From 7f44236b9ac2c8c193a967529635fd7aa502c0a0 Mon Sep 17 00:00:00 2001 From: "yinzefeng.yzf" Date: Wed, 4 Mar 2026 11:33:20 +0800 Subject: [PATCH 01/34] rm hnsw_builder --- src/core/algorithm/hnsw/hnsw_builder.cc | 407 --- src/core/algorithm/hnsw/hnsw_builder.h | 93 - .../algorithm/hnsw/hnsw_builder_entity.cc | 198 -- src/core/algorithm/hnsw/hnsw_builder_entity.h | 138 - .../algorithm/hnsw/hnsw_searcher_entity.h | 1 - .../core/algorithm/hnsw/hnsw_builder_test.cc | 543 ---- .../algorithm/hnsw/hnsw_searcher_test.cpp | 2775 ----------------- 7 files changed, 4155 deletions(-) delete mode 100644 src/core/algorithm/hnsw/hnsw_builder.cc delete mode 100644 src/core/algorithm/hnsw/hnsw_builder.h delete mode 100644 src/core/algorithm/hnsw/hnsw_builder_entity.cc delete mode 100644 src/core/algorithm/hnsw/hnsw_builder_entity.h delete mode 100644 tests/core/algorithm/hnsw/hnsw_builder_test.cc delete mode 100644 tests/core/algorithm/hnsw/hnsw_searcher_test.cpp diff --git a/src/core/algorithm/hnsw/hnsw_builder.cc b/src/core/algorithm/hnsw/hnsw_builder.cc deleted file mode 100644 index da2d9faf..00000000 --- a/src/core/algorithm/hnsw/hnsw_builder.cc +++ /dev/null @@ -1,407 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "hnsw_builder.h" -#include -#include -#include -#include -#include -#include -#include "hnsw_algorithm.h" -#include "hnsw_params.h" - -namespace zvec { -namespace core { - -HnswBuilder::HnswBuilder() = default; - -int HnswBuilder::init(const IndexMeta &meta, const ailego::Params ¶ms) { - LOG_INFO("Begin HnswBuilder::init"); - - meta_ = meta; - auto params_copy = params; - meta_.set_builder("HnswBuilder", HnswEntity::kRevision, - std::move(params_copy)); - - size_t memory_quota = 0UL; - params.get(PARAM_HNSW_BUILDER_MEMORY_QUOTA, &memory_quota); - params.get(PARAM_HNSW_BUILDER_THREAD_COUNT, &thread_cnt_); - params.get(PARAM_HNSW_BUILDER_MIN_NEIGHBOR_COUNT, &min_neighbor_cnt_); - params.get(PARAM_HNSW_BUILDER_EFCONSTRUCTION, &ef_construction_); - params.get(PARAM_HNSW_BUILDER_CHECK_INTERVAL_SECS, &check_interval_secs_); - - params.get(PARAM_HNSW_BUILDER_MAX_NEIGHBOR_COUNT, &upper_max_neighbor_cnt_); - float multiplier = HnswEntity::kDefaultL0MaxNeighborCntMultiplier; - params.get(PARAM_HNSW_BUILDER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER, &multiplier); - l0_max_neighbor_cnt_ = multiplier * upper_max_neighbor_cnt_; - scaling_factor_ = upper_max_neighbor_cnt_; - params.get(PARAM_HNSW_BUILDER_SCALING_FACTOR, &scaling_factor_); - - multiplier = HnswEntity::kDefaultNeighborPruneMultiplier; - params.get(PARAM_HNSW_BUILDER_NEIGHBOR_PRUNE_MULTIPLIER, &multiplier); - size_t prune_cnt = multiplier * upper_max_neighbor_cnt_; - - if (ef_construction_ == 0) { - ef_construction_ = HnswEntity::kDefaultEfConstruction; - } - if (upper_max_neighbor_cnt_ == 0) { - upper_max_neighbor_cnt_ = HnswEntity::kDefaultUpperMaxNeighborCnt; - } - if (upper_max_neighbor_cnt_ > kMaxNeighborCnt) { - LOG_ERROR("[%s] must be in range (0,%d]", - PARAM_HNSW_BUILDER_MAX_NEIGHBOR_COUNT.c_str(), kMaxNeighborCnt); - return IndexError_InvalidArgument; - } - if (min_neighbor_cnt_ > upper_max_neighbor_cnt_) { - LOG_ERROR("[%s]-[%d] must be <= [%s]-[%d]", - PARAM_HNSW_BUILDER_MIN_NEIGHBOR_COUNT.c_str(), min_neighbor_cnt_, - PARAM_HNSW_BUILDER_MAX_NEIGHBOR_COUNT.c_str(), - upper_max_neighbor_cnt_); - return IndexError_InvalidArgument; - } - if (l0_max_neighbor_cnt_ == 0) { - l0_max_neighbor_cnt_ = HnswEntity::kDefaultUpperMaxNeighborCnt; - } - if (l0_max_neighbor_cnt_ > HnswEntity::kMaxNeighborCnt) { - LOG_ERROR("L0MaxNeighborCnt must be in range (0,%d)", - HnswEntity::kMaxNeighborCnt); - return IndexError_InvalidArgument; - } - if (scaling_factor_ == 0U) { - scaling_factor_ = HnswEntity::kDefaultScalingFactor; - } - if (scaling_factor_ < 5 || scaling_factor_ > 1000) { - LOG_ERROR("[%s] must be in range [5,1000]", - PARAM_HNSW_BUILDER_SCALING_FACTOR.c_str()); - return IndexError_InvalidArgument; - } - if (thread_cnt_ == 0) { - thread_cnt_ = std::thread::hardware_concurrency(); - } - if (thread_cnt_ > std::thread::hardware_concurrency()) { - LOG_WARN("[%s] greater than cpu cores %u", - PARAM_HNSW_BUILDER_THREAD_COUNT.c_str(), - std::thread::hardware_concurrency()); - } - if (prune_cnt == 0UL) { - prune_cnt = upper_max_neighbor_cnt_; - } - - metric_ = IndexFactory::CreateMetric(meta_.metric_name()); - if (!metric_) { - LOG_ERROR("CreateMetric failed, name: %s", meta_.metric_name().c_str()); - return IndexError_NoExist; - } - int ret = metric_->init(meta_, meta_.metric_params()); - if (ret != 0) { - LOG_ERROR("IndexMetric init failed, ret=%d", ret); - return ret; - } - - entity_.set_vector_size(meta_.element_size()); - - entity_.set_ef_construction(ef_construction_); - entity_.set_l0_neighbor_cnt(l0_max_neighbor_cnt_); - entity_.set_min_neighbor_cnt(min_neighbor_cnt_); - entity_.set_upper_neighbor_cnt(upper_max_neighbor_cnt_); - entity_.set_scaling_factor(scaling_factor_); - entity_.set_memory_quota(memory_quota); - entity_.set_prune_cnt(prune_cnt); - - ret = entity_.init(); - if (ret != 0) { - return ret; - } - - alg_ = HnswAlgorithm::UPointer(new HnswAlgorithm(entity_)); - - ret = alg_->init(); - if (ret != 0) { - return ret; - } - - state_ = BUILD_STATE_INITED; - LOG_INFO( - "End HnswBuilder::init, params: vectorSize=%u efConstruction=%u " - "l0NeighborCnt=%u upperNeighborCnt=%u scalingFactor=%u " - "memoryQuota=%zu neighborPruneCnt=%zu metricName=%s ", - meta_.element_size(), ef_construction_, l0_max_neighbor_cnt_, - upper_max_neighbor_cnt_, scaling_factor_, memory_quota, prune_cnt, - meta_.metric_name().c_str()); - - return 0; -} - -int HnswBuilder::cleanup(void) { - LOG_INFO("Begin HnswBuilder::cleanup"); - - l0_max_neighbor_cnt_ = HnswEntity::kDefaultL0MaxNeighborCnt; - min_neighbor_cnt_ = 0; - upper_max_neighbor_cnt_ = HnswEntity::kDefaultUpperMaxNeighborCnt; - ef_construction_ = HnswEntity::kDefaultEfConstruction; - scaling_factor_ = HnswEntity::kDefaultScalingFactor; - check_interval_secs_ = kDefaultLogIntervalSecs; - errcode_ = 0; - error_ = false; - entity_.cleanup(); - alg_->cleanup(); - meta_.clear(); - metric_.reset(); - stats_.clear_attributes(); - stats_.set_trained_count(0UL); - stats_.set_built_count(0UL); - stats_.set_dumped_count(0UL); - stats_.set_discarded_count(0UL); - stats_.set_trained_costtime(0UL); - stats_.set_built_costtime(0UL); - stats_.set_dumped_costtime(0UL); - state_ = BUILD_STATE_INIT; - - LOG_INFO("End HnswBuilder::cleanup"); - - return 0; -} - -int HnswBuilder::train(IndexThreads::Pointer, IndexHolder::Pointer holder) { - if (state_ != BUILD_STATE_INITED) { - LOG_ERROR("Init the builder before HnswBuilder::train"); - return IndexError_NoReady; - } - - if (!holder) { - LOG_ERROR("Input holder is nullptr while training index"); - return IndexError_InvalidArgument; - } - if (!holder->is_matched(meta_)) { - LOG_ERROR("Input holder doesn't match index meta while training index"); - return IndexError_Mismatch; - } - LOG_INFO("Begin HnswBuilder::train"); - size_t trained_cost_time = 0; - size_t trained_count = 0; - - if (metric_->support_train()) { - auto start_time = ailego::Monotime::MilliSeconds(); - auto iter = holder->create_iterator(); - if (!iter) { - LOG_ERROR("Create iterator for holder failed"); - return IndexError_Runtime; - } - while (iter->is_valid()) { - int ret = metric_->train(iter->data(), meta_.dimension()); - if (ailego_unlikely(ret != 0)) { - LOG_ERROR("Hnsw build measure train failed, ret=%d", ret); - return ret; - } - iter->next(); - ++trained_count; - } - trained_cost_time = ailego::Monotime::MilliSeconds() - start_time; - } - stats_.set_trained_count(trained_count); - stats_.set_trained_costtime(trained_cost_time); - state_ = BUILD_STATE_TRAINED; - - LOG_INFO("End HnswBuilder::train"); - - return 0; -} - -int HnswBuilder::train(const IndexTrainer::Pointer & /*trainer*/) { - if (state_ != BUILD_STATE_INITED) { - LOG_ERROR("Init the builder before HnswBuilder::train"); - return IndexError_NoReady; - } - - LOG_INFO("Begin HnswBuilder::train by trainer"); - - stats_.set_trained_count(0UL); - stats_.set_trained_costtime(0UL); - state_ = BUILD_STATE_TRAINED; - - LOG_INFO("End HnswBuilder::train by trainer"); - - return 0; -} - -int HnswBuilder::build(IndexThreads::Pointer threads, - IndexHolder::Pointer holder) { - if (state_ != BUILD_STATE_TRAINED) { - LOG_ERROR("Train the index before HnswBuilder::build"); - return IndexError_NoReady; - } - - if (!holder) { - LOG_ERROR("Input holder is nullptr while building index"); - return IndexError_InvalidArgument; - } - if (!holder->is_matched(meta_)) { - LOG_ERROR("Input holder doesn't match index meta while building index"); - return IndexError_Mismatch; - } - if (!threads) { - threads = std::make_shared(thread_cnt_, false); - if (!threads) { - return IndexError_NoMemory; - } - } - - auto start_time = ailego::Monotime::MilliSeconds(); - LOG_INFO("Begin HnswBuilder::build"); - - if (holder->count() != static_cast(-1)) { - LOG_DEBUG("HnswBuilder holder documents count %lu", holder->count()); - int ret = entity_.reserve_space(holder->count()); - if (ret != 0) { - LOG_ERROR("HnswBuilde reserver space failed"); - return ret; - } - } - auto iter = holder->create_iterator(); - if (!iter) { - LOG_ERROR("Create iterator for holder failed"); - return IndexError_Runtime; - } - int ret; - error_ = false; - while (iter->is_valid()) { - level_t level = alg_->get_random_level(); - node_id_t id; - - const void *vec = iter->data(); - ret = entity_.add_vector(level, iter->key(), vec, &id); - if (ailego_unlikely(ret != 0)) { - return ret; - } - iter->next(); - } - // Holder is not needed, cleanup it. - holder.reset(); - - LOG_INFO("Finished save vector, start build graph..."); - - auto task_group = threads->make_group(); - if (!task_group) { - LOG_ERROR("Failed to create task group"); - return IndexError_Runtime; - } - - std::atomic finished{0}; - for (size_t i = 0; i < threads->count(); ++i) { - task_group->submit(ailego::Closure ::New(this, &HnswBuilder::do_build, i, - threads->count(), &finished)); - } - - while (!task_group->is_finished()) { - std::unique_lock lk(mutex_); - cond_.wait_until(lk, std::chrono::system_clock::now() + - std::chrono::seconds(check_interval_secs_)); - if (error_.load(std::memory_order_acquire)) { - LOG_ERROR("Failed to build index while waiting finish"); - return errcode_; - } - LOG_INFO("Built cnt %u, finished percent %.3f%%", finished.load(), - finished.load() * 100.0f / entity_.doc_cnt()); - } - if (error_.load(std::memory_order_acquire)) { - LOG_ERROR("Failed to build index while waiting finish"); - return errcode_; - } - task_group->wait_finish(); - - stats_.set_built_count(finished.load()); - stats_.set_built_costtime(ailego::Monotime::MilliSeconds() - start_time); - state_ = BUILD_STATE_BUILT; - - LOG_INFO("End HnswBuilder::build"); - return 0; -} - -void HnswBuilder::do_build(node_id_t idx, size_t step_size, - std::atomic *finished) { - AILEGO_DEFER([&]() { - std::lock_guard latch(mutex_); - cond_.notify_one(); - }); - HnswContext *ctx = new (std::nothrow) - HnswContext(meta_.dimension(), metric_, - std::shared_ptr(&entity_, [](HnswEntity *) {})); - if (ailego_unlikely(ctx == nullptr)) { - if (!error_.exchange(true)) { - LOG_ERROR("Failed to create context"); - errcode_ = IndexError_NoMemory; - } - return; - } - HnswContext::Pointer auto_ptr(ctx); - ctx->set_max_scan_num(entity_.doc_cnt()); - int ret = ctx->init(HnswContext::kBuilderContext); - if (ret != 0) { - if (!error_.exchange(true)) { - LOG_ERROR("Failed to init context"); - errcode_ = IndexError_Runtime; - } - return; - } - - IndexQueryMeta qmeta(meta_.data_type(), meta_.dimension()); - for (node_id_t id = idx; id < entity_.doc_cnt(); id += step_size) { - ctx->reset_query(entity_.get_vector(id)); - ret = alg_->add_node(id, entity_.get_level(id), ctx); - if (ailego_unlikely(ret != 0)) { - if (!error_.exchange(true)) { - LOG_ERROR("Hnsw graph add node failed"); - errcode_ = ret; - } - return; - } - ctx->clear(); - (*finished)++; - } -} - -int HnswBuilder::dump(const IndexDumper::Pointer &dumper) { - if (state_ != BUILD_STATE_BUILT) { - LOG_INFO("Build the index before HnswBuilder::dump"); - return IndexError_NoReady; - } - - LOG_INFO("Begin HnswBuilder::dump"); - - meta_.set_searcher("HnswSearcher", HnswEntity::kRevision, ailego::Params()); - auto start_time = ailego::Monotime::MilliSeconds(); - - int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); - if (ret != 0) { - LOG_ERROR("Failed to serialize meta into dumper."); - return ret; - } - - ret = entity_.dump(dumper); - if (ret != 0) { - LOG_ERROR("HnswBuilder dump index failed"); - return ret; - } - - stats_.set_dumped_count(entity_.doc_cnt()); - stats_.set_dumped_costtime(ailego::Monotime::MilliSeconds() - start_time); - - LOG_INFO("EndHnswBuilder::dump"); - return 0; -} - -INDEX_FACTORY_REGISTER_BUILDER(HnswBuilder); - -} // namespace core -} // namespace zvec diff --git a/src/core/algorithm/hnsw/hnsw_builder.h b/src/core/algorithm/hnsw/hnsw_builder.h deleted file mode 100644 index e1145283..00000000 --- a/src/core/algorithm/hnsw/hnsw_builder.h +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include -#include -#include "hnsw_algorithm.h" -#include "hnsw_builder_entity.h" - -namespace zvec { -namespace core { - -class HnswBuilder : public IndexBuilder { - public: - //! Constructor - HnswBuilder(); - - //! Initialize the builder - virtual int init(const IndexMeta &meta, - const ailego::Params ¶ms) override; - - //! Cleanup the builder - virtual int cleanup(void) override; - - //! Train the data - virtual int train(IndexThreads::Pointer, - IndexHolder::Pointer holder) override; - - //! Train the data - virtual int train(const IndexTrainer::Pointer &trainer) override; - - - //! Build the index - virtual int build(IndexThreads::Pointer threads, - IndexHolder::Pointer holder) override; - - //! Dump index into storage - virtual int dump(const IndexDumper::Pointer &dumper) override; - - //! Retrieve statistics - virtual const Stats &stats(void) const override { - return stats_; - } - - private: - void do_build(node_id_t idx, size_t step_size, - std::atomic *finished); - - constexpr static uint32_t kDefaultLogIntervalSecs = 15U; - constexpr static uint32_t kMaxNeighborCnt = 65535; - - private: - enum BUILD_STATE { - BUILD_STATE_INIT = 0, - BUILD_STATE_INITED = 1, - BUILD_STATE_TRAINED = 2, - BUILD_STATE_BUILT = 3 - }; - - HnswBuilderEntity entity_{}; - HnswAlgorithm::UPointer alg_; // impl graph algorithm - uint32_t thread_cnt_{0}; - uint32_t min_neighbor_cnt_{0}; - uint32_t upper_max_neighbor_cnt_{HnswEntity::kDefaultUpperMaxNeighborCnt}; - uint32_t l0_max_neighbor_cnt_{HnswEntity::kDefaultL0MaxNeighborCnt}; - uint32_t ef_construction_{HnswEntity::kDefaultEfConstruction}; - uint32_t scaling_factor_{HnswEntity::kDefaultScalingFactor}; - uint32_t check_interval_secs_{kDefaultLogIntervalSecs}; - - int errcode_{0}; - std::atomic_bool error_{false}; - IndexMeta meta_{}; - IndexMetric::Pointer metric_{}; - std::mutex mutex_{}; - std::condition_variable cond_{}; - Stats stats_{}; - - BUILD_STATE state_{BUILD_STATE_INIT}; -}; - -} // namespace core -} // namespace zvec diff --git a/src/core/algorithm/hnsw/hnsw_builder_entity.cc b/src/core/algorithm/hnsw/hnsw_builder_entity.cc deleted file mode 100644 index 472d98ee..00000000 --- a/src/core/algorithm/hnsw/hnsw_builder_entity.cc +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "hnsw_builder_entity.h" -#include -#include -#include "utility/sparse_utility.h" - -namespace zvec { -namespace core { - -HnswBuilderEntity::HnswBuilderEntity() { - update_ep_and_level(kInvalidNodeId, 0U); -} - -int HnswBuilderEntity::cleanup() { - memory_quota_ = 0UL; - neighbors_size_ = 0U; - upper_neighbors_size_ = 0U; - padding_size_ = 0U; - vectors_buffer_.clear(); - keys_buffer_.clear(); - neighbors_buffer_.clear(); - upper_neighbors_buffer_.clear(); - neighbors_index_.clear(); - - vectors_buffer_.shrink_to_fit(); - keys_buffer_.shrink_to_fit(); - neighbors_buffer_.shrink_to_fit(); - upper_neighbors_buffer_.shrink_to_fit(); - neighbors_index_.shrink_to_fit(); - - this->HnswEntity::cleanup(); - - return 0; -} - -int HnswBuilderEntity::init() { - size_t size = vector_size(); - - //! aligned size to 32 - set_node_size(AlignSize(size)); - //! if node size is aligned to 1k, the build performance will downgrade - if (node_size() % 1024 == 0) { - set_node_size(AlignSize(node_size() + 1)); - } - - padding_size_ = node_size() - size; - - neighbors_size_ = neighbors_size(); - upper_neighbors_size_ = upper_neighbors_size(); - - return 0; -} - -int HnswBuilderEntity::reserve_space(size_t docs) { - if (memory_quota_ > 0 && (node_size() * docs + neighbors_size_ * docs + - sizeof(NeighborIndex) * docs > - memory_quota_)) { - return IndexError_NoMemory; - } - - vectors_buffer_.reserve(node_size() * docs); - keys_buffer_.reserve(sizeof(key_t) * docs); - neighbors_buffer_.reserve(neighbors_size_ * docs); - neighbors_index_.reserve(docs); - - return 0; -} - -int HnswBuilderEntity::add_vector(level_t level, key_t key, const void *vec, - node_id_t *id) { - if (memory_quota_ > 0 && - (vectors_buffer_.capacity() + keys_buffer_.capacity() + - neighbors_buffer_.capacity() + upper_neighbors_buffer_.capacity() + - neighbors_index_.capacity() * sizeof(NeighborIndex)) > memory_quota_) { - LOG_ERROR("Add vector failed, used memory exceed quota, cur_doc=%u", - doc_cnt()); - return IndexError_NoMemory; - } - - vectors_buffer_.append(reinterpret_cast(vec), vector_size()); - vectors_buffer_.append(padding_size_, '\0'); - keys_buffer_.append(reinterpret_cast(&key), sizeof(key)); - - // init level 0 neighbors - neighbors_buffer_.append(neighbors_size_, '\0'); - - neighbors_index_.emplace_back(upper_neighbors_buffer_.size(), level); - - // init upper layer neighbors - for (level_t cur_level = 1; cur_level <= level; ++cur_level) { - upper_neighbors_buffer_.append(upper_neighbors_size_, '\0'); - } - - *id = (*mutable_doc_cnt())++; - - return 0; -} - -key_t HnswBuilderEntity::get_key(node_id_t id) const { - return *(reinterpret_cast(keys_buffer_.data() + - id * sizeof(key_t))); -} - -const void *HnswBuilderEntity::get_vector(node_id_t id) const { - return vectors_buffer_.data() + id * node_size(); -} - -int HnswBuilderEntity::get_vector(const node_id_t id, - IndexStorage::MemoryBlock &block) const { - const void *vec = get_vector(id); - block.reset((void *)vec); - return 0; -} - -int HnswBuilderEntity::get_vector(const node_id_t *ids, uint32_t count, - const void **vecs) const { - for (uint32_t i = 0; i < count; ++i) { - vecs[i] = vectors_buffer_.data() + ids[i] * node_size(); - } - - return 0; -} - -int HnswBuilderEntity::get_vector( - const node_id_t *ids, uint32_t count, - std::vector &vec_blocks) const { - const void *vecs[count]; - get_vector(ids, count, vecs); - for (uint32_t i = 0; i < count; ++i) { - vec_blocks.emplace_back(IndexStorage::MemoryBlock((void *)vecs[i])); - } - return 0; -} - -const Neighbors HnswBuilderEntity::get_neighbors(level_t level, - node_id_t id) const { - const NeighborsHeader *hd = get_neighbor_header(level, id); - return {hd->neighbor_cnt, hd->neighbors}; -} - -int HnswBuilderEntity::update_neighbors( - level_t level, node_id_t id, - const std::vector> &neighbors) { - NeighborsHeader *hd = - const_cast(get_neighbor_header(level, id)); - for (size_t i = 0; i < neighbors.size(); ++i) { - hd->neighbors[i] = neighbors[i].first; - } - hd->neighbor_cnt = neighbors.size(); - - // std::cout << "id: " << id << ", neighbour, id: "; - // for (size_t i = 0; i < neighbors.size(); ++i) { - // if (i == neighbors.size()-1) - // std::cout << neighbors[i].first << ", score:" << neighbors[i].second << - // std::endl; - // else - // std::cout << neighbors[i].first << ", score:" << neighbors[i].second << - // ", id: "; - // } - - return 0; -} - -void HnswBuilderEntity::add_neighbor(level_t level, node_id_t id, - uint32_t /*size*/, node_id_t neighbor_id) { - NeighborsHeader *hd = - const_cast(get_neighbor_header(level, id)); - hd->neighbors[hd->neighbor_cnt++] = neighbor_id; - - return; -} - -int HnswBuilderEntity::dump(const IndexDumper::Pointer &dumper) { - key_t *keys = - reinterpret_cast(const_cast(keys_buffer_.data())); - auto ret = - dump_segments(dumper, keys, [&](node_id_t id) { return get_level(id); }); - if (ailego_unlikely(ret < 0)) { - return ret; - } - - return 0; -} - -} // namespace core -} // namespace zvec diff --git a/src/core/algorithm/hnsw/hnsw_builder_entity.h b/src/core/algorithm/hnsw/hnsw_builder_entity.h deleted file mode 100644 index 1708e338..00000000 --- a/src/core/algorithm/hnsw/hnsw_builder_entity.h +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include -#include "hnsw_entity.h" - -namespace zvec { -namespace core { - -class HnswBuilderEntity : public HnswEntity { - public: - //! Add vector and key to hnsw entity, and local id will be saved to id - virtual int add_vector(level_t level, key_t key, const void *vec, - node_id_t *id) override; - - //! Get primary key of the node id - virtual key_t get_key(node_id_t id) const override; - - //! Get vector feature data by key - virtual const void *get_vector(node_id_t id) const override; - - //! Batch get vectors feature data by keys - virtual int get_vector(const node_id_t *ids, uint32_t count, - const void **vecs) const override; - - virtual int get_vector(const node_id_t id, - IndexStorage::MemoryBlock &block) const override; - virtual int get_vector( - const node_id_t *ids, uint32_t count, - std::vector &vec_blocks) const override; - - //! Get the node id's neighbors on graph level - const NeighborsHeader *get_neighbor_header(level_t level, - node_id_t id) const { - if (level == 0) { - return reinterpret_cast( - neighbors_buffer_.data() + neighbors_size_ * id); - } else { - size_t offset = neighbors_index_[id].offset; - return reinterpret_cast( - upper_neighbors_buffer_.data() + offset + - (level - 1) * upper_neighbors_size_); - } - } - - //! Get the node id's neighbors on graph level - virtual const Neighbors get_neighbors(level_t level, - node_id_t id) const override; - - //! Replace node id in level's neighbors - virtual int update_neighbors( - level_t level, node_id_t id, - const std::vector> &neighbors) override; - - //! add a neighbor to id in graph level - virtual void add_neighbor(level_t level, node_id_t id, uint32_t size, - node_id_t neighbor_id) override; - - //! Dump the hnsw graph to dumper - virtual int dump(const IndexDumper::Pointer &dumper) override; - - //! Cleanup the entity - virtual int cleanup(void) override; - - public: - //! Constructor - HnswBuilderEntity(); - - //! Get the node graph level by id - level_t get_level(node_id_t id) const { - return neighbors_index_[id].level; - } - - //! Init builerEntity - int init(); - - //! reserve buffer space for documents - //! @param docs number of documents - int reserve_space(size_t docs); - - //! Set memory quota params - inline void set_memory_quota(size_t memory_quota) { - memory_quota_ = memory_quota; - } - - //! Get neighbors size - inline size_t neighbors_size() const { - return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); - } - - //! Get upper neighbors size - inline size_t upper_neighbors_size() const { - return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); - } - - public: - HnswBuilderEntity(const HnswBuilderEntity &) = delete; - HnswBuilderEntity &operator=(const HnswBuilderEntity &) = delete; - - private: - friend class HnswSearcherEntity; - //! class internal used only - struct NeighborIndex { - NeighborIndex(size_t off, level_t l) : offset(off), level(l) {} - uint64_t offset : 48; - uint64_t level : 16; - }; - - std::string vectors_buffer_{}; // aligned vectors - std::string keys_buffer_{}; // aligned vectors - std::string neighbors_buffer_{}; // level 0 neighbors buffer - std::string upper_neighbors_buffer_{}; // upper layer neighbors buffer - - std::string sparse_data_buffer_{}; // aligned spase data buffer - size_t sparse_data_offset_{0}; // - - // upper layer offset + level in upper_neighbors_buffer_ - std::vector neighbors_index_{}; - size_t memory_quota_{0UL}; - size_t neighbors_size_{0U}; // level 0 neighbors size - size_t upper_neighbors_size_{0U}; // level 0 neighbors size - size_t padding_size_{}; // padding size for each vector element -}; - -} // namespace core -} // namespace zvec diff --git a/src/core/algorithm/hnsw/hnsw_searcher_entity.h b/src/core/algorithm/hnsw/hnsw_searcher_entity.h index 34200976..6fcd6b9b 100644 --- a/src/core/algorithm/hnsw/hnsw_searcher_entity.h +++ b/src/core/algorithm/hnsw/hnsw_searcher_entity.h @@ -13,7 +13,6 @@ // limitations under the License. #pragma once -#include "hnsw_builder_entity.h" #include "hnsw_entity.h" namespace zvec { diff --git a/tests/core/algorithm/hnsw/hnsw_builder_test.cc b/tests/core/algorithm/hnsw/hnsw_builder_test.cc deleted file mode 100644 index 402ae972..00000000 --- a/tests/core/algorithm/hnsw/hnsw_builder_test.cc +++ /dev/null @@ -1,543 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "hnsw_builder.h" -#include -#include -#include -#include -#include -#include -#include "zvec/core/framework/index_framework.h" - -#if defined(__GNUC__) || defined(__GNUG__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-result" -#endif - -using namespace std; -using namespace zvec::ailego; - -namespace zvec { -namespace core { - -constexpr size_t static dim = 16; - -class HnswBuilderTest : public testing::Test { - protected: - void SetUp(void); - void TearDown(void); - - static std::string _dir; - static shared_ptr _index_meta_ptr; -}; - -std::string HnswBuilderTest::_dir("hnswBuilderTest"); -shared_ptr HnswBuilderTest::_index_meta_ptr; - -void HnswBuilderTest::SetUp(void) { - _index_meta_ptr.reset(new (nothrow) - IndexMeta(IndexMeta::DataType::DT_FP32, dim)); - _index_meta_ptr->set_metric("SquaredEuclidean", 0, ailego::Params()); -} - -void HnswBuilderTest::TearDown(void) { - char cmdBuf[100]; - snprintf(cmdBuf, 100, "rm -rf %s", _dir.c_str()); - system(cmdBuf); -} - -TEST_F(HnswBuilderTest, TestGeneral) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - - auto holder = - make_shared>(dim); - size_t doc_cnt = 1000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - - ailego::Params params; - // params.set("proxima.hnsw.builder.thread_count", 1); - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - - ASSERT_EQ(0, builder->train(holder)); - - ASSERT_EQ(0, builder->build(holder)); - - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - - string path = _dir + "/TestGeneral"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - auto &stats = builder->stats(); - ASSERT_EQ(0UL, stats.trained_count()); - ASSERT_EQ(doc_cnt, stats.built_count()); - ASSERT_EQ(doc_cnt, stats.dumped_count()); - ASSERT_EQ(0UL, stats.discarded_count()); - ASSERT_EQ(0UL, stats.trained_costtime()); - ASSERT_GT(stats.built_costtime(), 0UL); - // ASSERT_GT(stats.dumped_costtime(), 0UL); - - // cleanup and rebuild - ASSERT_EQ(0, builder->cleanup()); - - auto holder2 = - make_shared>(dim); - size_t doc_cnt2 = 2000UL; - for (size_t i = 0; i < doc_cnt2; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder2->emplace(i, vec)); - } - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder2)); - ASSERT_EQ(0, builder->build(holder2)); - auto dumper2 = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper2, nullptr); - ASSERT_EQ(0, dumper2->create(path)); - ASSERT_EQ(0, builder->dump(dumper2)); - ASSERT_EQ(0, dumper2->close()); - - ASSERT_EQ(0UL, stats.trained_count()); - ASSERT_EQ(doc_cnt2, stats.built_count()); - ASSERT_EQ(doc_cnt2, stats.dumped_count()); - ASSERT_EQ(0UL, stats.discarded_count()); - ASSERT_EQ(0UL, stats.trained_costtime()); - ASSERT_GT(stats.built_costtime(), 0UL); -} - -TEST_F(HnswBuilderTest, TestMemquota) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - - auto holder = - make_shared>(dim); - size_t doc_cnt = 1000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - - ailego::Params params; - params.set("proxima.hnsw.builder.memory_quota", 100000UL); - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(IndexError_NoMemory, builder->build(holder)); -} - -TEST_F(HnswBuilderTest, TestIndexThreads) { - IndexBuilder::Pointer builder1 = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder1, nullptr); - IndexBuilder::Pointer builder2 = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder2, nullptr); - - auto holder = - make_shared>(dim); - size_t doc_cnt = 1000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - - ailego::Params params; - std::srand(ailego::Realtime::MilliSeconds()); - auto threads = - std::make_shared(std::rand() % 4, false); - ASSERT_EQ(0, builder1->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder2->init(*_index_meta_ptr, params)); - - auto build_index1 = [&]() { - ASSERT_EQ(0, builder1->train(threads, holder)); - ASSERT_EQ(0, builder1->build(threads, holder)); - }; - auto build_index2 = [&]() { - ASSERT_EQ(0, builder2->train(threads, holder)); - ASSERT_EQ(0, builder2->build(threads, holder)); - }; - - auto t1 = std::async(std::launch::async, build_index1); - auto t2 = std::async(std::launch::async, build_index2); - t1.wait(); - t2.wait(); - - - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - - string path = _dir + "/TestIndexThreads"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder1->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder2->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - auto &stats1 = builder1->stats(); - ASSERT_EQ(doc_cnt, stats1.built_count()); - auto &stats2 = builder2->stats(); - ASSERT_EQ(doc_cnt, stats2.built_count()); -} - -TEST_F(HnswBuilderTest, TestCosine) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - - auto holder = - make_shared>(dim); - size_t doc_cnt = 1000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); - index_meta_raw.set_metric("Cosine", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("CosineFp32Converter"); - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - converter->transform(holder); - - auto converted_holder = converter->result(); - - ailego::Params params; - // params.set("proxima.hnsw.builder.thread_count", 1); - ASSERT_EQ(0, builder->init(index_meta, params)); - - ASSERT_EQ(0, builder->train(converted_holder)); - - ASSERT_EQ(0, builder->build(converted_holder)); - - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - - string path = _dir + "/TestCosine"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - auto &stats = builder->stats(); - ASSERT_EQ(0UL, stats.trained_count()); - ASSERT_EQ(doc_cnt, stats.built_count()); - ASSERT_EQ(doc_cnt, stats.dumped_count()); - ASSERT_EQ(0UL, stats.discarded_count()); - ASSERT_EQ(0UL, stats.trained_costtime()); - ASSERT_GT(stats.built_costtime(), 0UL); - // ASSERT_GT(stats.dumped_costtime(), 0UL); - - // cleanup and rebuild - ASSERT_EQ(0, builder->cleanup()); - - auto holder2 = - make_shared>(dim); - size_t doc_cnt2 = 2000UL; - for (size_t i = 0; i < doc_cnt2; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder2->emplace(i, vec)); - } - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder2)); - ASSERT_EQ(0, builder->build(holder2)); - auto dumper2 = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper2, nullptr); - ASSERT_EQ(0, dumper2->create(path)); - ASSERT_EQ(0, builder->dump(dumper2)); - ASSERT_EQ(0, dumper2->close()); - - ASSERT_EQ(0UL, stats.trained_count()); - ASSERT_EQ(doc_cnt2, stats.built_count()); - ASSERT_EQ(doc_cnt2, stats.dumped_count()); - ASSERT_EQ(0UL, stats.discarded_count()); - ASSERT_EQ(0UL, stats.trained_costtime()); - ASSERT_GT(stats.built_costtime(), 0UL); -} - -TEST_F(HnswBuilderTest, TestCosineFp16Converter) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - - auto holder = - make_shared>(dim); - size_t doc_cnt = 1000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); - index_meta_raw.set_metric("Cosine", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("CosineFp16Converter"); - - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - converter->transform(holder); - - auto converted_holder = converter->result(); - - ailego::Params params; - - // params.set("proxima.hnsw.builder.thread_count", 1); - ASSERT_EQ(0, builder->init(index_meta, params)); - - ASSERT_EQ(0, builder->train(converted_holder)); - - ASSERT_EQ(0, builder->build(converted_holder)); - - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - - string path = _dir + "/TestCosineFp16Converter"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - auto &stats = builder->stats(); - ASSERT_EQ(0UL, stats.trained_count()); - ASSERT_EQ(doc_cnt, stats.built_count()); - ASSERT_EQ(doc_cnt, stats.dumped_count()); - ASSERT_EQ(0UL, stats.discarded_count()); - ASSERT_EQ(0UL, stats.trained_costtime()); - ASSERT_GT(stats.built_costtime(), 0UL); - // ASSERT_GT(stats.dumped_costtime(), 0UL); - - // cleanup and rebuild - ASSERT_EQ(0, builder->cleanup()); - - auto holder2 = - make_shared>(dim); - size_t doc_cnt2 = 2000UL; - for (size_t i = 0; i < doc_cnt2; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder2->emplace(i, vec)); - } - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder2)); - ASSERT_EQ(0, builder->build(holder2)); - auto dumper2 = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper2, nullptr); - ASSERT_EQ(0, dumper2->create(path)); - ASSERT_EQ(0, builder->dump(dumper2)); - ASSERT_EQ(0, dumper2->close()); - - ASSERT_EQ(0UL, stats.trained_count()); - ASSERT_EQ(doc_cnt2, stats.built_count()); - ASSERT_EQ(doc_cnt2, stats.dumped_count()); - ASSERT_EQ(0UL, stats.discarded_count()); - ASSERT_EQ(0UL, stats.trained_costtime()); - ASSERT_GT(stats.built_costtime(), 0UL); -} - -TEST_F(HnswBuilderTest, TestCosineInt8Converter) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - - auto holder = - make_shared>(dim); - size_t doc_cnt = 1000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); - index_meta_raw.set_metric("Cosine", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("CosineInt8Converter"); - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - converter->transform(holder); - - auto converted_holder = converter->result(); - - ailego::Params params; - // params.set("proxima.hnsw.builder.thread_count", 1); - ASSERT_EQ(0, builder->init(index_meta, params)); - - ASSERT_EQ(0, builder->train(converted_holder)); - - ASSERT_EQ(0, builder->build(converted_holder)); - - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - - string path = _dir + "/TestCosineInt8Converter"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - auto &stats = builder->stats(); - ASSERT_EQ(0UL, stats.trained_count()); - ASSERT_EQ(doc_cnt, stats.built_count()); - ASSERT_EQ(doc_cnt, stats.dumped_count()); - ASSERT_EQ(0UL, stats.discarded_count()); - ASSERT_EQ(0UL, stats.trained_costtime()); - ASSERT_GT(stats.built_costtime(), 0UL); - // ASSERT_GT(stats.dumped_costtime(), 0UL); - - // cleanup and rebuild - ASSERT_EQ(0, builder->cleanup()); - - auto holder2 = - make_shared>(dim); - size_t doc_cnt2 = 2000UL; - for (size_t i = 0; i < doc_cnt2; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder2->emplace(i, vec)); - } - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder2)); - ASSERT_EQ(0, builder->build(holder2)); - auto dumper2 = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper2, nullptr); - ASSERT_EQ(0, dumper2->create(path)); - ASSERT_EQ(0, builder->dump(dumper2)); - ASSERT_EQ(0, dumper2->close()); - - ASSERT_EQ(0UL, stats.trained_count()); - ASSERT_EQ(doc_cnt2, stats.built_count()); - ASSERT_EQ(doc_cnt2, stats.dumped_count()); - ASSERT_EQ(0UL, stats.discarded_count()); - ASSERT_EQ(0UL, stats.trained_costtime()); - ASSERT_GT(stats.built_costtime(), 0UL); -} - -TEST_F(HnswBuilderTest, TestCosineInt4Converter) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - - auto holder = - make_shared>(dim); - size_t doc_cnt = 1000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); - index_meta_raw.set_metric("Cosine", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("CosineInt4Converter"); - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - converter->transform(holder); - - auto converted_holder = converter->result(); - - ailego::Params params; - // params.set("proxima.hnsw.builder.thread_count", 1); - ASSERT_EQ(0, builder->init(index_meta, params)); - - ASSERT_EQ(0, builder->train(converted_holder)); - - ASSERT_EQ(0, builder->build(converted_holder)); - - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - - string path = _dir + "/TestCosineInt4Converter"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - auto &stats = builder->stats(); - ASSERT_EQ(0UL, stats.trained_count()); - ASSERT_EQ(doc_cnt, stats.built_count()); - ASSERT_EQ(doc_cnt, stats.dumped_count()); - ASSERT_EQ(0UL, stats.discarded_count()); - ASSERT_EQ(0UL, stats.trained_costtime()); - ASSERT_GT(stats.built_costtime(), 0UL); - // ASSERT_GT(stats.dumped_costtime(), 0UL); - - // cleanup and rebuild - ASSERT_EQ(0, builder->cleanup()); - - auto holder2 = - make_shared>(dim); - size_t doc_cnt2 = 2000UL; - for (size_t i = 0; i < doc_cnt2; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder2->emplace(i, vec)); - } - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder2)); - ASSERT_EQ(0, builder->build(holder2)); - auto dumper2 = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper2, nullptr); - ASSERT_EQ(0, dumper2->create(path)); - ASSERT_EQ(0, builder->dump(dumper2)); - ASSERT_EQ(0, dumper2->close()); - - ASSERT_EQ(0UL, stats.trained_count()); - ASSERT_EQ(doc_cnt2, stats.built_count()); - ASSERT_EQ(doc_cnt2, stats.dumped_count()); - ASSERT_EQ(0UL, stats.discarded_count()); - ASSERT_EQ(0UL, stats.trained_costtime()); - ASSERT_GT(stats.built_costtime(), 0UL); -} - -} // namespace core -} // namespace zvec - -#if defined(__GNUC__) || defined(__GNUG__) -#pragma GCC diagnostic pop -#endif \ No newline at end of file diff --git a/tests/core/algorithm/hnsw/hnsw_searcher_test.cpp b/tests/core/algorithm/hnsw/hnsw_searcher_test.cpp deleted file mode 100644 index d3a7004f..00000000 --- a/tests/core/algorithm/hnsw/hnsw_searcher_test.cpp +++ /dev/null @@ -1,2775 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include -#include -#include "zvec/core/framework/index_builder.h" -#include "zvec/core/framework/index_factory.h" -#include "zvec/core/framework/index_meta.h" -#include "hnsw_params.h" - -using namespace std; -using namespace testing; -using namespace zvec::ailego; - -#if defined(__GNUC__) || defined(__GNUG__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-result" -#endif - -namespace zvec { -namespace core { - -constexpr size_t static dim = 16; - -class HnswSearcherTest : public testing::Test { - protected: - void SetUp(void); - void TearDown(void); - - static std::string _dir; - static shared_ptr _index_meta_ptr; -}; - -std::string HnswSearcherTest::_dir("HnswSearcherTest/"); -shared_ptr HnswSearcherTest::_index_meta_ptr; - -void HnswSearcherTest::SetUp(void) { - _index_meta_ptr.reset(new (nothrow) - IndexMeta(IndexMeta::DataType::DT_FP32, dim)); - _index_meta_ptr->set_metric("SquaredEuclidean", 0, ailego::Params()); -} - -void HnswSearcherTest::TearDown(void) { - char cmdBuf[100]; - snprintf(cmdBuf, 100, "rm -rf %s", _dir.c_str()); - system(cmdBuf); -} - -TEST_F(HnswSearcherTest, TestRnnSearch) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - size_t doc_cnt = 1000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - ASSERT_EQ(0, builder->init(*_index_meta_ptr, ailego::Params())); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestRnnSearch"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ASSERT_EQ(0, searcher->init(ailego::Params())); - - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - auto ctx = searcher->create_context(); - ASSERT_TRUE(!!ctx); - - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = 0.0; - } - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - size_t topk = 50; - ctx->set_topk(topk); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - auto &results = ctx->result(); - ASSERT_EQ(topk, results.size()); - - float radius = results[topk / 2].score(); - ctx->set_threshold(radius); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - ASSERT_GT(topk, results.size()); - for (size_t k = 0; k < results.size(); ++k) { - ASSERT_GE(radius, results[k].score()); - } - - // Test Reset Threshold - ctx->reset_threshold(); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - ASSERT_EQ(topk, results.size()); - ASSERT_LT(radius, results[topk - 1].score()); -} - -TEST_F(HnswSearcherTest, TestRnnSearchInnerProduct) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - size_t doc_cnt = 1000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - - IndexMeta index_meta(IndexMeta::DataType::DT_FP32, dim); - index_meta.set_metric("InnerProduct", 0, ailego::Params()); - - ASSERT_EQ(0, builder->init(index_meta, ailego::Params())); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestRnnSearchInnerProduct"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ASSERT_EQ(0, searcher->init(ailego::Params())); - - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - auto ctx = searcher->create_context(); - ASSERT_TRUE(!!ctx); - - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = 1.0; - } - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - size_t topk = 50; - ctx->set_topk(topk); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - auto &results = ctx->result(); - ASSERT_EQ(topk, results.size()); - - float radius = -results[topk / 2].score(); - ctx->set_threshold(radius); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - ASSERT_GT(topk, results.size()); - for (size_t k = 0; k < results.size(); ++k) { - ASSERT_GE(radius, results[k].score()); - } - - // Test Reset Threshold - ctx->reset_threshold(); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - ASSERT_EQ(topk, results.size()); - ASSERT_LT(-radius, results[topk - 1].score()); -} - -TEST_F(HnswSearcherTest, TestRnnSearchCosine) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - size_t doc_cnt = 1000UL; - - std::random_device rd; - std::mt19937 gen(rd()); - - std::uniform_real_distribution dist(-1.0, 1.0); - - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = dist(gen); - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); - index_meta_raw.set_metric("Cosine", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("CosineFp32Converter"); - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - converter->transform(holder); - - auto converted_holder = converter->result(); - - ASSERT_EQ(0, builder->init(index_meta, ailego::Params())); - ASSERT_EQ(0, builder->train(converted_holder)); - ASSERT_EQ(0, builder->build(converted_holder)); - - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestRnnSearchCosine"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ASSERT_EQ(0, searcher->init(ailego::Params())); - - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - auto ctx = searcher->create_context(); - ASSERT_TRUE(!!ctx); - - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = 1.0; - } - - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); - ASSERT_TRUE(reformer != nullptr); - - ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); - - std::string new_query; - IndexQueryMeta new_meta; - ASSERT_EQ(0, reformer->transform(vec.data(), qmeta, &new_query, &new_meta)); - - size_t topk = 50; - ctx->set_topk(topk); - ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, ctx)); - auto &results = ctx->result(); - ASSERT_EQ(topk, results.size()); - - float radius = 0.5f; - ctx->set_threshold(radius); - ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, ctx)); - ASSERT_GT(topk, results.size()); - for (size_t k = 0; k < results.size(); ++k) { - ASSERT_GE(radius, results[k].score()); - } - - // Test Reset Threshold - ctx->reset_threshold(); - ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, ctx)); - ASSERT_EQ(topk, results.size()); - ASSERT_LT(radius, results[topk - 1].score()); -} - -TEST_F(HnswSearcherTest, TestRnnSearchMipsSquaredEuclidean) { - IndexStreamer::Pointer streamer = - IndexFactory::CreateStreamer("HnswStreamer"); - ASSERT_NE(streamer, nullptr); - - ailego::Params params; - params.set(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, 10); - params.set(PARAM_HNSW_STREAMER_SCALING_FACTOR, 16); - params.set(PARAM_HNSW_STREAMER_EFCONSTRUCTION, 10); - params.set(PARAM_HNSW_STREAMER_EF, 5); - params.set(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, 1000U); - - IndexMeta index_meta(IndexMeta::DataType::DT_FP32, dim); - index_meta.set_metric("MipsSquaredEuclidean", 0, ailego::Params()); - - ailego::Params stg_params; - auto storage = IndexFactory::CreateStorage("MMapFileStorage"); - ASSERT_EQ(0, storage->init(stg_params)); - ASSERT_EQ(0, storage->open(_dir + "/TestStreamerDump.index", true)); - ASSERT_EQ(0, streamer->init(index_meta, params)); - ASSERT_EQ(0, streamer->open(storage)); - - size_t doc_cnt = 1000UL; - auto streamer_ctx = streamer->create_context(); - ASSERT_TRUE(!!streamer_ctx); - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - - streamer->add_impl(i, vec.data(), qmeta, streamer_ctx); - } - - { - // Test Reset Threshold - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = 1.0; - } - - size_t topk = 50; - streamer_ctx->set_topk(topk); - ASSERT_EQ(0, streamer->search_impl(vec.data(), qmeta, streamer_ctx)); - auto &results = streamer_ctx->result(); - ASSERT_EQ(topk, results.size()); - - float radius = -results[topk / 2].score(); - streamer_ctx->set_threshold(radius); - ASSERT_EQ(0, streamer->search_impl(vec.data(), qmeta, streamer_ctx)); - ASSERT_GT(topk, results.size()); - for (size_t k = 0; k < results.size(); ++k) { - ASSERT_GE(radius, results[k].score()); - } - - streamer_ctx->reset_threshold(); - ASSERT_EQ(0, streamer->search_impl(vec.data(), qmeta, streamer_ctx)); - ASSERT_EQ(topk, results.size()); - ASSERT_LT(-radius, results[topk - 1].score()); - } - - auto path = _dir + "/TestStreamerDump"; - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, streamer->dump(dumper)); - ASSERT_EQ(0, streamer->close()); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ASSERT_EQ(0, searcher->init(ailego::Params())); - - auto read_storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, read_storage->open(path, false)); - ASSERT_EQ(0, searcher->load(read_storage, IndexMetric::Pointer())); - auto searcher_ctx = searcher->create_context(); - ASSERT_TRUE(!!searcher_ctx); - - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = 1.0; - } - - { - size_t topk = 50; - searcher_ctx->set_topk(topk); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, searcher_ctx)); - auto &results = searcher_ctx->result(); - ASSERT_EQ(topk, results.size()); - - float radius = -results[topk / 2].score(); - searcher_ctx->set_threshold(radius); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, searcher_ctx)); - ASSERT_GT(topk, results.size()); - for (size_t k = 0; k < results.size(); ++k) { - ASSERT_GE(radius, results[k].score()); - } - - // Test Reset Threshold - searcher_ctx->reset_threshold(); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, searcher_ctx)); - ASSERT_EQ(topk, results.size()); - ASSERT_LT(-radius, results[topk - 1].score()); - } -} - -TEST_F(HnswSearcherTest, TestGeneral) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - size_t doc_cnt = 5000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - ailego::Params params; - // params.set("proxima.hnsw.builder.max_neighbor_count", 16); - params.set("proxima.hnsw.builder.scaling_factor", 16); - params.set("proxima.hnsw.builder.ef_construction", 10); - params.set("proxima.hnsw.builder.thread_count", 2); - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestGeneral"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 1); - ASSERT_EQ(0, searcher->init(searcherParams)); - - - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - auto linearCtx = searcher->create_context(); - auto linearByPKeysCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - ASSERT_TRUE(!!linearCtx); - ASSERT_TRUE(!!linearByPKeysCtx); - ASSERT_TRUE(!!knnCtx); - NumericalVector vec(dim); - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - size_t topk = 200; - uint64_t knnTotalTime = 0; - uint64_t linearTotalTime = 0; - int totalHits = 0; - int totalCnts = 0; - int topk1Hits = 0; - linearCtx->set_topk(topk); - linearByPKeysCtx->set_topk(topk); - knnCtx->set_topk(topk); - - // do linear search test - { - std::vector query(dim); - for (size_t i = 0; i < dim; ++i) { - query[i] = 3.1f; - } - ASSERT_EQ(0, searcher->search_bf_impl(query.data(), qmeta, linearCtx)); - auto &linearResult = linearCtx->result(); - ASSERT_EQ(3UL, linearResult[0].key()); - ASSERT_EQ(4UL, linearResult[1].key()); - ASSERT_EQ(2UL, linearResult[2].key()); - ASSERT_EQ(5UL, linearResult[3].key()); - ASSERT_EQ(1UL, linearResult[4].key()); - ASSERT_EQ(6UL, linearResult[5].key()); - ASSERT_EQ(0UL, linearResult[6].key()); - ASSERT_EQ(7UL, linearResult[7].key()); - for (size_t i = 8; i < topk; ++i) { - ASSERT_EQ(i, linearResult[i].key()); - } - } - - // do linear search by p_keys test - std::vector> p_keys; - p_keys.resize(1); - p_keys[0] = {8, 9, 10, 11, 3, 2, 1, 0}; - { - std::vector query(dim); - for (size_t i = 0; i < dim; ++i) { - query[i] = 3.1f; - } - ASSERT_EQ(0, searcher->search_bf_by_p_keys_impl(query.data(), p_keys, qmeta, - linearByPKeysCtx)); - auto &linearByPKeysResult = linearByPKeysCtx->result(); - ASSERT_EQ(8, linearByPKeysResult.size()); - ASSERT_EQ(3UL, linearByPKeysResult[0].key()); - ASSERT_EQ(2UL, linearByPKeysResult[1].key()); - ASSERT_EQ(1UL, linearByPKeysResult[2].key()); - ASSERT_EQ(0UL, linearByPKeysResult[3].key()); - ASSERT_EQ(8UL, linearByPKeysResult[4].key()); - ASSERT_EQ(9UL, linearByPKeysResult[5].key()); - ASSERT_EQ(10UL, linearByPKeysResult[6].key()); - ASSERT_EQ(11UL, linearByPKeysResult[7].key()); - } - - size_t step = 50; - for (size_t i = 0; i < doc_cnt; i += step) { - for (size_t j = 0; j < dim; ++j) { - vec[j] = i + 0.1f; - } - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); - auto t2 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); - auto t3 = ailego::Realtime::MicroSeconds(); - knnTotalTime += t2 - t1; - linearTotalTime += t3 - t2; - - auto &knnResult = knnCtx->result(); - // TODO: check - // ASSERT_EQ(topk, knnResult.size()); - topk1Hits += i == knnResult[0].key(); - - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - ASSERT_EQ(i, linearResult[0].key()); - - for (size_t k = 0; k < topk; ++k) { - totalCnts++; - for (size_t j = 0; j < topk; ++j) { - if (linearResult[j].key() == knnResult[k].key()) { - totalHits++; - break; - } - } - } - } - float recall = totalHits * step * step * 1.0f / totalCnts; - float topk1Recall = topk1Hits * step * 1.0f / doc_cnt; - float cost = linearTotalTime * 1.0f / knnTotalTime; -#if 0 - printf("knnTotalTime=%zd linearTotalTime=%zd totalHits=%d totalCnts=%d " - "R@%zd=%f R@1=%f cost=%f\n", - knnTotalTime, linearTotalTime, totalHits, totalCnts, topk, recall, - topk1Recall, cost); -#endif - EXPECT_GT(recall, 0.90f); - EXPECT_GT(topk1Recall, 0.90f); - // EXPECT_GT(cost, 2.0f); -} - -TEST_F(HnswSearcherTest, TestClearAndReload) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - size_t doc_cnt = 1000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - ailego::Params params; - params.set("proxima.hnsw.builder.thread_count", 3); - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestGeneral"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.check_crc_enable", true); - searcherParams.set("proxima.hnsw.searcher.max_scan_ratio", - 1.1f); // including upper layer - ASSERT_EQ(0, searcher->init(searcherParams)); - - - auto storage = IndexFactory::CreateStorage("MMapFileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - auto linearCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - ASSERT_TRUE(!!linearCtx); - ASSERT_TRUE(!!knnCtx); - NumericalVector vec(dim); - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - size_t topk = 100; - linearCtx->set_topk(topk); - knnCtx->set_topk(topk); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); - ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); - auto &knnResult = knnCtx->result(); - ASSERT_EQ(topk, knnResult.size()); - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - auto &stats = searcher->stats(); - ASSERT_EQ(doc_cnt, stats.loaded_count()); - // ASSERT_GT(stats.loaded_costtime(), 0UL); - - //! cleanup - ASSERT_EQ(0, searcher->cleanup()); - ASSERT_EQ(nullptr, searcher->create_context()); - ASSERT_EQ(IndexError_Runtime, - searcher->load(storage, IndexMetric::Pointer())); - ASSERT_EQ(0UL, stats.loaded_count()); - - ASSERT_EQ(0, searcher->init(searcherParams)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - linearCtx = searcher->create_context(); - knnCtx = searcher->create_context(); - ASSERT_TRUE(!!linearCtx); - ASSERT_TRUE(!!knnCtx); - linearCtx->set_topk(topk); - knnCtx->set_topk(topk); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); - ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); - auto &knnResult1 = knnCtx->result(); - ASSERT_EQ(topk, knnResult1.size()); - auto &linearResult1 = linearCtx->result(); - ASSERT_EQ(topk, linearResult1.size()); - ASSERT_EQ(doc_cnt, stats.loaded_count()); - - //! unload - ASSERT_EQ(0, searcher->unload()); - ASSERT_EQ(nullptr, searcher->create_context()); - ASSERT_EQ(0UL, stats.loaded_count()); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - linearCtx = searcher->create_context(); - ASSERT_TRUE(!!linearCtx); - linearCtx->set_topk(topk); - ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); - auto &linearResult2 = linearCtx->result(); - ASSERT_EQ(topk, linearResult2.size()); - ASSERT_EQ(doc_cnt, stats.loaded_count()); -} - -TEST_F(HnswSearcherTest, TestFilter) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - size_t doc_cnt = 100UL; - std::vector> p_keys; - p_keys.resize(1); - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - p_keys[0].push_back(i); - } - ailego::Params params; - params.set("proxima.hnsw.builder.thread_count", 3); - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestGeneral"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.check_crc_enable", true); - searcherParams.set("proxima.hnsw.searcher.max_scan_ratio", 1.0f); - ASSERT_EQ(0, searcher->init(searcherParams)); - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - auto linearCtx = searcher->create_context(); - auto linearByPKeysCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - ASSERT_TRUE(!!linearCtx); - ASSERT_TRUE(!!linearByPKeysCtx); - ASSERT_TRUE(!!knnCtx); - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = 10.1f; - } - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - size_t topk = 10; - linearCtx->set_topk(topk); - linearByPKeysCtx->set_topk(topk); - knnCtx->set_topk(topk); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); - ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); - ASSERT_EQ(0, searcher->search_bf_by_p_keys_impl(vec.data(), p_keys, qmeta, - linearByPKeysCtx)); - - auto filterFunc = [](uint64_t key) { - if (key == 10UL || key == 11UL) { - return true; - } - return false; - }; - auto &knnResult = knnCtx->result(); - ASSERT_EQ(topk, knnResult.size()); - ASSERT_EQ(10UL, knnResult[0].key()); - ASSERT_EQ(11UL, knnResult[1].key()); - ASSERT_EQ(9UL, knnResult[2].key()); - - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - ASSERT_EQ(10UL, linearResult[0].key()); - ASSERT_EQ(11UL, linearResult[1].key()); - ASSERT_EQ(9UL, linearResult[2].key()); - - auto &linearByPKeysResult = linearByPKeysCtx->result(); - ASSERT_EQ(topk, linearByPKeysResult.size()); - ASSERT_EQ(10UL, linearByPKeysResult[0].key()); - ASSERT_EQ(11UL, linearByPKeysResult[1].key()); - ASSERT_EQ(9UL, linearByPKeysResult[2].key()); - - knnCtx->set_filter(filterFunc); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); - auto &knnResult1 = knnCtx->result(); - ASSERT_EQ(topk, knnResult1.size()); - ASSERT_EQ(9UL, knnResult1[0].key()); - ASSERT_EQ(12UL, knnResult1[1].key()); - ASSERT_EQ(8UL, knnResult1[2].key()); - - linearCtx->set_filter(filterFunc); - ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); - auto &linearResult1 = linearCtx->result(); - ASSERT_EQ(topk, linearResult1.size()); - ASSERT_EQ(9UL, linearResult1[0].key()); - ASSERT_EQ(12UL, linearResult1[1].key()); - ASSERT_EQ(8UL, linearResult1[2].key()); - - linearByPKeysCtx->set_filter(filterFunc); - ASSERT_EQ(0, searcher->search_bf_by_p_keys_impl(vec.data(), p_keys, qmeta, - linearByPKeysCtx)); - auto &linearByPKeysResult1 = linearByPKeysCtx->result(); - ASSERT_EQ(topk, linearByPKeysResult1.size()); - ASSERT_EQ(9UL, linearByPKeysResult1[0].key()); - ASSERT_EQ(12UL, linearByPKeysResult1[1].key()); - ASSERT_EQ(8UL, linearByPKeysResult1[2].key()); -} - -TEST_F(HnswSearcherTest, TestStreamerDump) { - IndexStreamer::Pointer streamer = - IndexFactory::CreateStreamer("HnswStreamer"); - ASSERT_NE(streamer, nullptr); - - ailego::Params params; - params.set(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, 10); - params.set(PARAM_HNSW_STREAMER_SCALING_FACTOR, 16); - params.set(PARAM_HNSW_STREAMER_EFCONSTRUCTION, 10); - params.set(PARAM_HNSW_STREAMER_EF, 5); - params.set(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, 1000U); - ailego::Params stg_params; - auto storage = IndexFactory::CreateStorage("MMapFileStorage"); - ASSERT_EQ(0, storage->init(stg_params)); - ASSERT_EQ(0, storage->open(_dir + "/TestStreamerDump.index", true)); - ASSERT_EQ(0, streamer->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, streamer->open(storage)); - - NumericalVector vec(dim); - size_t cnt = 5000U; - auto ctx = streamer->create_context(); - ASSERT_TRUE(!!ctx); - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - for (size_t i = 0; i < cnt; i++) { - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - streamer->add_impl(i, vec.data(), qmeta, ctx); - } - auto path = _dir + "/TestStreamerDump"; - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, streamer->dump(dumper)); - ASSERT_EQ(0, streamer->close()); - ASSERT_EQ(0, dumper->close()); - - // do searcher knn - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - auto read_storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, read_storage->open(path, false)); - ASSERT_TRUE(searcher != nullptr); - ASSERT_EQ(0, searcher->init(ailego::Params())); - ASSERT_EQ(0, searcher->load(read_storage, IndexMetric::Pointer())); - auto linearCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - size_t topk = 200; - linearCtx->set_topk(topk); - knnCtx->set_topk(topk); - uint64_t knnTotalTime = 0; - uint64_t linearTotalTime = 0; - int totalHits = 0; - int totalCnts = 0; - int topk1Hits = 0; - size_t step = 50; - for (size_t i = 0; i < cnt; i += step) { - for (size_t j = 0; j < dim; ++j) { - vec[j] = i + 0.1f; - } - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); - auto t2 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); - auto t3 = ailego::Realtime::MicroSeconds(); - knnTotalTime += t2 - t1; - linearTotalTime += t3 - t2; - - auto &knnResult = knnCtx->result(); - // ASSERT_EQ(topk, knnResult.size()); - topk1Hits += i == knnResult[0].key(); - - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - ASSERT_EQ(i, linearResult[0].key()); - - for (size_t k = 0; k < topk; ++k) { - totalCnts++; - for (size_t j = 0; j < topk; ++j) { - if (linearResult[j].key() == knnResult[k].key()) { - totalHits++; - break; - } - } - } - } - float recall = totalHits * step * 1.0f / totalCnts; - float topk1Recall = topk1Hits * step * 1.0f / cnt; - float cost = linearTotalTime * 1.0f / knnTotalTime; -#if 0 - printf("knnTotalTime=%zd linearTotalTime=%zd totalHits=%d totalCnts=%d " - "R@%zd=%f R@1=%f cost=%f\n", - knnTotalTime, linearTotalTime, totalHits, totalCnts, topk, recall, - topk1Recall, cost); -#endif - EXPECT_GT(recall, 0.90f); - EXPECT_GT(topk1Recall, 0.95f); - // EXPECT_GT(cost, 2.0f); -} - -TEST_F(HnswSearcherTest, TestSharedContext) { - auto gen_holder = [](int start, size_t doc_cnt) { - auto holder = - make_shared>(dim); - uint64_t key = start; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - key += 3; - holder->emplace(key, vec); - } - return holder; - }; - auto gen_index = [&gen_holder](int start, size_t docs, std::string path) { - auto holder = gen_holder(start, docs); - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ailego::Params params; - builder->init(*_index_meta_ptr, params); - builder->train(holder); - builder->build(holder); - auto dumper = IndexFactory::CreateDumper("FileDumper"); - dumper->create(path); - builder->dump(dumper); - dumper->close(); - - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - auto name = rand() % 2 ? "FileReadStorage" : "MMapFileReadStorage"; - auto storage = IndexFactory::CreateStorage(name); - storage->open(path, false); - params.set("proxima.hnsw.searcher.visit_bloomfilter_enable", rand() % 2); - searcher->init(ailego::Params()); - searcher->load(storage, IndexMetric::Pointer()); - return searcher; - }; - - srand(ailego::Realtime::MilliSeconds()); - size_t docs1 = rand() % 500 + 100; - size_t docs2 = rand() % 5000 + 100; - size_t docs3 = rand() % 50000 + 100; - auto path1 = _dir + "/TestSharedContext.index1"; - auto path2 = _dir + "/TestSharedContext.index2"; - auto path3 = _dir + "/TestSharedContext.index3"; - auto searcher1 = gen_index(0, docs1, path1); - auto searcher2 = gen_index(1, docs2, path2); - auto searcher3 = gen_index(2, docs3, path3); - - srand(ailego::Realtime::MilliSeconds()); - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - auto do_test = [&]() { - IndexSearcher::Context::Pointer ctx; - switch (rand() % 3) { - case 0: - ctx = searcher1->create_context(); - break; - case 1: - ctx = searcher2->create_context(); - break; - case 2: - ctx = searcher3->create_context(); - break; - } - ctx->set_topk(10); - - int ret = 0; - for (int i = 0; i < 100; ++i) { - NumericalVector query(dim); - for (size_t j = 0; j < dim; ++j) { - query[j] = i + 0.1f; - } - - auto code = rand() % 6; - switch (code) { - case 0: - ret = searcher1->search_impl(query.data(), qmeta, ctx); - break; - case 1: - ret = searcher2->search_impl(query.data(), qmeta, ctx); - break; - case 2: - ret = searcher3->search_impl(query.data(), qmeta, ctx); - break; - case 3: - ret = searcher1->search_bf_impl(query.data(), qmeta, ctx); - break; - case 4: - ret = searcher2->search_bf_impl(query.data(), qmeta, ctx); - break; - case 5: - ret = searcher3->search_bf_impl(query.data(), qmeta, ctx); - break; - } - - EXPECT_EQ(0, ret); - auto &results = ctx->result(); - EXPECT_EQ(10, results.size()); - for (int k = 0; k < 10; ++k) { - EXPECT_EQ(code % 3, results[k].key() % 3); - } - } - }; - auto t1 = std::async(std::launch::async, do_test); - auto t2 = std::async(std::launch::async, do_test); - t1.wait(); - t2.wait(); - - IndexStreamer::Pointer streamer = - IndexFactory::CreateStreamer("HnswStreamer"); - auto storage = IndexFactory::CreateStorage("MMapFileStorage"); - storage->init(ailego::Params()); - storage->open(_dir + "/TestSharedContext.index4", true); - streamer->init(*_index_meta_ptr, ailego::Params()); - streamer->open(storage); - NumericalVector query(dim); - auto ctx1 = streamer->create_context(); - EXPECT_EQ(IndexError_Unsupported, - searcher1->search_impl(query.data(), qmeta, ctx1)); - - auto ctx2 = searcher1->create_context(); - EXPECT_EQ(IndexError_Unsupported, - streamer->search_impl(query.data(), qmeta, ctx2)); -} - -TEST_F(HnswSearcherTest, TestProvider) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - size_t doc_cnt = 5000UL; - std::vector keys(doc_cnt); - srand(ailego::Realtime::MilliSeconds()); - bool rand_key = rand() % 2; - bool rand_order = rand() % 2; - size_t step = rand() % 2 + 1; - LOG_DEBUG("randKey=%u randOrder=%u step=%zu", rand_key, rand_order, step); - if (rand_key) { - std::mt19937 mt; - std::uniform_int_distribution dt( - 0, std::numeric_limits::max()); - for (size_t i = 0; i < doc_cnt; ++i) { - keys[i] = dt(mt); - } - } else { - std::iota(keys.begin(), keys.end(), 0U); - std::transform(keys.begin(), keys.end(), keys.begin(), - [&](key_t k) { return step * k; }); - if (rand_order) { - uint32_t seed = ailego::Realtime::Seconds(); - std::shuffle(keys.begin(), keys.end(), std::default_random_engine(seed)); - } - } - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = keys[i]; - } - ASSERT_TRUE(holder->emplace(keys[i], vec)); - } - ailego::Params params; - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestProvider"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 1); - ASSERT_EQ(0, searcher->init(searcherParams)); - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - - auto provider = searcher->create_provider(); - for (size_t i = 0; i < keys.size(); ++i) { - const float *d1 = - reinterpret_cast(provider->get_vector(keys[i])); - ASSERT_TRUE(d1); - for (size_t j = 0; j < dim; ++j) { - ASSERT_FLOAT_EQ(d1[j], keys[i]); - } - } - - auto iter = provider->create_iterator(); - size_t cnt = 0; - while (iter->is_valid()) { - auto key = iter->key(); - const float *d = reinterpret_cast(iter->data()); - for (size_t j = 0; j < dim; ++j) { - ASSERT_FLOAT_EQ(d[j], key); - } - cnt++; - iter->next(); - } - ASSERT_EQ(cnt, doc_cnt); - - ASSERT_EQ(dim, provider->dimension()); - ASSERT_EQ(_index_meta_ptr->element_size(), provider->element_size()); - ASSERT_EQ(_index_meta_ptr->data_type(), provider->data_type()); -} - -TEST_F(HnswSearcherTest, TestMipsEuclideanMetric) { - constexpr size_t static dim = 32; - IndexMeta meta(IndexMeta::DataType::DT_FP32, dim); - meta.set_metric("MipsSquaredEuclidean", 0, ailego::Params()); - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - const size_t COUNT = 10000UL; - for (size_t i = 0; i < COUNT; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i / 100.0f; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - ASSERT_EQ(0, builder->init(meta, ailego::Params())); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestMipsEuclideanMetric"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ailego::Params params; - params.set("proxima.hnsw.searcher.ef", 10); - ASSERT_TRUE(searcher != nullptr); - ASSERT_EQ(0, searcher->init(params)); - - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - auto ctx = searcher->create_context(); - ASSERT_TRUE(!!ctx); - - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = 1.0; - } - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - size_t topk = 50; - ctx->set_topk(topk); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - auto &results = ctx->result(); - EXPECT_EQ(results.size(), topk); - EXPECT_NEAR((uint64_t)(COUNT - 1), results[0].key(), 20); -} - -TEST_F(HnswSearcherTest, TestRandomPaddingTopk) { - std::mt19937 mt{}; - std::uniform_real_distribution gen(0.0f, 1.0f); - constexpr size_t static dim = 8; - IndexMeta meta(IndexMeta::DataType::DT_FP32, dim); - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - const size_t COUNT = 10000UL; - for (size_t i = 0; i < COUNT; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = gen(mt); - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - ASSERT_EQ(0, builder->init(meta, ailego::Params())); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestRandomPadding"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ailego::Params params; - params.set("proxima.hnsw.searcher.force_padding_result_enable", true); - params.set("proxima.hnsw.searcher.scan_ratio", 0.01f); - ASSERT_TRUE(searcher != nullptr); - ASSERT_EQ(0, searcher->init(params)); - - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - auto ctx = searcher->create_context(); - ASSERT_TRUE(!!ctx); - - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = 1.0; - } - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - std::uniform_int_distribution gen_int(1, COUNT); - size_t topk = gen_int(mt); - ctx->set_topk(topk); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - auto &results = ctx->result(); - EXPECT_EQ(results.size(), topk); - for (size_t i = 0; i < results.size(); ++i) { - for (size_t j = 0; j < i; ++j) { - EXPECT_NE(results[i].key(), results[j].key()); - } - } - - ctx->set_filter([](uint64_t key) { return true; }); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - auto &results1 = ctx->result(); - EXPECT_EQ(results1.size(), 0); -} - - -TEST_F(HnswSearcherTest, TestBruteForceSetupInContext) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - size_t doc_cnt = 5000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - - ailego::Params params; - // params.set("proxima.hnsw.builder.max_neighbor_count", 16); - params.set("proxima.hnsw.builder.scaling_factor", 16); - params.set("proxima.hnsw.builder.ef_construction", 10); - params.set("proxima.hnsw.builder.thread_count", 2); - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestGeneral"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 1); - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - - NumericalVector vec(dim); - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - size_t topk = 200; - uint64_t knnTotalTime = 0; - uint64_t linearTotalTime = 0; - int totalHits = 0; - int totalCnts = 0; - int topk1Hits = 0; - - bool set_bf_threshold = false; - bool use_update = false; - - size_t step = 50; - for (size_t i = 0; i < doc_cnt; i += step) { - auto linearCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - - ASSERT_TRUE(!!linearCtx); - ASSERT_TRUE(!!linearCtx); - - linearCtx->set_topk(topk); - knnCtx->set_topk(topk); - - for (size_t j = 0; j < dim; ++j) { - vec[j] = i + 0.1f; - } - auto t1 = ailego::Realtime::MicroSeconds(); - - if (set_bf_threshold) { - if (use_update) { - ailego::Params searcherParamsExtra; - - searcherParamsExtra.set("proxima.hnsw.searcher.brute_force_threshold", - doc_cnt); - knnCtx->update(searcherParamsExtra); - } else { - knnCtx->set_bruteforce_threshold(doc_cnt); - } - - use_update = !use_update; - } - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); - - auto t2 = ailego::Realtime::MicroSeconds(); - - ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); - // auto t3 = ailego::Realtime::MicroSeconds(); - - if (set_bf_threshold) { - linearTotalTime += t2 - t1; - } else { - knnTotalTime += t2 - t1; - } - - set_bf_threshold = !set_bf_threshold; - - auto &knnResult = knnCtx->result(); - // TODO: check - // ASSERT_EQ(topk, knnResult.size()); - topk1Hits += i == knnResult[0].key(); - - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - ASSERT_EQ(i, linearResult[0].key()); - - for (size_t k = 0; k < topk; ++k) { - totalCnts++; - for (size_t j = 0; j < topk; ++j) { - if (linearResult[j].key() == knnResult[k].key()) { - totalHits++; - break; - } - } - } - } - float recall = totalHits * step * step * 1.0f / totalCnts; - float topk1Recall = topk1Hits * step * 1.0f / doc_cnt; - float cost = linearTotalTime * 1.0f / knnTotalTime; -#if 0 - printf("knnTotalTime=%zd linearTotalTime=%zd totalHits=%d totalCnts=%d " - "R@%zd=%f R@1=%f cost=%f\n", - knnTotalTime, linearTotalTime, totalHits, totalCnts, topk, recall, - topk1Recall, cost); -#endif - EXPECT_GT(recall, 0.90f); - EXPECT_GT(topk1Recall, 0.90f); - // EXPECT_GT(cost, 2.0f); -} - -TEST_F(HnswSearcherTest, TestCosine) { - IndexStreamer::Pointer streamer = - IndexFactory::CreateStreamer("HnswStreamer"); - ASSERT_NE(streamer, nullptr); - - ailego::Params params; - params.set(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, 50); - params.set(PARAM_HNSW_STREAMER_SCALING_FACTOR, 16); - params.set(PARAM_HNSW_STREAMER_EFCONSTRUCTION, 100); - params.set(PARAM_HNSW_STREAMER_EF, 100); - params.set(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, 1000U); - ailego::Params stg_params; - - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); - index_meta_raw.set_metric("Cosine", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("CosineFp32Converter"); - ASSERT_TRUE(converter != nullptr); - - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); - ASSERT_TRUE(reformer != nullptr); - - ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); - - auto storage = IndexFactory::CreateStorage("MMapFileStorage"); - ASSERT_EQ(0, storage->init(stg_params)); - ASSERT_EQ(0, storage->open(_dir + "/TestCosine.index", true)); - ASSERT_EQ(0, streamer->init(index_meta, params)); - ASSERT_EQ(0, streamer->open(storage)); - - NumericalVector vec(dim); - size_t cnt = 5000U; - auto ctx = streamer->create_context(); - ASSERT_TRUE(!!ctx); - - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - - float fixed_value = float(cnt) / 2; - for (size_t i = 0; i < cnt; i++) { - float add_on = i * 10; - for (size_t j = 0; j < dim; ++j) { - if (j < dim / 4) - vec[j] = fixed_value; - else - vec[j] = fixed_value + add_on; - } - - std::string new_vec; - IndexQueryMeta new_meta; - - ASSERT_EQ(0, reformer->convert(vec.data(), qmeta, &new_vec, &new_meta)); - ASSERT_EQ(0, streamer->add_impl(i, new_vec.data(), new_meta, ctx)); - } - - auto path = _dir + "/TestCosine"; - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, streamer->dump(dumper)); - ASSERT_EQ(0, streamer->close()); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 100); - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto read_storage = IndexFactory::CreateStorage("MMapFileReadStorage"); - ASSERT_EQ(0, read_storage->open(path, false)); - ASSERT_EQ(0, searcher->load(read_storage, IndexMetric::Pointer())); - - size_t query_cnt = 200U; - auto linearCtx = searcher->create_context(); - auto linearByPKeysCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - - ASSERT_TRUE(!!linearCtx); - ASSERT_TRUE(!!linearByPKeysCtx); - ASSERT_TRUE(!!knnCtx); - - size_t topk = 200; - linearCtx->set_topk(topk); - knnCtx->set_topk(topk); - - uint64_t knnTotalTime = 0; - uint64_t linearTotalTime = 0; - int totalHits = 0; - int totalCnts = 0; - int topk1Hits = 0; - - NumericalVector qvec(dim); - for (size_t i = 0; i < query_cnt; i++) { - float add_on = i * 10; - for (size_t j = 0; j < dim; ++j) { - if (j < dim / 4) - qvec[j] = fixed_value; - else - qvec[j] = fixed_value + add_on; - } - - std::string new_query; - IndexQueryMeta new_meta; - ASSERT_EQ(0, - reformer->transform(qvec.data(), qmeta, &new_query, &new_meta)); - - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, knnCtx)); - auto t2 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, - searcher->search_bf_impl(new_query.data(), new_meta, linearCtx)); - auto t3 = ailego::Realtime::MicroSeconds(); - knnTotalTime += t2 - t1; - linearTotalTime += t3 - t2; - - auto &knnResult = knnCtx->result(); - ASSERT_EQ(topk, knnResult.size()); - topk1Hits += i == knnResult[0].key(); - - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - ASSERT_EQ(i, linearResult[0].key()); - - for (size_t k = 0; k < topk; ++k) { - totalCnts++; - for (size_t j = 0; j < topk; ++j) { - if (linearResult[j].key() == knnResult[k].key()) { - totalHits++; - break; - } - } - } - } - - float recall = totalHits * 1.0f / totalCnts; - float topk1Recall = topk1Hits * 1.0f / query_cnt; - float cost = linearTotalTime * 1.0f / knnTotalTime; - - EXPECT_GT(recall, 0.90f); - EXPECT_GT(topk1Recall, 0.90f); - // EXPECT_GT(cost, 2.0f); -} - -TEST_F(HnswSearcherTest, TestFetchVector) { - IndexStreamer::Pointer streamer = - IndexFactory::CreateStreamer("HnswStreamer"); - ASSERT_TRUE(streamer != nullptr); - - IndexMeta index_meta(IndexMeta::DataType::DT_FP32, dim); - index_meta.set_metric("SquaredEuclidean", 0, ailego::Params()); - - ailego::Params params; - params.set(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, 50); - params.set(PARAM_HNSW_STREAMER_SCALING_FACTOR, 16); - params.set(PARAM_HNSW_STREAMER_EFCONSTRUCTION, 100); - params.set(PARAM_HNSW_STREAMER_EF, 100); - params.set(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, 1000U); - ailego::Params stg_params; - - auto storage = IndexFactory::CreateStorage("MMapFileStorage"); - ASSERT_EQ(0, storage->init(stg_params)); - ASSERT_EQ(0, storage->open(_dir + "/TestFetchVector.index", true)); - ASSERT_EQ(0, streamer->init(index_meta, params)); - ASSERT_EQ(0, streamer->open(storage)); - - NumericalVector vec(dim); - size_t cnt = 2000U; - auto ctx = streamer->create_context(); - ASSERT_TRUE(!!ctx); - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - - for (size_t i = 0; i < cnt; i++) { - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - - streamer->add_impl(i, vec.data(), qmeta, ctx); - } - - auto path = _dir + "/TestFetchVector"; - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, streamer->dump(dumper)); - ASSERT_EQ(0, streamer->close()); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 100); - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto read_storage = IndexFactory::CreateStorage("MMapFileReadStorage"); - ASSERT_EQ(0, read_storage->open(path, false)); - ASSERT_EQ(0, searcher->load(read_storage, IndexMetric::Pointer())); - - for (size_t i = 0; i < cnt; i++) { - const void *vector = searcher->get_vector(i); - ASSERT_NE(vector, nullptr); - - float vector_value = *(float *)(vector); - ASSERT_EQ(vector_value, i); - } - - size_t query_cnt = 200U; - auto linearCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - auto linearByPKeysCtx = searcher->create_context(); - knnCtx->set_fetch_vector(true); - - size_t topk = 200; - linearCtx->set_topk(topk); - knnCtx->set_topk(topk); - uint64_t knnTotalTime = 0; - uint64_t linearTotalTime = 0; - - for (size_t i = 0; i < query_cnt; i++) { - for (size_t j = 0; j < dim; ++j) { - vec[j] = i; - } - - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); - auto t2 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); - auto t3 = ailego::Realtime::MicroSeconds(); - knnTotalTime += t2 - t1; - linearTotalTime += t3 - t2; - - auto &knnResult = knnCtx->result(); - ASSERT_EQ(topk, knnResult.size()); - - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - ASSERT_EQ(i, linearResult[0].key()); - - ASSERT_NE(knnResult[0].vector(), nullptr); - float vector_value = *((float *)(knnResult[0].vector())); - ASSERT_EQ(vector_value, i); - } - - std::cout << "knnTotalTime: " << knnTotalTime << std::endl; - std::cout << "linearTotalTime: " << linearTotalTime << std::endl; -} - -TEST_F(HnswSearcherTest, TestFetchVectorCosine) { - IndexStreamer::Pointer streamer = - IndexFactory::CreateStreamer("HnswStreamer"); - ASSERT_NE(streamer, nullptr); - - ailego::Params params; - params.set(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, 50); - params.set(PARAM_HNSW_STREAMER_SCALING_FACTOR, 16); - params.set(PARAM_HNSW_STREAMER_EFCONSTRUCTION, 100); - params.set(PARAM_HNSW_STREAMER_EF, 100); - params.set(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, 1000U); - params.set(PARAM_HNSW_STREAMER_GET_VECTOR_ENABLE, true); - - ailego::Params stg_params; - - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); - index_meta_raw.set_metric("Cosine", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("CosineFp32Converter"); - ASSERT_TRUE(converter != nullptr); - - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); - ASSERT_TRUE(reformer != nullptr); - - ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); - - auto storage = IndexFactory::CreateStorage("MMapFileStorage"); - ASSERT_EQ(0, storage->init(stg_params)); - ASSERT_EQ(0, storage->open(_dir + "/TestFetchVectorCosine.index", true)); - ASSERT_EQ(0, streamer->init(index_meta, params)); - ASSERT_EQ(0, streamer->open(storage)); - - NumericalVector vec(dim); - size_t cnt = 2000U; - auto ctx = streamer->create_context(); - ASSERT_TRUE(!!ctx); - - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - IndexQueryMeta new_meta; - - const float epsilon = 1e-2; - float fixed_value = float(cnt) / 2; - for (size_t i = 0; i < cnt; i++) { - float add_on = i * 10; - - for (size_t j = 0; j < dim; ++j) { - if (j < dim / 4) - vec[j] = fixed_value; - else - vec[j] = fixed_value + add_on; - } - - std::string new_vec; - - ASSERT_EQ(0, reformer->convert(vec.data(), qmeta, &new_vec, &new_meta)); - ASSERT_EQ(0, streamer->add_impl(i, new_vec.data(), new_meta, ctx)); - } - - auto path = _dir + "/TestFetchVectorCosine"; - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, streamer->dump(dumper)); - ASSERT_EQ(0, streamer->close()); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 100); - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto read_storage = IndexFactory::CreateStorage("MMapFileReadStorage"); - ASSERT_EQ(0, read_storage->open(path, false)); - ASSERT_EQ(0, searcher->load(read_storage, IndexMetric::Pointer())); - - for (size_t i = 0; i < cnt; i++) { - float add_on = i * 10; - - const void *vector = searcher->get_vector(i); - ASSERT_NE(vector, nullptr); - - std::string denormalized_vec; - denormalized_vec.resize(dim * sizeof(float)); - reformer->revert(vector, new_meta, &denormalized_vec); - - float vector_value = *((float *)(denormalized_vec.data()) + dim - 1); - EXPECT_NEAR(vector_value, fixed_value + add_on, epsilon); - } - - size_t query_cnt = 200U; - auto linearCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - auto linearByPKeysCtx = searcher->create_context(); - knnCtx->set_fetch_vector(true); - - size_t topk = 200; - linearCtx->set_topk(topk); - knnCtx->set_topk(topk); - uint64_t knnTotalTime = 0; - uint64_t linearTotalTime = 0; - - NumericalVector qvec(dim); - for (size_t i = 0; i < query_cnt; i++) { - float add_on = i * 10; - - for (size_t j = 0; j < dim; ++j) { - if (j < dim / 4) - qvec[j] = fixed_value; - else - qvec[j] = fixed_value + add_on; - } - - std::string new_query; - IndexQueryMeta new_meta; - ASSERT_EQ(0, - reformer->transform(qvec.data(), qmeta, &new_query, &new_meta)); - - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, knnCtx)); - auto t2 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, - searcher->search_bf_impl(new_query.data(), new_meta, linearCtx)); - auto t3 = ailego::Realtime::MicroSeconds(); - - knnTotalTime += t2 - t1; - linearTotalTime += t3 - t2; - - auto &knnResult = knnCtx->result(); - ASSERT_EQ(topk, knnResult.size()); - - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - ASSERT_EQ(i, linearResult[0].key()); - - ASSERT_NE(knnResult[0].vector(), nullptr); - - std::string denormalized_vec; - denormalized_vec.resize(dim * sizeof(float)); - reformer->revert(knnResult[0].vector(), new_meta, &denormalized_vec); - - float vector_value = *(((float *)(denormalized_vec.data()) + dim - 1)); - EXPECT_NEAR(vector_value, fixed_value + add_on, epsilon); - } - - std::cout << "knnTotalTime: " << knnTotalTime << std::endl; - std::cout << "linearTotalTime: " << linearTotalTime << std::endl; -} - - -TEST_F(HnswSearcherTest, TestFetchVectorCosineHalfFloatConverter) { - IndexStreamer::Pointer streamer = - IndexFactory::CreateStreamer("HnswStreamer"); - ASSERT_NE(streamer, nullptr); - - ailego::Params params; - params.set(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, 50); - params.set(PARAM_HNSW_STREAMER_SCALING_FACTOR, 16); - params.set(PARAM_HNSW_STREAMER_EFCONSTRUCTION, 100); - params.set(PARAM_HNSW_STREAMER_EF, 100); - params.set(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, 1000U); - params.set(PARAM_HNSW_STREAMER_GET_VECTOR_ENABLE, true); - - ailego::Params stg_params; - - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP16, dim); - index_meta_raw.set_metric("Cosine", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("CosineHalfFloatConverter"); - ASSERT_TRUE(converter != nullptr); - - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); - ASSERT_TRUE(reformer != nullptr); - - ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); - - auto storage = IndexFactory::CreateStorage("MMapFileStorage"); - ASSERT_EQ(0, storage->init(stg_params)); - ASSERT_EQ( - 0, storage->open(_dir + "/TestFetchVectorCosineHalfFloatConverter.index", - true)); - ASSERT_EQ(0, streamer->init(index_meta, params)); - ASSERT_EQ(0, streamer->open(storage)); - - size_t cnt = 2000U; - auto ctx = streamer->create_context(); - ASSERT_TRUE(!!ctx); - - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP16, dim); - IndexQueryMeta new_meta; - - const float epsilon = 0.1; - - std::random_device rd; - std::mt19937 gen(rd()); - - std::uniform_real_distribution dist(-2.0, 2.0); - - std::vector> vecs; - for (size_t i = 0; i < cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - float value = dist(gen); - vec[j] = ailego::FloatHelper::ToFP16(value); - } - - std::string new_vec; - - ASSERT_EQ(0, reformer->convert(vec.data(), qmeta, &new_vec, &new_meta)); - ASSERT_EQ(0, streamer->add_impl(i, new_vec.data(), new_meta, ctx)); - - vecs.push_back(vec); - } - - auto path = _dir + "/TestFetchVectorCosineHalfFloatConverter"; - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, streamer->dump(dumper)); - ASSERT_EQ(0, streamer->close()); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 100); - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto read_storage = IndexFactory::CreateStorage("MMapFileReadStorage"); - ASSERT_EQ(0, read_storage->open(path, false)); - ASSERT_EQ(0, searcher->load(read_storage, IndexMetric::Pointer())); - - for (size_t i = 0; i < cnt; i++) { - uint16_t expected_vec_value = vecs[i][dim - 1]; - - const void *vector = searcher->get_vector(i); - ASSERT_NE(vector, nullptr); - - std::string denormalized_vec; - denormalized_vec.resize(dim * sizeof(uint16_t)); - reformer->revert(vector, new_meta, &denormalized_vec); - - uint16_t vector_value = *((uint16_t *)(denormalized_vec.data()) + dim - 1); - float vector_value_float = ailego::FloatHelper::ToFP32(vector_value); - - float expected_vec_float = ailego::FloatHelper::ToFP32(expected_vec_value); - - EXPECT_NEAR(expected_vec_float, vector_value_float, epsilon); - } - - size_t query_cnt = 200U; - auto linearCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - auto linearByPKeysCtx = searcher->create_context(); - knnCtx->set_fetch_vector(true); - - size_t topk = 200; - linearCtx->set_topk(topk); - knnCtx->set_topk(topk); - uint64_t knnTotalTime = 0; - uint64_t linearTotalTime = 0; - - NumericalVector qvec(dim); - - for (size_t i = 0; i < query_cnt; i++) { - auto &vec = vecs[i]; - - std::string new_query; - IndexQueryMeta new_meta; - ASSERT_EQ(0, reformer->transform(vec.data(), qmeta, &new_query, &new_meta)); - - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, knnCtx)); - auto t2 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, - searcher->search_bf_impl(new_query.data(), new_meta, linearCtx)); - auto t3 = ailego::Realtime::MicroSeconds(); - - knnTotalTime += t2 - t1; - linearTotalTime += t3 - t2; - - auto &knnResult = knnCtx->result(); - ASSERT_EQ(topk, knnResult.size()); - - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - ASSERT_EQ(i, linearResult[0].key()); - - ASSERT_NE(knnResult[0].vector(), nullptr); - - std::string denormalized_vec; - denormalized_vec.resize(dim * sizeof(uint16_t)); - reformer->revert(knnResult[0].vector(), new_meta, &denormalized_vec); - - uint16_t expected_vec_value = vec[dim - 1]; - uint16_t vector_value = - *(((uint16_t *)(denormalized_vec.data()) + dim - 1)); - - float vector_value_float = ailego::FloatHelper::ToFP32(vector_value); - float expected_vec_float = ailego::FloatHelper::ToFP32(expected_vec_value); - - EXPECT_NEAR(expected_vec_float, vector_value_float, epsilon); - } - - std::cout << "knnTotalTime: " << knnTotalTime << std::endl; - std::cout << "linearTotalTime: " << linearTotalTime << std::endl; -} - -TEST_F(HnswSearcherTest, TestFetchVectorCosineFp16Converter) { - IndexStreamer::Pointer streamer = - IndexFactory::CreateStreamer("HnswStreamer"); - ASSERT_NE(streamer, nullptr); - - ailego::Params params; - params.set(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, 50); - params.set(PARAM_HNSW_STREAMER_SCALING_FACTOR, 16); - params.set(PARAM_HNSW_STREAMER_EFCONSTRUCTION, 100); - params.set(PARAM_HNSW_STREAMER_EF, 100); - params.set(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, 1000U); - params.set(PARAM_HNSW_STREAMER_GET_VECTOR_ENABLE, true); - - ailego::Params stg_params; - - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); - index_meta_raw.set_metric("Cosine", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("CosineFp16Converter"); - ASSERT_TRUE(converter != nullptr); - - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); - ASSERT_TRUE(reformer != nullptr); - - ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); - - auto storage = IndexFactory::CreateStorage("MMapFileStorage"); - ASSERT_EQ(0, storage->init(stg_params)); - ASSERT_EQ(0, storage->open(_dir + "/TestFetchVectorCosineFp16Converter.index", - true)); - ASSERT_EQ(0, streamer->init(index_meta, params)); - ASSERT_EQ(0, streamer->open(storage)); - - size_t cnt = 2000U; - auto ctx = streamer->create_context(); - ASSERT_TRUE(!!ctx); - - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - IndexQueryMeta new_meta; - - const float epsilon = 0.1; - - std::random_device rd; - std::mt19937 gen(rd()); - - std::uniform_real_distribution dist(-2.0, 2.0); - - std::vector> vecs; - for (size_t i = 0; i < cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = dist(gen); - } - - std::string new_vec; - - ASSERT_EQ(0, reformer->convert(vec.data(), qmeta, &new_vec, &new_meta)); - ASSERT_EQ(0, streamer->add_impl(i, new_vec.data(), new_meta, ctx)); - - vecs.push_back(vec); - } - - auto path = _dir + "/TestFetchVectorCosineFp16Converter"; - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, streamer->dump(dumper)); - ASSERT_EQ(0, streamer->close()); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 100); - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto read_storage = IndexFactory::CreateStorage("MMapFileReadStorage"); - ASSERT_EQ(0, read_storage->open(path, false)); - ASSERT_EQ(0, searcher->load(read_storage, IndexMetric::Pointer())); - - for (size_t i = 0; i < cnt; i++) { - float expected_vec_value = vecs[i][dim - 1]; - - const void *vector = searcher->get_vector(i); - ASSERT_NE(vector, nullptr); - - std::string denormalized_vec; - denormalized_vec.resize(dim * sizeof(float)); - reformer->revert(vector, new_meta, &denormalized_vec); - float vector_value = *((float *)(denormalized_vec.data()) + dim - 1); - - EXPECT_NEAR(expected_vec_value, vector_value, epsilon); - } - - size_t query_cnt = 200U; - auto linearCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - auto linearByPKeysCtx = searcher->create_context(); - knnCtx->set_fetch_vector(true); - - size_t topk = 200; - linearCtx->set_topk(topk); - knnCtx->set_topk(topk); - uint64_t knnTotalTime = 0; - uint64_t linearTotalTime = 0; - - NumericalVector qvec(dim); - - for (size_t i = 0; i < query_cnt; i++) { - auto &vec = vecs[i]; - - std::string new_query; - IndexQueryMeta new_meta; - ASSERT_EQ(0, reformer->transform(vec.data(), qmeta, &new_query, &new_meta)); - - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, knnCtx)); - auto t2 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, - searcher->search_bf_impl(new_query.data(), new_meta, linearCtx)); - auto t3 = ailego::Realtime::MicroSeconds(); - - knnTotalTime += t2 - t1; - linearTotalTime += t3 - t2; - - auto &knnResult = knnCtx->result(); - ASSERT_EQ(topk, knnResult.size()); - - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - ASSERT_EQ(i, linearResult[0].key()); - - ASSERT_NE(knnResult[0].vector(), nullptr); - - std::string denormalized_vec; - denormalized_vec.resize(dim * sizeof(float)); - reformer->revert(knnResult[0].vector(), new_meta, &denormalized_vec); - - float expected_vec_value = vec[dim - 1]; - float vector_value = *(((float *)(denormalized_vec.data()) + dim - 1)); - - EXPECT_NEAR(expected_vec_value, vector_value, epsilon); - } - - std::cout << "knnTotalTime: " << knnTotalTime << std::endl; - std::cout << "linearTotalTime: " << linearTotalTime << std::endl; -} - -TEST_F(HnswSearcherTest, TestFetchVectorCosineInt8Converter) { - IndexStreamer::Pointer streamer = - IndexFactory::CreateStreamer("HnswStreamer"); - ASSERT_NE(streamer, nullptr); - - ailego::Params params; - params.set(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, 50); - params.set(PARAM_HNSW_STREAMER_SCALING_FACTOR, 16); - params.set(PARAM_HNSW_STREAMER_EFCONSTRUCTION, 100); - params.set(PARAM_HNSW_STREAMER_EF, 100); - params.set(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, 1000U); - params.set(PARAM_HNSW_STREAMER_GET_VECTOR_ENABLE, true); - - ailego::Params stg_params; - - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); - index_meta_raw.set_metric("Cosine", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("CosineInt8Converter"); - ASSERT_TRUE(converter != nullptr); - - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); - ASSERT_TRUE(reformer != nullptr); - - ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); - - auto storage = IndexFactory::CreateStorage("MMapFileStorage"); - ASSERT_EQ(0, storage->init(stg_params)); - ASSERT_EQ(0, storage->open(_dir + "/TestFetchVectorCosineInt8Converter.index", - true)); - ASSERT_EQ(0, streamer->init(index_meta, params)); - ASSERT_EQ(0, streamer->open(storage)); - - NumericalVector vec(dim); - size_t cnt = 2000U; - auto ctx = streamer->create_context(); - ASSERT_TRUE(!!ctx); - - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - IndexQueryMeta new_meta; - - const float epsilon = 1e-2; - float fixed_value = float(cnt) / 2; - for (size_t i = 0; i < cnt; i++) { - float add_on = i * 10; - - for (size_t j = 0; j < dim; ++j) { - if (j < dim / 4) - vec[j] = fixed_value; - else - vec[j] = fixed_value + add_on; - } - - std::string new_vec; - - ASSERT_EQ(0, reformer->convert(vec.data(), qmeta, &new_vec, &new_meta)); - ASSERT_EQ(0, streamer->add_impl(i, new_vec.data(), new_meta, ctx)); - } - - auto path = _dir + "/TestFetchVectorCosineInt8Converter"; - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, streamer->dump(dumper)); - ASSERT_EQ(0, streamer->close()); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 100); - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto read_storage = IndexFactory::CreateStorage("MMapFileReadStorage"); - ASSERT_EQ(0, read_storage->open(path, false)); - ASSERT_EQ(0, searcher->load(read_storage, IndexMetric::Pointer())); - - for (size_t i = 0; i < cnt; i++) { - float add_on = i * 10; - - const void *vector = searcher->get_vector(i); - ASSERT_NE(vector, nullptr); - - std::string denormalized_vec; - denormalized_vec.resize(dim * sizeof(float)); - reformer->revert(vector, new_meta, &denormalized_vec); - - float vector_value = *((float *)(denormalized_vec.data()) + dim - 1); - EXPECT_NEAR(vector_value, fixed_value + add_on, epsilon); - } - - size_t query_cnt = 200U; - auto linearCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - auto linearByPKeysCtx = searcher->create_context(); - knnCtx->set_fetch_vector(true); - - size_t topk = 200; - linearCtx->set_topk(topk); - knnCtx->set_topk(topk); - uint64_t knnTotalTime = 0; - uint64_t linearTotalTime = 0; - - NumericalVector qvec(dim); - for (size_t i = 0; i < query_cnt; i++) { - float add_on = i * 10; - - for (size_t j = 0; j < dim; ++j) { - if (j < dim / 4) - qvec[j] = fixed_value; - else - qvec[j] = fixed_value + add_on; - } - - std::string new_query; - IndexQueryMeta new_meta; - ASSERT_EQ(0, - reformer->transform(qvec.data(), qmeta, &new_query, &new_meta)); - - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, knnCtx)); - auto t2 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, - searcher->search_bf_impl(new_query.data(), new_meta, linearCtx)); - auto t3 = ailego::Realtime::MicroSeconds(); - - knnTotalTime += t2 - t1; - linearTotalTime += t3 - t2; - - auto &knnResult = knnCtx->result(); - ASSERT_EQ(topk, knnResult.size()); - - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - ASSERT_EQ(i, linearResult[0].key()); - - ASSERT_NE(knnResult[0].vector(), nullptr); - - std::string denormalized_vec; - denormalized_vec.resize(dim * sizeof(float)); - reformer->revert(knnResult[0].vector(), new_meta, &denormalized_vec); - - float vector_value = *(((float *)(denormalized_vec.data()) + dim - 1)); - EXPECT_NEAR(vector_value, fixed_value + add_on, epsilon); - } - - std::cout << "knnTotalTime: " << knnTotalTime << std::endl; - std::cout << "linearTotalTime: " << linearTotalTime << std::endl; -} - -TEST_F(HnswSearcherTest, TestFetchVectorCosineInt4Converter) { - IndexStreamer::Pointer streamer = - IndexFactory::CreateStreamer("HnswStreamer"); - ASSERT_NE(streamer, nullptr); - - ailego::Params params; - params.set(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, 50); - params.set(PARAM_HNSW_STREAMER_SCALING_FACTOR, 16); - params.set(PARAM_HNSW_STREAMER_EFCONSTRUCTION, 100); - params.set(PARAM_HNSW_STREAMER_EF, 100); - params.set(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, 1000U); - params.set(PARAM_HNSW_STREAMER_GET_VECTOR_ENABLE, true); - - ailego::Params stg_params; - - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); - index_meta_raw.set_metric("Cosine", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("CosineInt4Converter"); - ASSERT_TRUE(converter != nullptr); - - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); - ASSERT_TRUE(reformer != nullptr); - - ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); - - auto storage = IndexFactory::CreateStorage("MMapFileStorage"); - ASSERT_EQ(0, storage->init(stg_params)); - ASSERT_EQ(0, storage->open(_dir + "/TestFetchVectorCosineInt4Converter.index", - true)); - ASSERT_EQ(0, streamer->init(index_meta, params)); - ASSERT_EQ(0, streamer->open(storage)); - - NumericalVector vec(dim); - size_t cnt = 2000U; - auto ctx = streamer->create_context(); - ASSERT_TRUE(!!ctx); - - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - IndexQueryMeta new_meta; - - const float epsilon = 1e-2; - float fixed_value = float(cnt) / 2; - for (size_t i = 0; i < cnt; i++) { - float add_on = i * 10; - - for (size_t j = 0; j < dim; ++j) { - if (j < dim / 4) - vec[j] = fixed_value; - else - vec[j] = fixed_value + add_on; - } - - std::string new_vec; - - ASSERT_EQ(0, reformer->convert(vec.data(), qmeta, &new_vec, &new_meta)); - ASSERT_EQ(0, streamer->add_impl(i, new_vec.data(), new_meta, ctx)); - } - - auto path = _dir + "/TestFetchVectorCosineInt4Converter"; - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, streamer->dump(dumper)); - ASSERT_EQ(0, streamer->close()); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 100); - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto read_storage = IndexFactory::CreateStorage("MMapFileReadStorage"); - ASSERT_EQ(0, read_storage->open(path, false)); - ASSERT_EQ(0, searcher->load(read_storage, IndexMetric::Pointer())); - - for (size_t i = 0; i < cnt; i++) { - float add_on = i * 10; - - const void *vector = searcher->get_vector(i); - ASSERT_NE(vector, nullptr); - - std::string denormalized_vec; - denormalized_vec.resize(dim * sizeof(float)); - reformer->revert(vector, new_meta, &denormalized_vec); - - float vector_value = *((float *)(denormalized_vec.data()) + dim - 1); - EXPECT_NEAR(vector_value, fixed_value + add_on, epsilon); - } - - size_t query_cnt = 200U; - auto linearCtx = searcher->create_context(); - auto knnCtx = searcher->create_context(); - auto linearByPKeysCtx = searcher->create_context(); - knnCtx->set_fetch_vector(true); - - size_t topk = 100; - linearCtx->set_topk(topk); - knnCtx->set_topk(topk); - uint64_t knnTotalTime = 0; - uint64_t linearTotalTime = 0; - - NumericalVector qvec(dim); - for (size_t i = 0; i < query_cnt; i++) { - float add_on = i * 10; - - for (size_t j = 0; j < dim; ++j) { - if (j < dim / 4) - qvec[j] = fixed_value; - else - qvec[j] = fixed_value + add_on; - } - - std::string new_query; - IndexQueryMeta new_meta; - ASSERT_EQ(0, - reformer->transform(qvec.data(), qmeta, &new_query, &new_meta)); - - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, knnCtx)); - auto t2 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, - searcher->search_bf_impl(new_query.data(), new_meta, linearCtx)); - auto t3 = ailego::Realtime::MicroSeconds(); - - knnTotalTime += t2 - t1; - linearTotalTime += t3 - t2; - - auto &knnResult = knnCtx->result(); - ASSERT_EQ(topk, knnResult.size()); - - auto &linearResult = linearCtx->result(); - ASSERT_EQ(topk, linearResult.size()); - ASSERT_EQ(i, linearResult[0].key()); - - ASSERT_NE(knnResult[0].vector(), nullptr); - - std::string denormalized_vec; - denormalized_vec.resize(dim * sizeof(float)); - reformer->revert(knnResult[0].vector(), new_meta, &denormalized_vec); - - float vector_value = *(((float *)(denormalized_vec.data()) + dim - 1)); - EXPECT_NEAR(vector_value, fixed_value + add_on, epsilon); - } - - std::cout << "knnTotalTime: " << knnTotalTime << std::endl; - std::cout << "linearTotalTime: " << linearTotalTime << std::endl; -} - -TEST_F(HnswSearcherTest, TestGroup) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - size_t doc_cnt = 5000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i / 10.0; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - - ailego::Params params; - - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestGroup"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_NE(searcher, nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 50); - searcherParams.set("proxima.hnsw.searcher.max_scan_ratio", 0.8); - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - - auto ctx = searcher->create_context(); - ASSERT_TRUE(!!ctx); - - NumericalVector vec(dim); - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - size_t group_topk = 20; - uint64_t total_time = 0; - - auto groupbyFunc = [](uint64_t key) { - uint32_t group_id = key / 10 % 10; - - // std::cout << "key: " << key << ", group id: " << group_id << std::endl; - - return std::string("g_") + std::to_string(group_id); - }; - - size_t group_num = 5; - - ctx->set_group_params(group_num, group_topk); - ctx->set_group_by(groupbyFunc); - - size_t query_value = doc_cnt / 2; - for (size_t j = 0; j < dim; ++j) { - vec[j] = float(query_value) / 10 + 0.1f; - } - - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - auto t2 = ailego::Realtime::MicroSeconds(); - - total_time += t2 - t1; - - std::cout << "total time: " << total_time << std::endl; - - auto &group_result = ctx->group_result(); - - for (uint32_t i = 0; i < group_result.size(); ++i) { - // const std::string &group_id = group_result[i].group_id(); - auto &result = group_result[i].docs(); - - ASSERT_GT(result.size(), 0); - // std::cout << "Group ID: " << group_id << std::endl; - - // for (uint32_t j = 0; j < result.size(); ++j) { - // std::cout << "\tKey: " << result[j].key() << std::fixed - // << std::setprecision(3) << ", Score: " << result[j].score() - // << std::endl; - // } - } - - // do linear search by p_keys test - auto groupbyFuncLinear = [](uint64_t key) { - uint32_t group_id = key % 10; - - return std::string("g_") + std::to_string(group_id); - }; - - auto linear_pk_ctx = searcher->create_context(); - - linear_pk_ctx->set_group_params(group_num, group_topk); - linear_pk_ctx->set_group_by(groupbyFuncLinear); - - std::vector> p_keys; - p_keys.resize(1); - p_keys[0] = {4, 3, 2, 1, 5, 6, 7, 8, 9, 10}; - - ASSERT_EQ(0, searcher->search_bf_by_p_keys_impl(vec.data(), p_keys, qmeta, - linear_pk_ctx)); - auto &linear_by_pkeys_group_result = linear_pk_ctx->group_result(); - ASSERT_EQ(linear_by_pkeys_group_result.size(), group_num); - - for (uint32_t i = 0; i < linear_by_pkeys_group_result.size(); ++i) { - // const std::string &group_id = linear_by_pkeys_group_result[i].group_id(); - auto &result = linear_by_pkeys_group_result[i].docs(); - - ASSERT_GT(result.size(), 0); - // std::cout << "Group ID: " << group_id << std::endl; - - // for (uint32_t j = 0; j < result.size(); ++j) { - // std::cout << "\tKey: " << result[j].key() << std::fixed - // << std::setprecision(3) << ", Score: " << result[j].score() - // << std::endl; - // } - - ASSERT_EQ(10 - i, result[0].key()); - } -} - -TEST_F(HnswSearcherTest, TestGroupNotEnoughNum) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - size_t doc_cnt = 5000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i / 10.0; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - - ailego::Params params; - - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestGroupNotEnoughNum"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_NE(searcher, nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 50); - searcherParams.set("proxima.hnsw.searcher.max_scan_ratio", 0.8); - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - - auto ctx = searcher->create_context(); - ASSERT_TRUE(!!ctx); - - NumericalVector vec(dim); - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - size_t group_topk = 20; - uint64_t total_time = 0; - - auto groupbyFunc = [](uint64_t key) { - uint32_t group_id = key / 10 % 10; - - // std::cout << "key: " << key << ", group id: " << group_id << std::endl; - - return std::string("g_") + std::to_string(group_id); - }; - - size_t group_num = 12; - ctx->set_group_params(group_num, group_topk); - ctx->set_group_by(groupbyFunc); - - size_t query_value = doc_cnt / 2; - for (size_t j = 0; j < dim; ++j) { - vec[j] = float(query_value) / 10 + 0.1f; - } - - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - auto t2 = ailego::Realtime::MicroSeconds(); - total_time += t2 - t1; - - std::cout << "total time: " << total_time << std::endl; - - auto &group_result = ctx->group_result(); - ASSERT_EQ(group_result.size(), 10); - - for (uint32_t i = 0; i < group_result.size(); ++i) { - // const std::string &group_id = group_result[i].group_id(); - auto &result = group_result[i].docs(); - - ASSERT_GT(result.size(), 0); - // std::cout << "Group ID: " << group_id << std::endl; - - // for (uint32_t j = 0; j < result.size(); ++j) { - // std::cout << "\tKey: " << result[j].key() << std::fixed - // << std::setprecision(3) << ", Score: " << result[j].score() - // << std::endl; - // } - } -} - -TEST_F(HnswSearcherTest, TestGroupInBruteforceSearch) { - IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("HnswBuilder"); - ASSERT_NE(builder, nullptr); - auto holder = - make_shared>(dim); - size_t doc_cnt = 5000UL; - for (size_t i = 0; i < doc_cnt; i++) { - NumericalVector vec(dim); - for (size_t j = 0; j < dim; ++j) { - vec[j] = i / 10.0; - } - ASSERT_TRUE(holder->emplace(i, vec)); - } - - ailego::Params params; - - ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); - ASSERT_EQ(0, builder->train(holder)); - ASSERT_EQ(0, builder->build(holder)); - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - string path = _dir + "/TestGroupInBruteforceSearch"; - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, builder->dump(dumper)); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_NE(searcher, nullptr); - ailego::Params searcherParams; - searcherParams.set("proxima.hnsw.searcher.ef", 50); - searcherParams.set("proxima.hnsw.searcher.max_scan_ratio", 0.8); - searcherParams.set("proxima.hnsw.searcher.brute_force_threshold", - 2 * doc_cnt); - - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto storage = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, storage->open(path, false)); - ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); - - auto ctx = searcher->create_context(); - ASSERT_TRUE(!!ctx); - - NumericalVector vec(dim); - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); - size_t group_topk = 20; - uint64_t total_time = 0; - - auto groupbyFunc = [](uint64_t key) { - uint32_t group_id = key / 10 % 10; - - // std::cout << "key: " << key << ", group id: " << group_id << std::endl; - - return std::string("g_") + std::to_string(group_id); - }; - - size_t group_num = 5; - ctx->set_group_params(group_num, group_topk); - ctx->set_group_by(groupbyFunc); - - size_t query_value = doc_cnt / 2; - for (size_t j = 0; j < dim; ++j) { - vec[j] = float(query_value) / 10 + 0.1f; - } - - auto t1 = ailego::Realtime::MicroSeconds(); - ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); - auto t2 = ailego::Realtime::MicroSeconds(); - total_time += t2 - t1; - - std::cout << "total time: " << total_time << std::endl; - - auto &group_result = ctx->group_result(); - ASSERT_EQ(group_result.size(), 5); - - for (uint32_t i = 0; i < group_result.size(); ++i) { - // const std::string &group_id = group_result[i].group_id(); - auto &result = group_result[i].docs(); - - ASSERT_GT(result.size(), 0); - // std::cout << "Group ID: " << group_id << std::endl; - - // for (uint32_t j = 0; j < result.size(); ++j) { - // std::cout << "\tKey: " << result[j].key() << std::fixed - // << std::setprecision(3) << ", Score: " << result[j].score() - // << std::endl; - // } - } -} - -TEST_F(HnswSearcherTest, TestBinaryConverter) { - uint32_t dimension = 256; - - IndexStreamer::Pointer streamer = - IndexFactory::CreateStreamer("HnswStreamer"); - ASSERT_TRUE(streamer != nullptr); - - ailego::Params params; - // params.set(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, 50); - // params.set(PARAM_HNSW_STREAMER_SCALING_FACTOR, 16); - // params.set(PARAM_HNSW_STREAMER_EFCONSTRUCTION, 10); - // params.set(PARAM_HNSW_STREAMER_EF, 5); - // params.set(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, 1000U); - - ailego::Params stg_params; - - IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dimension); - index_meta_raw.set_metric("InnerProduct", 0, ailego::Params()); - - ailego::Params converter_params; - auto converter = IndexFactory::CreateConverter("BinaryConverter"); - ASSERT_TRUE(converter != nullptr); - - converter->init(index_meta_raw, converter_params); - - IndexMeta index_meta = converter->meta(); - - auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); - ASSERT_TRUE(reformer != nullptr); - - ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); - - auto storage = IndexFactory::CreateStorage("MMapFileStorage"); - ASSERT_EQ(0, storage->init(stg_params)); - ASSERT_EQ(0, storage->open(_dir + "/TestBinaryConverter.index", true)); - ASSERT_EQ(0, streamer->init(index_meta, params)); - ASSERT_EQ(0, streamer->open(storage)); - - size_t cnt = 5000U; - auto ctx = streamer->create_context(); - ASSERT_TRUE(!!ctx); - - IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dimension); - - std::random_device rd; - std::mt19937 gen(rd()); - - std::uniform_real_distribution dist(-2.0, 2.0); - std::vector> vecs; - - for (size_t i = 0; i < cnt; i++) { - NumericalVector vec(dimension); - for (size_t j = 0; j < dimension; ++j) { - vec[j] = dist(gen); - } - - std::string new_vec; - IndexQueryMeta new_meta; - - ASSERT_EQ(0, reformer->convert(vec.data(), qmeta, &new_vec, &new_meta)); - ASSERT_EQ(0, streamer->add_impl(i, new_vec.data(), new_meta, ctx)); - - vecs.push_back(vec); - } - - auto path = _dir + "/TestBinaryConverter"; - auto dumper = IndexFactory::CreateDumper("FileDumper"); - ASSERT_NE(dumper, nullptr); - ASSERT_EQ(0, dumper->create(path)); - ASSERT_EQ(0, streamer->dump(dumper)); - ASSERT_EQ(0, streamer->close()); - ASSERT_EQ(0, dumper->close()); - - // test searcher - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - ASSERT_TRUE(searcher != nullptr); - - ailego::Params searcherParams; - ASSERT_EQ(0, searcher->init(searcherParams)); - - auto read_storage = IndexFactory::CreateStorage("MMapFileReadStorage"); - ASSERT_EQ(0, read_storage->open(path, false)); - ASSERT_EQ(0, searcher->load(read_storage, IndexMetric::Pointer())); - - size_t query_cnt = 200U; - auto knnCtx = searcher->create_context(); - - float epison = 1e-6; - for (size_t i = 0; i < query_cnt; i++) { - auto &vec = vecs[i]; - std::string new_query; - IndexQueryMeta new_meta; - ASSERT_EQ(0, reformer->transform(vec.data(), qmeta, &new_query, &new_meta)); - - size_t topk = 50; - knnCtx->set_topk(topk); - ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, knnCtx)); - auto &results = knnCtx->result(); - ASSERT_EQ(topk, results.size()); - ASSERT_EQ(i, results[0].key()); - ASSERT_NEAR(0, results[0].score(), epison); - } -} - -} // namespace core -} // namespace zvec - -#if defined(__GNUC__) || defined(__GNUG__) -#pragma GCC diagnostic pop -#endif \ No newline at end of file From 1b03901f2f0f76285366ce24d85307c1f505ba8d Mon Sep 17 00:00:00 2001 From: "yinzefeng.yzf" Date: Wed, 4 Mar 2026 11:42:18 +0800 Subject: [PATCH 02/34] rm hnsw searcher --- src/core/algorithm/hnsw/hnsw_searcher.cc | 460 ---------------- src/core/algorithm/hnsw/hnsw_searcher.h | 139 ----- .../algorithm/hnsw/hnsw_searcher_entity.cc | 515 ------------------ .../algorithm/hnsw/hnsw_searcher_entity.h | 155 ------ 4 files changed, 1269 deletions(-) delete mode 100644 src/core/algorithm/hnsw/hnsw_searcher.cc delete mode 100644 src/core/algorithm/hnsw/hnsw_searcher.h delete mode 100644 src/core/algorithm/hnsw/hnsw_searcher_entity.cc delete mode 100644 src/core/algorithm/hnsw/hnsw_searcher_entity.h diff --git a/src/core/algorithm/hnsw/hnsw_searcher.cc b/src/core/algorithm/hnsw/hnsw_searcher.cc deleted file mode 100644 index c68a146d..00000000 --- a/src/core/algorithm/hnsw/hnsw_searcher.cc +++ /dev/null @@ -1,460 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "hnsw_searcher.h" -#include "hnsw_algorithm.h" -#include "hnsw_index_provider.h" -#include "hnsw_params.h" - -namespace zvec { -namespace core { - -HnswSearcher::HnswSearcher() = default; - -HnswSearcher::~HnswSearcher() = default; - -int HnswSearcher::init(const ailego::Params &search_params) { - params_ = search_params; - params_.get(PARAM_HNSW_SEARCHER_EF, &ef_); - params_.get(PARAM_HNSW_SEARCHER_MAX_SCAN_RATIO, &max_scan_ratio_); - params_.get(PARAM_HNSW_SEARCHER_VISIT_BLOOMFILTER_ENABLE, &bf_enabled_); - params_.get(PARAM_HNSW_SEARCHER_CHECK_CRC_ENABLE, &check_crc_enabled_); - params_.get(PARAM_HNSW_SEARCHER_NEIGHBORS_IN_MEMORY_ENABLE, - &neighbors_in_memory_enabled_); - params_.get(PARAM_HNSW_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB, - &bf_negative_probability_); - params_.get(PARAM_HNSW_SEARCHER_BRUTE_FORCE_THRESHOLD, - &bruteforce_threshold_); - params_.get(PARAM_HNSW_SEARCHER_FORCE_PADDING_RESULT_ENABLE, - &force_padding_topk_enabled_); - - if (ef_ == 0) { - ef_ = HnswEntity::kDefaultEf; - } - if (bf_negative_probability_ <= 0.0f || bf_negative_probability_ >= 1.0f) { - LOG_ERROR("[%s] must be in range (0,1)", - PARAM_HNSW_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB.c_str()); - return IndexError_InvalidArgument; - } - - entity_.set_neighbors_in_memory(neighbors_in_memory_enabled_); - - state_ = STATE_INITED; - - LOG_DEBUG( - "Init params: ef=%u maxScanRatio=%f bfEnabled=%u checkCrcEnabled=%u " - "neighborsInMemoryEnabled=%u bfNagtiveProb=%f bruteForceThreshold=%u " - "forcePadding=%u", - ef_, max_scan_ratio_, bf_enabled_, check_crc_enabled_, - neighbors_in_memory_enabled_, bf_negative_probability_, - bruteforce_threshold_, force_padding_topk_enabled_); - - return 0; -} - -void HnswSearcher::print_debug_info() { - for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { - Neighbors neighbours = entity_.get_neighbors(0, id); - std::cout << "node: " << id << "; "; - for (uint32_t i = 0; i < neighbours.size(); ++i) { - std::cout << neighbours[i]; - - if (i == neighbours.size() - 1) { - std::cout << std::endl; - } else { - std::cout << ", "; - } - } - } -} - -int HnswSearcher::cleanup() { - LOG_INFO("Begin HnswSearcher:cleanup"); - - metric_.reset(); - meta_.clear(); - stats_.clear_attributes(); - stats_.set_loaded_count(0UL); - stats_.set_loaded_costtime(0UL); - max_scan_ratio_ = HnswEntity::kDefaultScanRatio; - max_scan_num_ = 0U; - ef_ = HnswEntity::kDefaultEf; - bf_enabled_ = false; - bf_negative_probability_ = HnswEntity::kDefaultBFNegativeProbability; - bruteforce_threshold_ = HnswEntity::kDefaultBruteForceThreshold; - check_crc_enabled_ = false; - neighbors_in_memory_enabled_ = false; - entity_.cleanup(); - state_ = STATE_INIT; - - LOG_INFO("End HnswSearcher:cleanup"); - - return 0; -} - -int HnswSearcher::load(IndexStorage::Pointer container, - IndexMetric::Pointer metric) { - if (state_ != STATE_INITED) { - LOG_ERROR("Init the searcher first before load index"); - return IndexError_Runtime; - } - - LOG_INFO("Begin HnswSearcher:load"); - - auto start_time = ailego::Monotime::MilliSeconds(); - - int ret = IndexHelper::DeserializeFromStorage(container.get(), &meta_); - if (ret != 0) { - LOG_ERROR("Failed to deserialize meta from container"); - return ret; - } - - ret = entity_.load(container, check_crc_enabled_); - if (ret != 0) { - LOG_ERROR("HnswSearcher load index failed"); - return ret; - } - - alg_ = HnswAlgorithm::UPointer(new HnswAlgorithm(entity_)); - - if (metric) { - metric_ = metric; - } else { - metric_ = IndexFactory::CreateMetric(meta_.metric_name()); - if (!metric_) { - LOG_ERROR("CreateMetric failed, name: %s", meta_.metric_name().c_str()); - return IndexError_NoExist; - } - ret = metric_->init(meta_, meta_.metric_params()); - if (ret != 0) { - LOG_ERROR("IndexMetric init failed, ret=%d", ret); - return ret; - } - if (metric_->query_metric()) { - metric_ = metric_->query_metric(); - } - } - - if (!metric_->is_matched(meta_)) { - LOG_ERROR("IndexMetric not match index meta"); - return IndexError_Mismatch; - } - - max_scan_num_ = static_cast(max_scan_ratio_ * entity_.doc_cnt()); - max_scan_num_ = std::max(4096U, max_scan_num_); - - stats_.set_loaded_count(entity_.doc_cnt()); - stats_.set_loaded_costtime(ailego::Monotime::MilliSeconds() - start_time); - state_ = STATE_LOADED; - magic_ = IndexContext::GenerateMagic(); - - LOG_INFO("End HnswSearcher::load"); - - return 0; -} - -int HnswSearcher::unload() { - LOG_INFO("HnswSearcher unload index"); - - meta_.clear(); - entity_.cleanup(); - metric_.reset(); - max_scan_num_ = 0; - stats_.set_loaded_count(0UL); - stats_.set_loaded_costtime(0UL); - state_ = STATE_INITED; - - return 0; -} - -int HnswSearcher::update_context(HnswContext *ctx) const { - const HnswEntity::Pointer entity = entity_.clone(); - if (!entity) { - LOG_ERROR("Failed to clone search context entity"); - return IndexError_Runtime; - } - ctx->set_max_scan_num(max_scan_num_); - ctx->set_bruteforce_threshold(bruteforce_threshold_); - - return ctx->update_context(HnswContext::kSearcherContext, meta_, metric_, - entity, magic_); -} - -int HnswSearcher::search_impl(const void *query, const IndexQueryMeta &qmeta, - uint32_t count, Context::Pointer &context) const { - if (ailego_unlikely(!query || !context)) { - LOG_ERROR("The context is not created by this searcher"); - return IndexError_Mismatch; - } - HnswContext *ctx = dynamic_cast(context.get()); - ailego_do_if_false(ctx) { - LOG_ERROR("Cast context to HnswContext failed"); - return IndexError_Cast; - } - - if (entity_.doc_cnt() <= ctx->get_bruteforce_threshold()) { - return search_bf_impl(query, qmeta, count, context); - } - - if (ctx->magic() != magic_) { - //! context is created by another searcher or streamer - int ret = update_context(ctx); - if (ret != 0) { - return ret; - } - } - - ctx->clear(); - ctx->resize_results(count); - for (size_t q = 0; q < count; ++q) { - ctx->reset_query(query); - int ret = alg_->search(ctx); - if (ailego_unlikely(ret != 0)) { - LOG_ERROR("Hnsw searcher fast search failed"); - return ret; - } - ctx->topk_to_result(q); - query = static_cast(query) + qmeta.element_size(); - } - - if (ailego_unlikely(ctx->error())) { - return IndexError_Runtime; - } - - return 0; -} - -int HnswSearcher::search_bf_impl(const void *query, const IndexQueryMeta &qmeta, - uint32_t count, - Context::Pointer &context) const { - if (ailego_unlikely(!query || !context)) { - LOG_ERROR("The context is not created by this searcher"); - return IndexError_Mismatch; - } - HnswContext *ctx = dynamic_cast(context.get()); - ailego_do_if_false(ctx) { - LOG_ERROR("Cast context to HnswContext failed"); - return IndexError_Cast; - } - if (ctx->magic() != magic_) { - //! context is created by another searcher or streamer - int ret = update_context(ctx); - if (ret != 0) { - return ret; - } - } - - ctx->clear(); - ctx->resize_results(count); - - if (ctx->group_by_search()) { - if (!ctx->group_by().is_valid()) { - LOG_ERROR("Invalid group-by function"); - return IndexError_InvalidArgument; - } - - std::function group_by = [&](node_id_t id) { - return ctx->group_by()(entity_.get_key(id)); - }; - - for (size_t q = 0; q < count; ++q) { - ctx->reset_query(query); - ctx->group_topk_heaps().clear(); - - for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { - if (entity_.get_key(id) == kInvalidKey) { - continue; - } - if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { - dist_t dist = ctx->dist_calculator().batch_dist(id); - - std::string group_id = group_by(id); - - auto &topk_heap = ctx->group_topk_heaps()[group_id]; - if (topk_heap.empty()) { - topk_heap.limit(ctx->group_topk()); - } - topk_heap.emplace_back(id, dist); - } - } - ctx->topk_to_result(q); - query = static_cast(query) + qmeta.element_size(); - } - } else { - for (size_t q = 0; q < count; ++q) { - ctx->reset_query(query); - ctx->topk_heap().clear(); - for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { - if (entity_.get_key(id) == kInvalidKey) { - continue; - } - if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { - dist_t dist = ctx->dist_calculator().batch_dist(id); - ctx->topk_heap().emplace(id, dist); - } - } - ctx->topk_to_result(q); - query = static_cast(query) + qmeta.element_size(); - } - } - - if (ailego_unlikely(ctx->error())) { - return IndexError_Runtime; - } - - return 0; -} - -int HnswSearcher::search_bf_by_p_keys_impl( - const void *query, const std::vector> &p_keys, - const IndexQueryMeta &qmeta, uint32_t count, - Context::Pointer &context) const { - if (ailego_unlikely(!query || !context)) { - LOG_ERROR("The context is not created by this searcher"); - return IndexError_Mismatch; - } - - if (ailego_unlikely(p_keys.size() != count)) { - LOG_ERROR("The size of p_keys is not equal to count"); - return IndexError_InvalidArgument; - } - - HnswContext *ctx = dynamic_cast(context.get()); - ailego_do_if_false(ctx) { - LOG_ERROR("Cast context to HnswContext failed"); - return IndexError_Cast; - } - if (ctx->magic() != magic_) { - //! context is created by another searcher or streamer - int ret = update_context(ctx); - if (ret != 0) { - return ret; - } - } - - ctx->clear(); - ctx->resize_results(count); - - if (ctx->group_by_search()) { - if (!ctx->group_by().is_valid()) { - LOG_ERROR("Invalid group-by function"); - return IndexError_InvalidArgument; - } - - std::function group_by = [&](node_id_t id) { - return ctx->group_by()(entity_.get_key(id)); - }; - - for (size_t q = 0; q < count; ++q) { - ctx->reset_query(query); - ctx->group_topk_heaps().clear(); - - for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { - uint64_t pk = p_keys[q][idx]; - if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { - node_id_t id = entity_.get_id(pk); - if (id != kInvalidNodeId) { - dist_t dist = ctx->dist_calculator().batch_dist(id); - std::string group_id = group_by(id); - - auto &topk_heap = ctx->group_topk_heaps()[group_id]; - if (topk_heap.empty()) { - topk_heap.limit(ctx->group_topk()); - } - topk_heap.emplace_back(id, dist); - } - } - } - ctx->topk_to_result(q); - query = static_cast(query) + qmeta.element_size(); - } - } else { - for (size_t q = 0; q < count; ++q) { - ctx->reset_query(query); - ctx->topk_heap().clear(); - for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { - uint64_t pk = p_keys[q][idx]; - if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { - node_id_t id = entity_.get_id(pk); - if (id != kInvalidNodeId) { - dist_t dist = ctx->dist_calculator().batch_dist(id); - ctx->topk_heap().emplace(id, dist); - } - } - } - ctx->topk_to_result(q); - query = static_cast(query) + qmeta.element_size(); - } - } - - if (ailego_unlikely(ctx->error())) { - return IndexError_Runtime; - } - - return 0; -} - -IndexSearcher::Context::Pointer HnswSearcher::create_context() const { - if (ailego_unlikely(state_ != STATE_LOADED)) { - LOG_ERROR("Load the index first before create context"); - return Context::Pointer(); - } - const HnswEntity::Pointer search_ctx_entity = entity_.clone(); - if (!search_ctx_entity) { - LOG_ERROR("Failed to create search context entity"); - return Context::Pointer(); - } - HnswContext *ctx = new (std::nothrow) - HnswContext(meta_.dimension(), metric_, search_ctx_entity); - if (ailego_unlikely(ctx == nullptr)) { - LOG_ERROR("Failed to new HnswContext"); - return Context::Pointer(); - } - ctx->set_ef(ef_); - ctx->set_max_scan_num(max_scan_num_); - uint32_t filter_mode = - bf_enabled_ ? VisitFilter::BloomFilter : VisitFilter::ByteMap; - ctx->set_filter_mode(filter_mode); - ctx->set_filter_negative_probability(bf_negative_probability_); - ctx->set_magic(magic_); - ctx->set_force_padding_topk(force_padding_topk_enabled_); - ctx->set_bruteforce_threshold(bruteforce_threshold_); - if (ailego_unlikely(ctx->init(HnswContext::kSearcherContext)) != 0) { - LOG_ERROR("Init HnswContext failed"); - delete ctx; - return Context::Pointer(); - } - - return Context::Pointer(ctx); -} - -IndexProvider::Pointer HnswSearcher::create_provider(void) const { - LOG_DEBUG("HnswSearcher create provider"); - - auto entity = entity_.clone(); - if (ailego_unlikely(!entity)) { - LOG_ERROR("Clone HnswEntity failed"); - return Provider::Pointer(); - } - return Provider::Pointer( - new (std::nothrow) HnswIndexProvider(meta_, entity, "HnswSearcher")); -} - -const void *HnswSearcher::get_vector(uint64_t key) const { - return entity_.get_vector_by_key(key); -} - -INDEX_FACTORY_REGISTER_SEARCHER(HnswSearcher); - -} // namespace core -} // namespace zvec diff --git a/src/core/algorithm/hnsw/hnsw_searcher.h b/src/core/algorithm/hnsw/hnsw_searcher.h deleted file mode 100644 index d79526df..00000000 --- a/src/core/algorithm/hnsw/hnsw_searcher.h +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include -#include "hnsw_searcher_entity.h" -#include "hnsw_streamer.h" - -namespace zvec { -namespace core { - -class HnswSearcher : public IndexSearcher { - public: - using ContextPointer = IndexSearcher::Context::Pointer; - - public: - HnswSearcher(void); - ~HnswSearcher(void); - - HnswSearcher(const HnswSearcher &) = delete; - HnswSearcher &operator=(const HnswSearcher &) = delete; - - protected: - //! Initialize Searcher - virtual int init(const ailego::Params ¶ms) override; - - //! Cleanup Searcher - virtual int cleanup(void) override; - - //! Load Index from storage - virtual int load(IndexStorage::Pointer container, - IndexMetric::Pointer metric) override; - - //! Unload index from storage - virtual int unload(void) override; - - //! KNN Search - virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, - ContextPointer &context) const override { - return search_impl(query, qmeta, 1, context); - } - - //! KNN Search - virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, - uint32_t count, - ContextPointer &context) const override; - - //! Linear Search - virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, - ContextPointer &context) const override { - return search_bf_impl(query, qmeta, 1, context); - } - - //! Linear Search - virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, - uint32_t count, - ContextPointer &context) const override; - - //! Linear search by primary keys - virtual int search_bf_by_p_keys_impl( - const void *query, const std::vector> &p_keys, - const IndexQueryMeta &qmeta, ContextPointer &context) const override { - return search_bf_by_p_keys_impl(query, p_keys, qmeta, 1, context); - } - - //! Linear search by primary keys - virtual int search_bf_by_p_keys_impl( - const void *query, const std::vector> &p_keys, - const IndexQueryMeta &qmeta, uint32_t count, - ContextPointer &context) const override; - - //! Fetch vector by key - virtual const void *get_vector(uint64_t key) const override; - - //! Create a searcher context - virtual ContextPointer create_context() const override; - - //! Create a new iterator - virtual IndexProvider::Pointer create_provider(void) const override; - - //! Retrieve statistics - virtual const Stats &stats(void) const override { - return stats_; - } - - //! Retrieve meta of index - virtual const IndexMeta &meta(void) const override { - return meta_; - } - - //! Retrieve params of index - virtual const ailego::Params ¶ms(void) const override { - return params_; - } - - virtual void print_debug_info() override; - - private: - //! To share ctx across streamer/searcher, we need to update the context for - //! current streamer/searcher - int update_context(HnswContext *ctx) const; - - private: - enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; - - HnswSearcherEntity entity_{}; - HnswAlgorithm::UPointer alg_; // impl graph algorithm - - IndexMetric::Pointer metric_{}; - IndexMeta meta_{}; - ailego::Params params_{}; - Stats stats_; - uint32_t ef_{HnswEntity::kDefaultEf}; - uint32_t max_scan_num_{0U}; - uint32_t bruteforce_threshold_{HnswEntity::kDefaultBruteForceThreshold}; - float max_scan_ratio_{HnswEntity::kDefaultScanRatio}; - bool bf_enabled_{false}; - bool check_crc_enabled_{false}; - bool neighbors_in_memory_enabled_{false}; - bool force_padding_topk_enabled_{false}; - float bf_negative_probability_{HnswEntity::kDefaultBFNegativeProbability}; - uint32_t magic_{0U}; - - State state_{STATE_INIT}; -}; - -} // namespace core -} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/hnsw/hnsw_searcher_entity.cc b/src/core/algorithm/hnsw/hnsw_searcher_entity.cc deleted file mode 100644 index 6661c1db..00000000 --- a/src/core/algorithm/hnsw/hnsw_searcher_entity.cc +++ /dev/null @@ -1,515 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "hnsw_searcher_entity.h" -#include -#include "utility/sparse_utility.h" - -namespace zvec { -namespace core { - -HnswSearcherEntity::HnswSearcherEntity() {} - -int HnswSearcherEntity::cleanup(void) { - storage_.reset(); - vectors_.reset(); - keys_.reset(); - neighbors_.reset(); - neighbors_meta_.reset(); - neighbors_in_memory_enabled_ = false; - loaded_ = false; - - this->HnswEntity::cleanup(); - - return 0; -} - -key_t HnswSearcherEntity::get_key(node_id_t id) const { - const void *key; - if (ailego_unlikely(keys_->read(id * sizeof(key_t), &key, sizeof(key_t)) != - sizeof(key_t))) { - LOG_ERROR("Read key from segment failed"); - return kInvalidKey; - } - return *(reinterpret_cast(key)); -} - -//! Get vector local id by key -node_id_t HnswSearcherEntity::get_id(key_t key) const { - if (ailego_unlikely(!mapping_)) { - LOG_ERROR("Index missing mapping segment"); - return kInvalidNodeId; - } - - //! Do binary search - node_id_t start = 0UL; - node_id_t end = doc_cnt(); - const void *data; - node_id_t idx = 0u; - while (start < end) { - idx = start + (end - start) / 2; - if (ailego_unlikely( - mapping_->read(idx * sizeof(node_id_t), &data, sizeof(node_id_t)) != - sizeof(node_id_t))) { - LOG_ERROR("Read key from segment failed"); - return kInvalidNodeId; - } - const key_t *mkey; - node_id_t local_id = *reinterpret_cast(data); - if (ailego_unlikely(keys_->read(local_id * sizeof(key_t), - (const void **)(&mkey), - sizeof(key_t)) != sizeof(key_t))) { - LOG_ERROR("Read key from segment failed"); - return kInvalidNodeId; - } - if (*mkey < key) { - start = idx + 1; - } else if (*mkey > key) { - end = idx; - } else { - return local_id; - } - } - return kInvalidNodeId; -} - -const void *HnswSearcherEntity::get_vector_by_key(key_t key) const { - node_id_t local_id = get_id(key); - if (ailego_unlikely(local_id == kInvalidNodeId)) { - return nullptr; - } - - return get_vector(local_id); -} - -const void *HnswSearcherEntity::get_vector(node_id_t id) const { - size_t read_size = vector_size(); - size_t offset = node_size() * id; - - const void *vec; - if (ailego_unlikely(vectors_->read(offset, &vec, read_size) != read_size)) { - LOG_ERROR("Read vector from segment failed"); - return nullptr; - } - return vec; -} - -int HnswSearcherEntity::get_vector(const node_id_t id, - IndexStorage::MemoryBlock &block) const { - const void *vec = get_vector(id); - block.reset((void *)vec); - return 0; -} - -const void *HnswSearcherEntity::get_vectors() const { - const void *vec; - size_t len = node_size() * doc_cnt(); - if (vectors_->read(0, &vec, len) != len) { - LOG_ERROR("Read vectors from segment failed"); - return nullptr; - } - return vec; -} - -int HnswSearcherEntity::get_vector(const node_id_t *ids, uint32_t count, - const void **vecs) const { - ailego_assert_with(count <= segment_datas_.size(), "invalid count"); - - size_t read_size = vector_size(); - - for (uint32_t i = 0; i < count; ++i) { - segment_datas_[i].offset = node_size() * ids[i]; - segment_datas_[i].length = read_size; - - ailego_assert_with(segment_datas_[i].offset < vectors_->data_size(), - "invalid offset"); - } - if (ailego_unlikely(!vectors_->read(&segment_datas_[0], count))) { - LOG_ERROR("Read vectors from segment failed"); - return IndexError_ReadData; - } - for (uint32_t i = 0; i < count; ++i) { - vecs[i] = segment_datas_[i].data; - } - - return 0; -} - -int HnswSearcherEntity::get_vector( - const node_id_t *ids, uint32_t count, - std::vector &vec_blocks) const { - const void *vecs[count]; - get_vector(ids, count, vecs); - for (uint32_t i = 0; i < count; ++i) { - vec_blocks.emplace_back(IndexStorage::MemoryBlock((void *)vecs[i])); - } - return 0; -} - -const Neighbors HnswSearcherEntity::get_neighbors(level_t level, - node_id_t id) const { - if (level == 0) { - if (neighbors_in_memory_enabled_) { - auto hd = reinterpret_cast( - fixed_neighbors_.get() + neighbors_size() * id); - return {hd->neighbor_cnt, hd->neighbors}; - } - - const GraphNeighborMeta *m; - if (ailego_unlikely(neighbors_meta_->read(id * sizeof(GraphNeighborMeta), - (const void **)(&m), - sizeof(GraphNeighborMeta)) != - sizeof(GraphNeighborMeta))) { - LOG_ERROR("Read neighbors meta from segment failed"); - return {0, nullptr}; - } - - const void *data; - if (ailego_unlikely(neighbors_->read(m->offset, &data, - m->neighbor_cnt * sizeof(node_id_t)) != - m->neighbor_cnt * sizeof(node_id_t))) { - LOG_ERROR("Read neighbors from segment failed"); - return {0, nullptr}; - } - return {static_cast(m->neighbor_cnt), - reinterpret_cast(data)}; - } - - //! Read level > 0 neighbors - const HnswNeighborMeta *m; - if (ailego_unlikely(upper_neighbors_meta_->read(id * sizeof(HnswNeighborMeta), - (const void **)(&m), - sizeof(HnswNeighborMeta)) != - sizeof(HnswNeighborMeta))) { - LOG_ERROR("Read neighbors meta from segment failed"); - return {0, nullptr}; - } - - ailego_assert_with(level <= m->level, "invalid level"); - size_t offset = m->offset + (level - 1) * upper_neighbors_size(); - ailego_assert_with(offset <= upper_neighbors_->data_size(), "invalid offset"); - const void *data; - if (ailego_unlikely( - upper_neighbors_->read(offset, &data, upper_neighbors_size()) != - upper_neighbors_size())) { - LOG_ERROR("Read neighbors from segment failed"); - return {0, nullptr}; - } - - auto hd = reinterpret_cast(data); - return {hd->neighbor_cnt, hd->neighbors}; -} - -int HnswSearcherEntity::load(const IndexStorage::Pointer &container, - bool check_crc) { - storage_ = container; - - int ret = load_segments(check_crc); - if (ret != 0) { - return ret; - } - - loaded_ = true; - - LOG_INFO( - "Index info: docCnt=%u entryPoint=%u maxLevel=%d efConstruct=%zu " - "l0NeighborCnt=%zu upperNeighborCnt=%zu scalingFactor=%zu " - "vectorSize=%zu nodeSize=%zu vectorSegmentSize=%zu keySegmentSize=%zu " - "neighborsSegmentSize=%zu neighborsMetaSegmentSize=%zu ", - doc_cnt(), entry_point(), cur_max_level(), ef_construction(), - l0_neighbor_cnt(), upper_neighbor_cnt(), scaling_factor(), vector_size(), - node_size(), vectors_->data_size(), keys_->data_size(), - neighbors_ == nullptr ? 0 : neighbors_->data_size(), - neighbors_meta_ == nullptr ? 0 : neighbors_meta_->data_size()); - - return 0; -} - -int HnswSearcherEntity::load_segments(bool check_crc) { - //! load header - const void *data = nullptr; - HNSWHeader hd; - auto graph_hd_segment = storage_->get(kGraphHeaderSegmentId); - if (!graph_hd_segment || graph_hd_segment->data_size() < sizeof(hd.graph)) { - LOG_ERROR("Miss or invalid segment %s", kGraphHeaderSegmentId.c_str()); - return IndexError_InvalidFormat; - } - if (graph_hd_segment->read(0, reinterpret_cast(&data), - sizeof(hd.graph)) != sizeof(hd.graph)) { - LOG_ERROR("Read segment %s failed", kGraphHeaderSegmentId.c_str()); - return IndexError_ReadData; - } - memcpy(&hd.graph, data, sizeof(hd.graph)); - - auto hnsw_hd_segment = storage_->get(kHnswHeaderSegmentId); - if (!hnsw_hd_segment || hnsw_hd_segment->data_size() < sizeof(hd.hnsw)) { - LOG_ERROR("Miss or invalid segment %s", kHnswHeaderSegmentId.c_str()); - return IndexError_InvalidFormat; - } - if (hnsw_hd_segment->read(0, reinterpret_cast(&data), - sizeof(hd.hnsw)) != sizeof(hd.hnsw)) { - LOG_ERROR("Read segment %s failed", kHnswHeaderSegmentId.c_str()); - return IndexError_ReadData; - } - memcpy(&hd.hnsw, data, sizeof(hd.hnsw)); - *mutable_header() = hd; - segment_datas_.resize(std::max(l0_neighbor_cnt(), upper_neighbor_cnt())); - - vectors_ = storage_->get(kGraphFeaturesSegmentId); - if (!vectors_) { - LOG_ERROR("IndexStorage get segment %s failed", - kGraphFeaturesSegmentId.c_str()); - return IndexError_InvalidFormat; - } - keys_ = storage_->get(kGraphKeysSegmentId); - if (!keys_) { - LOG_ERROR("IndexStorage get segment %s failed", - kGraphKeysSegmentId.c_str()); - return IndexError_InvalidFormat; - } - - neighbors_ = storage_->get(kGraphNeighborsSegmentId); - if (!neighbors_ || (neighbors_->data_size() == 0 && doc_cnt() > 1)) { - LOG_ERROR("IndexStorage get segment %s failed or empty", - kGraphNeighborsSegmentId.c_str()); - return IndexError_InvalidArgument; - } - neighbors_meta_ = storage_->get(kGraphOffsetsSegmentId); - if (!neighbors_meta_ || - neighbors_meta_->data_size() < sizeof(GraphNeighborMeta) * doc_cnt()) { - LOG_ERROR("IndexStorage get segment %s failed or invalid size", - kGraphOffsetsSegmentId.c_str()); - return IndexError_InvalidArgument; - } - - upper_neighbors_ = storage_->get(kHnswNeighborsSegmentId); - if (!upper_neighbors_ || - (upper_neighbors_->data_size() == 0 && cur_max_level() > 0)) { - LOG_ERROR("IndexStorage get segment %s failed or empty", - kHnswNeighborsSegmentId.c_str()); - return IndexError_InvalidArgument; - } - - upper_neighbors_meta_ = storage_->get(kHnswOffsetsSegmentId); - if (!upper_neighbors_meta_ || upper_neighbors_meta_->data_size() < - sizeof(HnswNeighborMeta) * doc_cnt()) { - LOG_ERROR("IndexStorage get segment %s failed or invalid size", - kHnswOffsetsSegmentId.c_str()); - return IndexError_InvalidArgument; - } - - mapping_ = storage_->get(kGraphMappingSegmentId); - if (!mapping_ || mapping_->data_size() < sizeof(node_id_t) * doc_cnt()) { - LOG_ERROR("IndexStorage get segment %s failed or invalid size", - kGraphMappingSegmentId.c_str()); - return IndexError_InvalidArgument; - } - - if (check_crc) { - std::vector segments; - segments.emplace_back(graph_hd_segment); - segments.emplace_back(hnsw_hd_segment); - segments.emplace_back(vectors_); - segments.emplace_back(keys_); - - segments.emplace_back(neighbors_); - segments.emplace_back(neighbors_meta_); - segments.emplace_back(upper_neighbors_); - segments.emplace_back(upper_neighbors_meta_); - - if (!do_crc_check(segments)) { - LOG_ERROR("Check index crc failed, the index may broken"); - return IndexError_Runtime; - } - } - - if (neighbors_in_memory_enabled_) { - int ret = load_and_flat_neighbors(); - if (ret != 0) { - return ret; - } - } - - return 0; -} - -int HnswSearcherEntity::load_and_flat_neighbors() { - fixed_neighbors_.reset( - new (std::nothrow) char[neighbors_size() * doc_cnt()]{}, - std::default_delete()); - if (!fixed_neighbors_) { - LOG_ERROR("Malloc memory failed"); - return IndexError_NoMemory; - } - - //! Get a new segemnt to release the buffer after loading neighbors - auto neighbors_meta = storage_->get(kGraphOffsetsSegmentId); - if (!neighbors_meta) { - LOG_ERROR("IndexStorage get segment graph.offsets failed"); - return IndexError_InvalidArgument; - } - - const GraphNeighborMeta *neighbors_index = nullptr; - if (neighbors_meta->read(0, reinterpret_cast(&neighbors_index), - neighbors_meta->data_size()) != - neighbors_meta->data_size()) { - LOG_ERROR("Read segment %s data failed", kGraphOffsetsSegmentId.c_str()); - return IndexError_InvalidArgument; - } - - const char *neighbor_data; - for (node_id_t id = 0; id < doc_cnt(); ++id) { - size_t rd_size = neighbors_index[id].neighbor_cnt * sizeof(node_id_t); - if (ailego_unlikely( - neighbors_->read(neighbors_index[id].offset, - reinterpret_cast(&neighbor_data), - rd_size) != rd_size)) { - LOG_ERROR("Read neighbors from segment failed"); - return IndexError_ReadData; - } - // copy level 0 neighbors to fixed size neighbors memory - char *dst = fixed_neighbors_.get() + neighbors_size() * id; - *reinterpret_cast(dst) = neighbors_index[id].neighbor_cnt; - memcpy(dst + sizeof(uint32_t), neighbor_data, rd_size); - } - - return 0; -} - -int HnswSearcherEntity::get_fixed_neighbors( - std::vector *fixed_neighbors) const { - //! Get a new segemnt to release the buffer after loading neighbors - auto neighbors_meta = storage_->get(kGraphOffsetsSegmentId); - if (!neighbors_meta) { - LOG_ERROR("IndexStorage get segment graph.offsets failed"); - return IndexError_InvalidArgument; - } - - const GraphNeighborMeta *neighbors_index = nullptr; - size_t meta_size = neighbors_meta->data_size(); - if (neighbors_meta->read(0, reinterpret_cast(&neighbors_index), - meta_size) != meta_size) { - LOG_ERROR("Read segment %s data failed", kGraphOffsetsSegmentId.c_str()); - return IndexError_InvalidArgument; - } - - size_t fixed_neighbor_cnt = l0_neighbor_cnt(); - fixed_neighbors->resize((fixed_neighbor_cnt + 1) * doc_cnt(), kInvalidNodeId); - - size_t neighbors_cnt_offset = fixed_neighbor_cnt * doc_cnt(); - size_t total_neighbor_cnt = 0; - for (node_id_t id = 0; id < doc_cnt(); ++id) { - size_t cur_neighbor_cnt = neighbors_index[id].neighbor_cnt; - if (cur_neighbor_cnt == 0) { - (*fixed_neighbors)[neighbors_cnt_offset + id] = 0; - continue; - } - size_t rd_size = cur_neighbor_cnt * sizeof(node_id_t); - const uint32_t *neighbors; - if (neighbors_->read(neighbors_index[id].offset, - reinterpret_cast(&neighbors), - rd_size) != rd_size) { - LOG_ERROR("Read neighbors from segment failed"); - return IndexError_ReadData; - } - - // copy level 0 neighbors to fixed size neighbors memory - auto it = fixed_neighbors->begin() + id * fixed_neighbor_cnt; - std::copy(neighbors, neighbors + cur_neighbor_cnt, it); - - (*fixed_neighbors)[neighbors_cnt_offset + id] = cur_neighbor_cnt; - total_neighbor_cnt += cur_neighbor_cnt; - } - LOG_INFO("total neighbor cnt: %zu, average neighbor cnt: %zu", - total_neighbor_cnt, total_neighbor_cnt / doc_cnt()); - - return 0; -} - -bool HnswSearcherEntity::do_crc_check( - std::vector &segments) const { - constexpr size_t blk_size = 4096; - const void *data; - for (auto &segment : segments) { - size_t offset = 0; - size_t rd_size; - uint32_t crc = 0; - while (offset < segment->data_size()) { - size_t size = std::min(blk_size, segment->data_size() - offset); - if ((rd_size = segment->read(offset, &data, size)) <= 0) { - break; - } - offset += rd_size; - crc = ailego::Crc32c::Hash(data, rd_size, crc); - } - if (crc != segment->data_crc()) { - return false; - } - } - return true; -} - -const HnswEntity::Pointer HnswSearcherEntity::clone() const { - auto vectors = vectors_->clone(); - if (ailego_unlikely(!vectors)) { - LOG_ERROR("clone segment %s failed", kGraphFeaturesSegmentId.c_str()); - return HnswEntity::Pointer(); - } - auto keys = keys_->clone(); - if (ailego_unlikely(!keys)) { - LOG_ERROR("clone segment %s failed", kGraphKeysSegmentId.c_str()); - return HnswEntity::Pointer(); - } - - auto mapping = mapping_->clone(); - if (ailego_unlikely(!mapping)) { - LOG_ERROR("clone segment %s failed", kGraphMappingSegmentId.c_str()); - return HnswEntity::Pointer(); - } - - auto neighbors = neighbors_->clone(); - if (ailego_unlikely(!neighbors)) { - LOG_ERROR("clone segment %s failed", kGraphNeighborsSegmentId.c_str()); - return HnswEntity::Pointer(); - } - auto upper_neighbors = upper_neighbors_->clone(); - if (ailego_unlikely(!neighbors)) { - LOG_ERROR("clone segment %s failed", kHnswNeighborsSegmentId.c_str()); - return HnswEntity::Pointer(); - } - auto neighbors_meta = neighbors_meta_->clone(); - if (ailego_unlikely(!neighbors_meta)) { - LOG_ERROR("clone segment %s failed", kGraphOffsetsSegmentId.c_str()); - return HnswEntity::Pointer(); - } - auto upper_neighbors_meta = upper_neighbors_meta_->clone(); - if (ailego_unlikely(!upper_neighbors_meta)) { - LOG_ERROR("clone segment %s failed", kHnswOffsetsSegmentId.c_str()); - return HnswEntity::Pointer(); - } - - SegmentGroupParam neighbor_group{neighbors, neighbors_meta, upper_neighbors, - upper_neighbors_meta}; - - HnswSearcherEntity *entity = new (std::nothrow) - HnswSearcherEntity(header(), vectors, keys, mapping, neighbor_group, - fixed_neighbors_, neighbors_in_memory_enabled_); - if (ailego_unlikely(!entity)) { - LOG_ERROR("HnswSearcherEntity new failed"); - } - - return HnswEntity::Pointer(entity); -} - -} // namespace core -} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/hnsw/hnsw_searcher_entity.h b/src/core/algorithm/hnsw/hnsw_searcher_entity.h deleted file mode 100644 index 6fcd6b9b..00000000 --- a/src/core/algorithm/hnsw/hnsw_searcher_entity.h +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include "hnsw_entity.h" - -namespace zvec { -namespace core { - -class HnswSearcherEntity : public HnswEntity { - public: - using Pointer = std::shared_ptr; - using SegmentPointer = IndexStorage::Segment::Pointer; - - public: - struct SegmentGroupParam { - SegmentGroupParam(SegmentPointer neighbors_in, - SegmentPointer neighbors_meta_in, - SegmentPointer upper_neighbors_in, - SegmentPointer upper_neighbors_meta_in) - : neighbors{neighbors_in}, - neighbors_meta{neighbors_meta_in}, - upper_neighbors{upper_neighbors_in}, - upper_neighbors_meta{upper_neighbors_meta_in} {} - - SegmentPointer neighbors{nullptr}; - SegmentPointer neighbors_meta{nullptr}; - SegmentPointer upper_neighbors{nullptr}; - SegmentPointer upper_neighbors_meta{nullptr}; - }; - - //! Constructor - HnswSearcherEntity(); - - //! Make a copy of searcher entity, to support thread-safe operation. - //! The segment in container cannot be read concurrenly - virtual const HnswEntity::Pointer clone() const override; - - //! Get primary key of the node id - virtual key_t get_key(node_id_t id) const override; - - //! Get vector local id by key - node_id_t get_id(key_t key) const; - - //! Get vector feature data by key - virtual const void *get_vector_by_key(key_t key) const override; - - //! Get vector feature data by id - virtual const void *get_vector(node_id_t id) const override; - - //! Get vector feature data by id - virtual int get_vector(const node_id_t *ids, uint32_t count, - const void **vecs) const override; - - virtual int get_vector(const node_id_t id, - IndexStorage::MemoryBlock &block) const override; - virtual int get_vector( - const node_id_t *ids, uint32_t count, - std::vector &vec_blocks) const override; - - //! Get all vectors - const void *get_vectors() const; - - //! Get the node id's neighbors on graph level - virtual const Neighbors get_neighbors(level_t level, - node_id_t id) const override; - - virtual int load(const IndexStorage::Pointer &container, - bool check_crc) override; - - int load_segments(bool check_crc); - - virtual int cleanup(void) override; - - public: - bool is_loaded() const { - return loaded_; - } - - void set_neighbors_in_memory(bool enabled) { - neighbors_in_memory_enabled_ = enabled; - } - - //! get fixed length neighbors data - int get_fixed_neighbors(std::vector *fixed_neighbors) const; - - private: - //! Constructor - HnswSearcherEntity(const HNSWHeader &hd, const SegmentPointer &vectors, - const SegmentPointer &keys, const SegmentPointer &mapping, - const SegmentGroupParam &neighbor_group, - const std::shared_ptr &fixed_neighbors, - bool neighbors_in_memory_enabled) - : HnswEntity(hd), - vectors_(vectors), - keys_(keys), - mapping_(mapping), - neighbors_(neighbor_group.neighbors), - neighbors_meta_(neighbor_group.neighbors_meta), - upper_neighbors_(neighbor_group.upper_neighbors), - upper_neighbors_meta_(neighbor_group.upper_neighbors_meta), - neighbors_in_memory_enabled_(neighbors_in_memory_enabled) { - segment_datas_.resize(std::max(l0_neighbor_cnt(), upper_neighbor_cnt()), - IndexStorage::SegmentData(0U, 0U)); - fixed_neighbors_ = fixed_neighbors; - } - - bool do_crc_check(std::vector &segments) const; - - inline size_t neighbors_size() const { - return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); - } - - inline size_t upper_neighbors_size() const { - return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); - } - - //! If neighbors_in_memory_enabled, load the level0 neighbors to memory - int load_and_flat_neighbors(void); - - public: - HnswSearcherEntity(const HnswSearcherEntity &) = delete; - HnswSearcherEntity &operator=(const HnswSearcherEntity &) = delete; - - private: - IndexStorage::Pointer storage_{}; - - SegmentPointer vectors_{}; - SegmentPointer keys_{}; - SegmentPointer mapping_{}; - - SegmentPointer neighbors_{}; - SegmentPointer neighbors_meta_{}; - SegmentPointer upper_neighbors_{}; - SegmentPointer upper_neighbors_meta_{}; - - mutable std::vector segment_datas_{}; - std::shared_ptr fixed_neighbors_{}; // level 0 fixed size neighbors - bool neighbors_in_memory_enabled_{false}; - bool loaded_{false}; -}; - -} // namespace core -} // namespace zvec From 956eb0a5839f4bba27dd1443120093118e44f3a5 Mon Sep 17 00:00:00 2001 From: Zefeng Yin Date: Wed, 4 Mar 2026 14:58:48 +0800 Subject: [PATCH 03/34] fix ut --- src/core/utility/file_read_storage.cc | 2 +- .../core/algorithm/hnsw/hnsw_streamer_test.cc | 63 +++++-------------- 2 files changed, 18 insertions(+), 47 deletions(-) diff --git a/src/core/utility/file_read_storage.cc b/src/core/utility/file_read_storage.cc index 8f2b79fe..f1b3f732 100644 --- a/src/core/utility/file_read_storage.cc +++ b/src/core/utility/file_read_storage.cc @@ -289,7 +289,7 @@ class FileReadStorage : public IndexStorage { } int append(const std::string & /*id*/, size_t /*size*/) override { - return IndexError_NotImplemented; + return 0; } void refresh(uint64_t) override { diff --git a/tests/core/algorithm/hnsw/hnsw_streamer_test.cc b/tests/core/algorithm/hnsw/hnsw_streamer_test.cc index c04f6712..5fb5e2f4 100644 --- a/tests/core/algorithm/hnsw/hnsw_streamer_test.cc +++ b/tests/core/algorithm/hnsw/hnsw_streamer_test.cc @@ -353,7 +353,7 @@ TEST_F(HnswStreamerTest, TestKnnSearch) { } float recall = totalHits * 1.0f / totalCnts; float topk1Recall = topk1Hits * 1.0f / cnt; - float cost = linearTotalTime * 1.0f / knnTotalTime; + // float cost = linearTotalTime * 1.0f / knnTotalTime; #if 0 printf("knnTotalTime=%zd linearTotalTime=%zd totalHits=%d totalCnts=%d " "R@%zd=%f R@1=%f cost=%f\n", @@ -439,7 +439,7 @@ TEST_F(HnswStreamerTest, TestAddAndSearch) { } float recall = totalHits * 1.0f / totalCnts; float topk1Recall = topk1Hits * 100.0f / cnt; - float cost = linearTotalTime * 1.0f / knnTotalTime; + // float cost = linearTotalTime * 1.0f / knnTotalTime; #if 0 printf("knnTotalTime=%zd linearTotalTime=%zd totalHits=%d totalCnts=%d " "R@%zd=%f R@1=%f cost=%f\n", @@ -1678,15 +1678,12 @@ TEST_F(HnswStreamerTest, TestDumpIndexAndAdd) { ASSERT_EQ(IndexError_Unsupported, code); // check dump index - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - auto container = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, container->init(ailego::Params())); - ASSERT_EQ(0, container->open(path1, false)); - ASSERT_NE(searcher, nullptr); - ASSERT_EQ(0, searcher->init(ailego::Params())); - ASSERT_EQ(0, searcher->load(container, IndexMetric::Pointer())); - auto iter = searcher->create_provider()->create_iterator(); + IndexStreamer::Pointer read_streamer = + IndexFactory::CreateStreamer("HnswStreamer"); + ASSERT_NE(read_streamer, nullptr); + ASSERT_EQ(0, read_streamer->init(*index_meta_ptr_, params)); + ASSERT_EQ(0, read_streamer->open(storage)); + auto iter = read_streamer->create_provider()->create_iterator(); size_t docs = 0; while (iter->is_valid()) { auto key = iter->key(); @@ -1777,15 +1774,12 @@ TEST_F(HnswStreamerTest, TestProvider) { streamer->close(); // check dump index - IndexSearcher::Pointer searcher = - IndexFactory::CreateSearcher("HnswSearcher"); - auto container = IndexFactory::CreateStorage("FileReadStorage"); - ASSERT_EQ(0, container->init(ailego::Params())); - ASSERT_EQ(0, container->open(path1, false)); - ASSERT_NE(searcher, nullptr); - ASSERT_EQ(0, searcher->init(ailego::Params())); - ASSERT_EQ(0, searcher->load(container, IndexMetric::Pointer())); - auto iter = searcher->create_provider()->create_iterator(); + IndexStreamer::Pointer read_streamer = + IndexFactory::CreateStreamer("HnswStreamer"); + ASSERT_NE(read_streamer, nullptr); + ASSERT_EQ(0, read_streamer->init(*index_meta_ptr_, params)); + ASSERT_EQ(0, read_streamer->open(storage)); + auto iter = read_streamer->create_provider()->create_iterator(); size_t cnt = 0; while (iter->is_valid()) { auto key = iter->key(); @@ -1812,29 +1806,6 @@ TEST_F(HnswStreamerTest, TestProvider) { iter->next(); } ASSERT_EQ(cnt, docs); - - - auto searcher_provider = searcher->create_provider(); - auto streamer_provider = streamer->create_provider(); - for (size_t i = 0; i < keys.size(); ++i) { - const float *d1 = - reinterpret_cast(searcher_provider->get_vector(keys[i])); - ASSERT_TRUE(d1); - for (size_t j = 0; j < dim; ++j) { - ASSERT_FLOAT_EQ(d1[j], keys[i]); - } - - const float *d2 = - reinterpret_cast(streamer_provider->get_vector(keys[i])); - ASSERT_TRUE(d2); - for (size_t j = 0; j < dim; ++j) { - ASSERT_FLOAT_EQ(d2[j], keys[i]); - } - } - - ASSERT_EQ(dim, streamer_provider->dimension()); - ASSERT_EQ(index_meta_ptr_->element_size(), streamer_provider->element_size()); - ASSERT_EQ(index_meta_ptr_->data_type(), streamer_provider->data_type()); } TEST_F(HnswStreamerTest, TestSharedContext) { @@ -2093,7 +2064,7 @@ TEST_F(HnswStreamerTest, TestBruteForceSetupInContext) { } float recall = totalHits * 1.0f / totalCnts; float topk1Recall = topk1Hits * 1.0f / cnt; - float cost = linearTotalTime * 1.0f / knnTotalTime; + // float cost = linearTotalTime * 1.0f / knnTotalTime; #if 0 printf("knnTotalTime=%zd linearTotalTime=%zd totalHits=%d totalCnts=%d " "R@%zd=%f R@1=%f cost=%f\n", @@ -2220,7 +2191,7 @@ TEST_F(HnswStreamerTest, TestKnnSearchCosine) { } float recall = totalHits * 1.0f / totalCnts; float topk1Recall = topk1Hits * 1.0f / query_cnt; - float cost = linearTotalTime * 1.0f / knnTotalTime; + // float cost = linearTotalTime * 1.0f / knnTotalTime; #if 0 printf("knnTotalTime=%zd linearTotalTime=%zd totalHits=%d totalCnts=%d " "R@%zd=%f R@1=%f cost=%f\n", @@ -3654,7 +3625,7 @@ TEST_F(HnswStreamerTest, TestAddAndSearchWithID) { } float recall = totalHits * 1.0f / totalCnts; float topk1Recall = topk1Hits * 100.0f / cnt; - float cost = linearTotalTime * 1.0f / knnTotalTime; + // float cost = linearTotalTime * 1.0f / knnTotalTime; #if 0 printf("knnTotalTime=%zd linearTotalTime=%zd totalHits=%d totalCnts=%d " "R@%zd=%f R@1=%f cost=%f\n", From 6a104ba7c45fc02d18e7824e4c351e3b05598daa Mon Sep 17 00:00:00 2001 From: Zefeng Yin Date: Tue, 10 Mar 2026 20:21:21 +0800 Subject: [PATCH 04/34] add HnswStreamerEntityNew --- .../hnsw/hnsw_streamer_entity_new.cc | 1038 +++++++++++++++++ .../algorithm/hnsw/hnsw_streamer_entity_new.h | 744 ++++++++++++ 2 files changed, 1782 insertions(+) create mode 100644 src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc create mode 100644 src/core/algorithm/hnsw/hnsw_streamer_entity_new.h diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc new file mode 100644 index 00000000..5d8fc439 --- /dev/null +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc @@ -0,0 +1,1038 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hnsw_streamer_entity_new.h" +#include + +// #define DEBUG_PRINT + +namespace zvec { +namespace core { + +const std::string HnswStreamerEntityNew::kGraphHeaderSegmentId = "graph.header"; +const std::string HnswStreamerEntityNew::kGraphFeaturesSegmentId = "graph.features"; +const std::string HnswStreamerEntityNew::kGraphKeysSegmentId = "graph.keys"; +const std::string HnswStreamerEntityNew::kGraphNeighborsSegmentId = "graph.neighbors"; +const std::string HnswStreamerEntityNew::kGraphOffsetsSegmentId = "graph.offsets"; +const std::string HnswStreamerEntityNew::kGraphMappingSegmentId = "graph.mapping"; +const std::string HnswStreamerEntityNew::kHnswHeaderSegmentId = "hnsw.header"; +const std::string HnswStreamerEntityNew::kHnswNeighborsSegmentId = "hnsw.neighbors"; +const std::string HnswStreamerEntityNew::kHnswOffsetsSegmentId = "hnsw.offsets"; + +int64_t HnswStreamerEntityNew::dump_segment(const IndexDumper::Pointer &dumper, + const std::string &segment_id, + const void *data, size_t size) const { + size_t len = dumper->write(data, size); + if (len != size) { + LOG_ERROR("Dump segment %s data failed, expect: %lu, actual: %lu", + segment_id.c_str(), size, len); + return IndexError_WriteData; + } + + size_t padding_size = AlignSize(size) - size; + if (padding_size > 0) { + std::string padding(padding_size, '\0'); + if (dumper->write(padding.data(), padding_size) != padding_size) { + LOG_ERROR("Append padding failed, size %lu", padding_size); + return IndexError_WriteData; + } + } + + uint32_t crc = ailego::Crc32c::Hash(data, size); + int ret = dumper->append(segment_id, size, padding_size, crc); + if (ret != 0) { + LOG_ERROR("Dump segment %s meta failed, ret=%d", segment_id.c_str(), ret); + return ret; + } + + return len + padding_size; +} + +int64_t HnswStreamerEntityNew::dump_header(const IndexDumper::Pointer &dumper, + const HNSWHeader &hd) const { + //! dump basic graph header. header is aligned and does not need padding + int64_t graph_hd_size = + dump_segment(dumper, kGraphHeaderSegmentId, &hd.graph, hd.graph.size); + if (graph_hd_size < 0) { + return graph_hd_size; + } + + //! dump basic graph header. header is aligned and does not need padding + int64_t hnsw_hd_size = + dump_segment(dumper, kHnswHeaderSegmentId, &hd.hnsw, hd.hnsw.size); + if (hnsw_hd_size < 0) { + return hnsw_hd_size; + } + + return graph_hd_size + hnsw_hd_size; +} + + +HnswStreamerEntityNew::HnswStreamerEntityNew(IndexStreamer::Stats &stats) + : stats_(stats) {} + +HnswStreamerEntityNew::~HnswStreamerEntityNew() {} + +int HnswStreamerEntityNew::init(size_t max_doc_cnt) { + if (std::pow(scaling_factor(), kMaxGraphLayers) < max_doc_cnt) { + LOG_ERROR("scalingFactor=%zu is too small", scaling_factor()); + return IndexError_InvalidArgument; + } + + std::lock_guard lock(mutex_); + broker_ = std::make_shared(stats_); + upper_neighbor_index_ = std::make_shared(); + keys_map_lock_ = std::make_shared(); + keys_map_ = std::make_shared>(); + if (!keys_map_ || !upper_neighbor_index_ || !broker_ || !keys_map_lock_) { + LOG_ERROR("HnswStreamerEntityNew new object failed"); + return IndexError_NoMemory; + } + keys_map_->set_empty_key(kInvalidKey); + + neighbor_size_ = neighbors_size(); + upper_neighbor_size_ = upper_neighbors_size(); + + //! vector + key + level 0 neighbors + size_t size = vector_size() + sizeof(key_t) + neighbor_size_; + + size = AlignSize(size); + set_node_size(size); + return 0; +} + +int HnswStreamerEntityNew::cleanup() { + std::lock_guard lock(mutex_); + mutable_header()->clear(); + chunk_size_ = kDefaultChunkSize; + node_index_mask_bits_ = 0U; + node_index_mask_ = 0U; + node_cnt_per_chunk_ = 0U; + neighbor_size_ = 0U; + upper_neighbor_size_ = 0U; + if (upper_neighbor_index_) { + upper_neighbor_index_->cleanup(); + } + if (keys_map_) { + keys_map_->clear(); + } + node_chunks_.clear(); + upper_neighbor_chunks_.clear(); + filter_same_key_ = false; + get_vector_enabled_ = false; + broker_.reset(); + + return 0; +} + +int HnswStreamerEntityNew::update_neighbors( + level_t level, node_id_t id, + const std::vector> &neighbors) { + std::vector buffer(neighbor_size_); + NeighborsHeader *hd = reinterpret_cast(buffer.data()); + hd->neighbor_cnt = neighbors.size(); + size_t i = 0; + for (; i < neighbors.size(); ++i) { + hd->neighbors[i] = neighbors[i].first; + } + + auto loc = get_neighbor_chunk_loc(level, id); + size_t size = reinterpret_cast(&hd->neighbors[i]) - &buffer[0]; + size_t ret = loc.first->write(loc.second, hd, size); + if (ailego_unlikely(ret != size)) { + LOG_ERROR("Write neighbor header failed, ret=%zu", ret); + + return IndexError_Runtime; + } + + return 0; +} + +const Neighbors HnswStreamerEntityNew::get_neighbors(level_t level, + node_id_t id) const { + Chunk *chunk = nullptr; + size_t offset = 0UL; + size_t neighbor_size = neighbor_size_; + if (level == 0UL) { + uint32_t chunk_idx = id >> node_index_mask_bits_; + offset = + (id & node_index_mask_) * node_size() + vector_size() + sizeof(key_t); + + sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); + ailego_assert_with(chunk_idx < node_chunks_.size(), "invalid chunk idx"); + chunk = node_chunks_[chunk_idx].get(); + } else { + auto p = get_upper_neighbor_chunk_loc(level, id); + chunk = upper_neighbor_chunks_[p.first].get(); + offset = p.second; + neighbor_size = upper_neighbor_size_; + } + + ailego_assert_with(offset < chunk->data_size(), "invalid chunk offset"); + IndexStorage::MemoryBlock neighbor_block; + size_t size = chunk->read(offset, neighbor_block, neighbor_size); + if (ailego_unlikely(size != neighbor_size)) { + LOG_ERROR("Read neighbor header failed, ret=%zu", size); + return Neighbors(); + } + return Neighbors(std::move(neighbor_block)); +} + +//! Get vector data by key +const void *HnswStreamerEntityNew::get_vector(node_id_t id) const { + auto loc = get_vector_chunk_loc(id); + const void *vec = nullptr; + ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), + "invalid chunk offset"); + + size_t read_size = vector_size(); + + size_t ret = node_chunks_[loc.first]->read(loc.second, &vec, read_size); + if (ailego_unlikely(ret != read_size)) { + LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", + loc.second, read_size, ret); + } + + return vec; +} + +int HnswStreamerEntityNew::get_vector(const node_id_t *ids, uint32_t count, + const void **vecs) const { + for (auto i = 0U; i < count; ++i) { + auto loc = get_vector_chunk_loc(ids[i]); + ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), + "invalid chunk offset"); + + size_t read_size = vector_size(); + + size_t ret = node_chunks_[loc.first]->read(loc.second, &vecs[i], read_size); + if (ailego_unlikely(ret != read_size)) { + LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", + loc.second, read_size, ret); + return IndexError_ReadData; + } + } + return 0; +} + +int HnswStreamerEntityNew::get_vector(const node_id_t id, + IndexStorage::MemoryBlock &block) const { + auto loc = get_vector_chunk_loc(id); + ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), + "invalid chunk offset"); + + size_t read_size = vector_size(); + + size_t ret = node_chunks_[loc.first]->read(loc.second, block, read_size); + if (ailego_unlikely(ret != read_size)) { + LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", + loc.second, read_size, ret); + return IndexError_ReadData; + } + return 0; +} + +int HnswStreamerEntityNew::get_vector( + const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const { + vec_blocks.resize(count); + for (auto i = 0U; i < count; ++i) { + auto loc = get_vector_chunk_loc(ids[i]); + ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), + "invalid chunk offset"); + + size_t read_size = vector_size(); + + size_t ret = + node_chunks_[loc.first]->read(loc.second, vec_blocks[i], read_size); + if (ailego_unlikely(ret != read_size)) { + LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", + loc.second, read_size, ret); + return IndexError_ReadData; + } + } + return 0; +} + +key_t HnswStreamerEntityNew::get_key(node_id_t id) const { + if (use_key_info_map_) { + auto loc = get_key_chunk_loc(id); + IndexStorage::MemoryBlock key_block; + ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), + "invalid chunk offset"); + size_t ret = + node_chunks_[loc.first]->read(loc.second, key_block, sizeof(key_t)); + if (ailego_unlikely(ret != sizeof(key_t))) { + LOG_ERROR("Read vector failed, ret=%zu", ret); + return kInvalidKey; + } + + return *reinterpret_cast(key_block.data()); + } else { + return id; + } +} + +void HnswStreamerEntityNew::add_neighbor(level_t level, node_id_t id, + uint32_t size, node_id_t neighbor_id) { + auto loc = get_neighbor_chunk_loc(level, id); + size_t offset = + loc.second + sizeof(NeighborsHeader) + size * sizeof(node_id_t); + ailego_assert_with(size < neighbor_cnt(level), "invalid neighbor size"); + ailego_assert_with(offset < loc.first->data_size(), "invalid chunk offset"); + size_t ret = loc.first->write(offset, &neighbor_id, sizeof(node_id_t)); + if (ailego_unlikely(ret != sizeof(node_id_t))) { + LOG_ERROR("Write neighbor id failed, ret=%zu", ret); + return; + } + + uint32_t neighbors = size + 1; + ret = loc.first->write(loc.second, &neighbors, sizeof(uint32_t)); + if (ailego_unlikely(ret != sizeof(uint32_t))) { + LOG_ERROR("Write neighbor cnt failed, ret=%zu", ret); + } + + return; +} + +int HnswStreamerEntityNew::init_chunks(const Chunk::Pointer &header_chunk) { + if (header_chunk->data_size() < header_size()) { + LOG_ERROR("Invalid header chunk size"); + return IndexError_InvalidFormat; + } + IndexStorage::MemoryBlock header_block; + size_t size = header_chunk->read(0UL, header_block, header_size()); + if (ailego_unlikely(size != header_size())) { + LOG_ERROR("Read header chunk failed"); + return IndexError_ReadData; + } + *mutable_header() = + *reinterpret_cast(header_block.data()); + + int ret = check_hnsw_index(&header()); + if (ret != 0) { + broker_->close(); + return ret; + } + + node_chunks_.resize(broker_->get_chunk_cnt(ChunkBroker::CHUNK_TYPE_NODE)); + for (auto seq = 0UL; seq < node_chunks_.size(); ++seq) { + node_chunks_[seq] = broker_->get_chunk(ChunkBroker::CHUNK_TYPE_NODE, seq); + if (!node_chunks_[seq]) { + LOG_ERROR("Missing hnsw streamer data chunk %zu th of %zu", seq, + node_chunks_.size()); + return IndexError_InvalidFormat; + } + } + + upper_neighbor_chunks_.resize( + broker_->get_chunk_cnt(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR)); + for (auto seq = 0UL; seq < upper_neighbor_chunks_.size(); ++seq) { + upper_neighbor_chunks_[seq] = + broker_->get_chunk(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, seq); + if (!upper_neighbor_chunks_[seq]) { + LOG_ERROR("Missing hnsw streamer index chunk %zu th of %zu", seq, + upper_neighbor_chunks_.size()); + return IndexError_InvalidFormat; + } + } + + return 0; +} + +int HnswStreamerEntityNew::open(IndexStorage::Pointer stg, uint64_t max_index_size, + bool check_crc) { + std::lock_guard lock(mutex_); + bool huge_page = stg->isHugePage(); + LOG_DEBUG("huge_page: %d", (int)huge_page); + int ret = init_chunk_params(max_index_size, huge_page); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("init_chunk_params failed for %s", IndexError::What(ret)); + return ret; + } + ret = broker_->open(std::move(stg), max_index_size_, chunk_size_, check_crc); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Open index failed for %s", IndexError::What(ret)); + return ret; + } + ret = upper_neighbor_index_->init(broker_, upper_neighbor_chunk_size_, + scaling_factor(), estimate_doc_capacity(), + kUpperHashMemoryInflateRatio); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Init neighbor hash map failed"); + return ret; + } + + //! init header + auto header_chunk = broker_->get_chunk(ChunkBroker::CHUNK_TYPE_HEADER, + ChunkBroker::kDefaultChunkSeqId); + if (!header_chunk) { // open empty index, create one + auto p = + broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_HEADER, + ChunkBroker::kDefaultChunkSeqId, header_size()); + if (ailego_unlikely(p.first != 0)) { + LOG_ERROR("Alloc header chunk failed"); + return p.first; + } + size_t size = p.second->write(0UL, &header(), header_size()); + if (ailego_unlikely(size != header_size())) { + LOG_ERROR("Write header chunk failed"); + return IndexError_WriteData; + } + return 0; + } + + //! Open an exist hnsw index + ret = init_chunks(header_chunk); + if (ailego_unlikely(ret != 0)) { + return ret; + } + + //! total docs including features wrote in index but neighbors may not ready + node_id_t total_vecs = 0; + if (node_chunks_.size() > 0) { + size_t last_idx = node_chunks_.size() - 1; + auto last_chunk = node_chunks_[last_idx]; + if (last_chunk->data_size() % node_size()) { + LOG_WARN("The index may broken"); + return IndexError_InvalidFormat; + } + total_vecs = last_idx * node_cnt_per_chunk_ + + node_chunks_[last_idx]->data_size() / node_size(); + } + + LOG_INFO( + "Open index, l0NeighborCnt=%zu upperNeighborCnt=%zu " + "efConstruction=%zu curDocCnt=%u totalVecs=%u maxLevel=%u", + l0_neighbor_cnt(), upper_neighbor_cnt(), ef_construction(), doc_cnt(), + total_vecs, cur_max_level()); + //! try to correct the docCnt if index not fully flushed + if (doc_cnt() != total_vecs) { + LOG_WARN("Index closed abnormally, using totalVecs as curDocCnt"); + *mutable_doc_cnt() = total_vecs; + } + if (filter_same_key_ || get_vector_enabled_) { + if (use_key_info_map_) { + for (node_id_t id = 0U; id < doc_cnt(); ++id) { + if (get_key(id) == kInvalidKey) { + continue; + } + (*keys_map_)[get_key(id)] = id; + } + } + } + + stats_.set_loaded_count(doc_cnt()); + + return 0; +} + +int HnswStreamerEntityNew::close() { + LOG_DEBUG("close index"); + + std::lock_guard lock(mutex_); + flush_header(); + mutable_header()->reset(); + upper_neighbor_index_->cleanup(); + keys_map_->clear(); + header_.clear(); + node_chunks_.clear(); + upper_neighbor_chunks_.clear(); + + return broker_->close(); +} + +int HnswStreamerEntityNew::flush(uint64_t checkpoint) { + LOG_INFO("Flush index, curDocs=%u", doc_cnt()); + + std::lock_guard lock(mutex_); + flush_header(); + int ret = broker_->flush(checkpoint); + if (ret != 0) { + return ret; + } + + return 0; +} + +int HnswStreamerEntityNew::dump(const IndexDumper::Pointer &dumper) { + LOG_INFO("Dump index, curDocs=%u", doc_cnt()); + + //! sort by keys, to support get_vector by key in searcher + std::vector keys(doc_cnt()); + for (node_id_t i = 0; i < doc_cnt(); ++i) { + keys[i] = get_key(i); + } + + //! dump neighbors + auto get_level = [&](node_id_t id) { + auto it = upper_neighbor_index_->find(id); + if (it == upper_neighbor_index_->end()) { + return 0U; + }; + auto meta = reinterpret_cast(&it->second); + return meta->level; + }; + auto ret = dump_segments(dumper, keys.data(), get_level); + if (ailego_unlikely(ret < 0)) { + return ret; + } + *stats_.mutable_dumped_size() += ret; + + return 0; +} + +int HnswStreamerEntityNew::check_hnsw_index(const HNSWHeader *hd) const { + if (l0_neighbor_cnt() != hd->l0_neighbor_cnt() || + upper_neighbor_cnt() != hd->upper_neighbor_cnt()) { + LOG_ERROR("Param neighbor cnt: %zu:%zu mismatch index previous %zu:%zu", + l0_neighbor_cnt(), upper_neighbor_cnt(), hd->l0_neighbor_cnt(), + hd->upper_neighbor_cnt()); + return IndexError_Mismatch; + } + if (vector_size() != hd->vector_size()) { + LOG_ERROR("vector size %zu mismatch index previous %zu", vector_size(), + hd->vector_size()); + return IndexError_Mismatch; + } + if (ef_construction() != hd->ef_construction()) { + LOG_WARN("Param efConstruction %zu mismatch index previous %zu", + ef_construction(), hd->ef_construction()); + } + if (scaling_factor() != hd->scaling_factor()) { + LOG_WARN("Param scalingFactor %zu mismatch index previous %zu", + scaling_factor(), hd->scaling_factor()); + return IndexError_Mismatch; + } + if (prune_cnt() != hd->neighbor_prune_cnt()) { + LOG_WARN("Param pruneCnt %zu mismatch index previous %zu", prune_cnt(), + hd->neighbor_prune_cnt()); + return IndexError_Mismatch; + } + if ((hd->entry_point() != kInvalidNodeId && + hd->entry_point() >= hd->doc_cnt()) || + (hd->entry_point() == kInvalidNodeId && hd->doc_cnt() > 0U)) { + LOG_WARN("Invalid entryPoint %u, docCnt %u", hd->entry_point(), + hd->doc_cnt()); + return IndexError_InvalidFormat; + } + if (hd->entry_point() == kInvalidNodeId && + broker_->get_chunk_cnt(ChunkBroker::CHUNK_TYPE_NODE) > 0) { + LOG_WARN("The index is broken, maybe it haven't flush"); + return IndexError_InvalidFormat; + } + + return 0; +} + +int HnswStreamerEntityNew::add_vector(level_t level, key_t key, const void *vec, + node_id_t *id) { + Chunk::Pointer node_chunk; + size_t chunk_offset = -1UL; + + std::lock_guard lock(mutex_); + // duplicate check + if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { + LOG_WARN("Try to add duplicate key, ignore it"); + return IndexError_Duplicate; + } + + node_id_t local_id = static_cast(doc_cnt()); + uint32_t chunk_index = node_chunks_.size() - 1U; + if (chunk_index == -1U || + (node_chunks_[chunk_index]->data_size() >= + node_cnt_per_chunk_ * node_size())) { // no space left and need to alloc + if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { + LOG_ERROR("add vector failed for no memory quota"); + return IndexError_IndexFull; + } + chunk_index++; + auto p = broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_NODE, chunk_index, + chunk_size_); + if (ailego_unlikely(p.first != 0)) { + LOG_ERROR("Alloc data chunk failed"); + return p.first; + } + node_chunk = p.second; + chunk_offset = 0UL; + node_chunks_.emplace_back(node_chunk); + } else { + node_chunk = node_chunks_[chunk_index]; + chunk_offset = node_chunk->data_size(); + } + + size_t size = node_chunk->write(chunk_offset, vec, vector_size()); + if (ailego_unlikely(size != vector_size())) { + LOG_ERROR("Chunk write vec failed, ret=%zu", size); + return IndexError_WriteData; + } + size = node_chunk->write(chunk_offset + vector_size(), &key, sizeof(key_t)); + if (ailego_unlikely(size != sizeof(key_t))) { + LOG_ERROR("Chunk write vec failed, ret=%zu", size); + return IndexError_WriteData; + } + //! level 0 neighbors is inited to zero by default + + int ret = add_upper_neighbor(level, local_id); + if (ret != 0) { + return ret; + } + + chunk_offset += node_size(); + if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { + LOG_ERROR("Chunk resize to %zu failed", chunk_offset); + return IndexError_Runtime; + } + if (filter_same_key_ || get_vector_enabled_) { + if (use_key_info_map_) { + keys_map_lock_->lock(); + (*keys_map_)[key] = local_id; + keys_map_lock_->unlock(); + } + } + + *mutable_doc_cnt() += 1; + broker_->mark_dirty(); + *id = local_id; + + return 0; +} + +int HnswStreamerEntityNew::add_vector_with_id(level_t level, node_id_t id, + const void *vec) { + Chunk::Pointer node_chunk; + size_t chunk_offset = -1UL; + key_t key = id; + + std::lock_guard lock(mutex_); + + // duplicate check + if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { + LOG_WARN("Try to add duplicate key, ignore it"); + return IndexError_Duplicate; + } + + // set node_chunk & chunk_offset if succeed + auto func_get_node_chunk_and_offset = [&](node_id_t node_id) -> int { + uint32_t chunk_index = node_id >> node_index_mask_bits_; + ailego_assert_with(chunk_index <= node_chunks_.size(), "invalid chunk idx"); + // belongs to next chunk + if (chunk_index == node_chunks_.size()) { + if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { + LOG_ERROR("add vector failed for no memory quota"); + return IndexError_IndexFull; + } + auto p = broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_NODE, chunk_index, + chunk_size_); + if (ailego_unlikely(p.first != 0)) { + LOG_ERROR("Alloc data chunk failed"); + return p.first; + } + node_chunk = p.second; + node_chunks_.emplace_back(node_chunk); + } + + node_chunk = node_chunks_[chunk_index]; + chunk_offset = (node_id & node_index_mask_) * node_size(); + return 0; + }; + + for (size_t start_id = doc_cnt(); start_id < id; ++start_id) { + if (auto ret = func_get_node_chunk_and_offset(start_id); ret != 0) { + LOG_ERROR("func_get_node_chunk_and_offset failed"); + return ret; + } + size_t size = node_chunk->write(chunk_offset + vector_size(), &kInvalidKey, + sizeof(key_t)); + if (ailego_unlikely(size != sizeof(key_t))) { + LOG_ERROR("Chunk write key failed, ret=%zu", size); + return IndexError_WriteData; + } + + chunk_offset += node_size(); + if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { + LOG_ERROR("Chunk resize to %zu failed", chunk_offset); + return IndexError_Runtime; + } + } + + if (auto ret = func_get_node_chunk_and_offset(id); ret != 0) { + LOG_ERROR("func_get_node_chunk_and_offset failed"); + return ret; + } + + size_t size = node_chunk->write(chunk_offset, vec, vector_size()); + if (ailego_unlikely(size != vector_size())) { + LOG_ERROR("Chunk write vec failed, ret=%zu", size); + return IndexError_WriteData; + } + + size = node_chunk->write(chunk_offset + vector_size(), &key, sizeof(key_t)); + if (ailego_unlikely(size != sizeof(key_t))) { + LOG_ERROR("Chunk write vec failed, ret=%zu", size); + return IndexError_WriteData; + } + //! level 0 neighbors is inited to zero by default + + int ret = add_upper_neighbor(level, id); + if (ret != 0) { + return ret; + } + + if (*mutable_doc_cnt() <= id) { + *mutable_doc_cnt() = id + 1; + chunk_offset += node_size(); + if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { + LOG_ERROR("Chunk resize to %zu failed", chunk_offset); + return IndexError_Runtime; + } + } + + if (filter_same_key_ || get_vector_enabled_) { + if (use_key_info_map_) { + keys_map_lock_->lock(); + (*keys_map_)[key] = id; + keys_map_lock_->unlock(); + } + } + + broker_->mark_dirty(); + + return 0; +} + +void HnswStreamerEntityNew::update_ep_and_level(node_id_t ep, level_t level) { + header_.hnsw.entry_point = ep; + header_.hnsw.max_level = level; + flush_header(); + + return; +} + +const HnswStreamerEntityNew::Pointer HnswStreamerEntityNew::clone() const { + std::vector node_chunks; + node_chunks.reserve(node_chunks_.size()); + for (size_t i = 0UL; i < node_chunks_.size(); ++i) { + node_chunks.emplace_back(node_chunks_[i]->clone()); + if (ailego_unlikely(!node_chunks[i])) { + LOG_ERROR("HnswStreamerEntityNew get chunk failed in clone"); + return HnswStreamerEntityNew::Pointer(); + } + } + + std::vector upper_neighbor_chunks; + upper_neighbor_chunks.reserve(upper_neighbor_chunks_.size()); + for (size_t i = 0UL; i < upper_neighbor_chunks_.size(); ++i) { + upper_neighbor_chunks.emplace_back(upper_neighbor_chunks_[i]->clone()); + if (ailego_unlikely(!upper_neighbor_chunks[i])) { + LOG_ERROR("HnswStreamerEntityNew get chunk failed in clone"); + return HnswStreamerEntityNew::Pointer(); + } + } + + HnswStreamerEntityNew *entity = new (std::nothrow) HnswStreamerEntityNew( + stats_, header(), chunk_size_, node_index_mask_bits_, + upper_neighbor_mask_bits_, filter_same_key_, get_vector_enabled_, + upper_neighbor_index_, keys_map_lock_, keys_map_, use_key_info_map_, + std::move(node_chunks), std::move(upper_neighbor_chunks), broker_); + if (ailego_unlikely(!entity)) { + LOG_ERROR("HnswStreamerEntityNew new failed"); + } + return HnswStreamerEntityNew::Pointer(entity); +} + +int64_t HnswStreamerEntityNew::dump_mapping_segment(const IndexDumper::Pointer &dumper, + const key_t *keys) const { + std::vector mapping(doc_cnt()); + + std::iota(mapping.begin(), mapping.end(), 0U); + std::sort(mapping.begin(), mapping.end(), + [&](node_id_t i, node_id_t j) { return keys[i] < keys[j]; }); + + size_t size = mapping.size() * sizeof(node_id_t); + + return dump_segment(dumper, kGraphMappingSegmentId, mapping.data(), size); +} + +int64_t HnswStreamerEntityNew::dump_segments( + const IndexDumper::Pointer &dumper, key_t *keys, + const std::function &get_level) const { + HNSWHeader dump_hd(header()); + + dump_hd.graph.node_size = AlignSize(vector_size()); + + std::vector n2o_mapping; // map new id to origin id + std::vector o2n_mapping; // map origin id to new id + if (!o2n_mapping.empty()) { + dump_hd.hnsw.entry_point = o2n_mapping[entry_point()]; + } + + //! Dump header + int64_t hd_size = dump_header(dumper, dump_hd); + if (hd_size < 0) { + return hd_size; + } + + //! Dump vectors + int64_t vecs_size = dump_vectors(dumper, n2o_mapping); + if (vecs_size < 0) { + return vecs_size; + } + + //! Dump neighbors + auto neighbors_size = + dump_neighbors(dumper, get_level, n2o_mapping, o2n_mapping); + if (neighbors_size < 0) { + return neighbors_size; + } + //! free memory + n2o_mapping = std::vector(); + o2n_mapping = std::vector(); + + //! Dump keys + size_t key_segment_size = doc_cnt() * sizeof(key_t); + int64_t keys_size = + dump_segment(dumper, kGraphKeysSegmentId, keys, key_segment_size); + if (keys_size < 0) { + return keys_size; + } + + //! Dump mapping + int64_t mapping_size = dump_mapping_segment(dumper, keys); + if (mapping_size < 0) { + return mapping_size; + } + + return hd_size + keys_size + vecs_size + neighbors_size + mapping_size; +} + +int64_t HnswStreamerEntityNew::dump_vectors( + const IndexDumper::Pointer &dumper, + const std::vector &reorder_mapping) const { + size_t vector_dump_size = vector_size(); + + size_t padding_size = AlignSize(vector_dump_size) - vector_dump_size; + + std::vector padding(padding_size); + memset(padding.data(), 0, sizeof(char) * padding_size); + const void *data = nullptr; + uint32_t crc = 0U; + size_t vecs_size = 0UL; + + //! dump vectors + for (node_id_t id = 0; id < doc_cnt(); ++id) { + data = get_vector(reorder_mapping.empty() ? id : reorder_mapping[id]); + if (ailego_unlikely(!data)) { + return IndexError_ReadData; + } + size_t len = dumper->write(data, vector_size()); + if (len != vector_size()) { + LOG_ERROR("Dump vectors failed, write=%zu expect=%zu", len, + vector_size()); + return IndexError_WriteData; + } + + crc = ailego::Crc32c::Hash(data, vector_size(), crc); + vecs_size += vector_size(); + + if (padding_size == 0) { + continue; + } + + len = dumper->write(padding.data(), padding_size); + if (len != padding_size) { + LOG_ERROR("Dump vectors failed, write=%zu expect=%zu", len, padding_size); + return IndexError_WriteData; + } + crc = ailego::Crc32c::Hash(padding.data(), padding_size, crc); + vecs_size += padding_size; + } + + int ret = dumper->append(kGraphFeaturesSegmentId, vecs_size, 0UL, crc); + if (ret != 0) { + LOG_ERROR("Dump vectors segment meta failed, ret %d", ret); + return ret; + } + + return vecs_size; +} + +int64_t HnswStreamerEntityNew::dump_graph_neighbors( + const IndexDumper::Pointer &dumper, + const std::vector &reorder_mapping, + const std::vector &neighbor_mapping) const { + std::vector graph_meta; + graph_meta.reserve(doc_cnt()); + size_t offset = 0; + uint32_t crc = 0; + std::vector mapping(l0_neighbor_cnt()); + + uint32_t min_neighbor_count = 10000; + uint32_t max_neighbor_count = 0; + size_t sum_neighbor_count = 0; + + for (node_id_t id = 0; id < doc_cnt(); ++id) { + const Neighbors neighbors = + get_neighbors(0, reorder_mapping.empty() ? id : reorder_mapping[id]); + ailego_assert_with(!!neighbors.data, "invalid neighbors"); + ailego_assert_with(neighbors.size() <= l0_neighbor_cnt(), + "invalid neighbors"); + + uint32_t neighbor_count = neighbors.size(); + if (neighbor_count < min_neighbor_count) { + min_neighbor_count = neighbor_count; + } + if (neighbor_count > max_neighbor_count) { + max_neighbor_count = neighbor_count; + } + sum_neighbor_count += neighbor_count; + + graph_meta.emplace_back(offset, neighbor_count); + size_t size = neighbors.size() * sizeof(node_id_t); + const node_id_t *data = &neighbors[0]; + if (!neighbor_mapping.empty()) { + for (node_id_t i = 0; i < neighbors.size(); ++i) { + mapping[i] = neighbor_mapping[neighbors[i]]; + } + data = mapping.data(); + } + if (dumper->write(data, size) != size) { + LOG_ERROR("Dump graph neighbor id=%u failed, size %lu", id, size); + return IndexError_WriteData; + } + crc = ailego::Crc32c::Hash(data, size, crc); + offset += size; + } + + uint32_t average_neighbor_count = 0; + if (doc_cnt() > 0) { + average_neighbor_count = sum_neighbor_count / doc_cnt(); + } + LOG_INFO( + "Dump hnsw graph: min_neighbor_count[%u] max_neighbor_count[%u] " + "average_neighbor_count[%u]", + min_neighbor_count, max_neighbor_count, average_neighbor_count); + + size_t padding_size = 0; + int ret = CalcAndAddPadding(dumper, offset, &padding_size); + if (ret != 0) { + return ret; + } + ret = dumper->append(kGraphNeighborsSegmentId, offset, padding_size, crc); + if (ret != 0) { + LOG_ERROR("Dump segment %s failed, ret %d", + kGraphNeighborsSegmentId.c_str(), ret); + return ret; + } + + //! dump level 0 neighbors meta + auto len = dump_segment(dumper, kGraphOffsetsSegmentId, graph_meta.data(), + graph_meta.size() * sizeof(GraphNeighborMeta)); + if (len < 0) { + return len; + } + + return len + offset + padding_size; +} + +int64_t HnswStreamerEntityNew::dump_upper_neighbors( + const IndexDumper::Pointer &dumper, + const std::function &get_level, + const std::vector &reorder_mapping, + const std::vector &neighbor_mapping) const { + std::vector hnsw_meta; + hnsw_meta.reserve(doc_cnt()); + size_t offset = 0; + uint32_t crc = 0; + std::vector buffer(upper_neighbor_cnt() + 1); + for (node_id_t id = 0; id < doc_cnt(); ++id) { + node_id_t new_id = reorder_mapping.empty() ? id : reorder_mapping[id]; + auto level = get_level(new_id); + if (level == 0) { + hnsw_meta.emplace_back(0U, 0U); + continue; + } + hnsw_meta.emplace_back(offset, level); + ailego_assert_with((size_t)level < kMaxGraphLayers, "invalid level"); + for (level_t cur_level = 1; cur_level <= level; ++cur_level) { + const Neighbors neighbors = get_neighbors(cur_level, new_id); + ailego_assert_with(!!neighbors.data, "invalid neighbors"); + ailego_assert_with(neighbors.size() <= neighbor_cnt(cur_level), + "invalid neighbors"); + memset(buffer.data(), 0, sizeof(node_id_t) * buffer.size()); + buffer[0] = neighbors.size(); + if (neighbor_mapping.empty()) { + memcpy(&buffer[1], &neighbors[0], neighbors.size() * sizeof(node_id_t)); + } else { + for (node_id_t i = 0; i < neighbors.size(); ++i) { + buffer[i + 1] = neighbor_mapping[neighbors[i]]; + } + } + if (dumper->write(buffer.data(), sizeof(node_id_t) * buffer.size()) != + sizeof(node_id_t) * buffer.size()) { + LOG_ERROR("Dump graph neighbor id=%u failed, size %lu", id, + sizeof(node_id_t) * buffer.size()); + return IndexError_WriteData; + } + crc = ailego::Crc32c::Hash(buffer.data(), + sizeof(node_id_t) * buffer.size(), crc); + offset += sizeof(node_id_t) * buffer.size(); + } + } + size_t padding_size = 0; + int ret = CalcAndAddPadding(dumper, offset, &padding_size); + if (ret != 0) { + return ret; + } + + ret = dumper->append(kHnswNeighborsSegmentId, offset, padding_size, crc); + if (ret != 0) { + LOG_ERROR("Dump segment %s failed, ret %d", kHnswNeighborsSegmentId.c_str(), + ret); + return ret; + } + + //! dump level 0 neighbors meta + auto len = dump_segment(dumper, kHnswOffsetsSegmentId, hnsw_meta.data(), + hnsw_meta.size() * sizeof(HnswNeighborMeta)); + if (len < 0) { + return len; + } + + return len + offset + padding_size; +} + +int HnswStreamerEntityNew::CalcAndAddPadding(const IndexDumper::Pointer &dumper, + size_t data_size, size_t *padding_size) { + *padding_size = AlignSize(data_size) - data_size; + if (*padding_size == 0) { + return 0; + } + + std::string padding(*padding_size, '\0'); + if (dumper->write(padding.data(), *padding_size) != *padding_size) { + LOG_ERROR("Append padding failed, size %lu", *padding_size); + return IndexError_WriteData; + } + return 0; +} + + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h new file mode 100644 index 00000000..cf19b63d --- /dev/null +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h @@ -0,0 +1,744 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "hnsw_chunk.h" +#include "hnsw_entity.h" +#include "hnsw_index_hash.h" +#include "hnsw_params.h" + +namespace zvec { +namespace core { + +//! HnswStreamerEntityNew manage vector data, pkey, and node's neighbors +class HnswStreamerEntityNew { + public: // override + typedef std::shared_ptr Pointer; + + //! Cleanup + //! return 0 on success, or errCode in failure + int cleanup(); + + //! Make a copy of streamer entity, to support thread-safe operation. + //! The segment in container cannot be read concurrenly + const HnswStreamerEntityNew::Pointer clone() const; + + //! Get primary key of the node id + key_t get_key(node_id_t id) const; + + //! Get vector feature data by key + const void *get_vector(node_id_t id) const; + + //! Get vectors feature data by local ids + int get_vector(const node_id_t *ids, uint32_t count, + const void **vecs) const; + + int get_vector(const node_id_t id, + IndexStorage::MemoryBlock &block) const; + + int get_vector( + const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const; + + //! Get the node id's neighbors on graph level + //! Note: the neighbors cannot be modified, using the following + //! method to get WritableNeighbors if want to + const Neighbors get_neighbors(level_t level, + node_id_t id) const; + + //! Add vector and key to hnsw entity, and local id will be saved in id + int add_vector(level_t level, key_t key, const void *vec, + node_id_t *id); + + //! Add vector and id to hnsw entity + int add_vector_with_id(level_t level, node_id_t id, + const void *vec); + + int update_neighbors( + level_t level, node_id_t id, + const std::vector> &neighbors); + + //! Append neighbor_id to node id neighbors on level + //! Notice: the caller must be ensure the neighbors not full + void add_neighbor(level_t level, node_id_t id, uint32_t size, + node_id_t neighbor_id); + + //! Dump index by dumper + int dump(const IndexDumper::Pointer &dumper); + + void update_ep_and_level(node_id_t ep, level_t level); + + const void *get_vector_by_key(key_t key) const { + auto id = get_id(key); + return id == kInvalidNodeId ? nullptr : get_vector(id); + } + + int get_vector_by_key( + const key_t key, IndexStorage::MemoryBlock &block) const { + auto id = get_id(key); + if (id != kInvalidNodeId) { + return get_vector(id, block); + } else { + return IndexError_InvalidArgument; + } + } + + public: // hnsw entity public + + //! Get max neighbor size of graph level + inline size_t neighbor_cnt(level_t level) const { + return level == 0 ? header_.graph.l0_neighbor_count + : header_.hnsw.upper_neighbor_count; + } + + //! get max neighbor size of graph level 0 + inline size_t l0_neighbor_cnt() const { + return header_.graph.l0_neighbor_count; + } + + //! get min neighbor size of graph + inline size_t min_neighbor_cnt() const { + return header_.graph.min_neighbor_count; + } + + //! get upper neighbor size of graph level other than 0 + inline size_t upper_neighbor_cnt() const { + return header_.hnsw.upper_neighbor_count; + } + + //! Get current total doc of the hnsw graph + inline node_id_t *mutable_doc_cnt() { + return &header_.graph.doc_count; + } + + inline node_id_t doc_cnt() const { + return header_.graph.doc_count; + } + + //! Get hnsw graph scaling params + inline size_t scaling_factor() const { + return header_.hnsw.scaling_factor; + } + + //! Get prune_size + inline size_t prune_cnt() const { + return header_.graph.prune_neighbor_count; + } + + //! Current entity of top level graph + inline node_id_t entry_point() const { + return header_.hnsw.entry_point; + } + + //! Current max graph level + inline level_t cur_max_level() const { + return header_.hnsw.max_level; + } + + //! Retrieve index vector size + size_t vector_size() const { + return header_.graph.vector_size; + } + + //! Retrieve node size + size_t node_size() const { + return header_.graph.node_size; + } + + //! Retrieve ef constuction + size_t ef_construction() const { + return header_.graph.ef_construction; + } + + void set_vector_size(size_t size) { + header_.graph.vector_size = size; + } + + void set_prune_cnt(size_t v) { + header_.graph.prune_neighbor_count = v; + } + + void set_scaling_factor(size_t val) { + header_.hnsw.scaling_factor = val; + } + + void set_l0_neighbor_cnt(size_t cnt) { + header_.graph.l0_neighbor_count = cnt; + } + + void set_min_neighbor_cnt(size_t cnt) { + header_.graph.min_neighbor_count = cnt; + } + + void set_upper_neighbor_cnt(size_t cnt) { + header_.hnsw.upper_neighbor_count = cnt; + } + + void set_ef_construction(size_t ef) { + header_.graph.ef_construction = ef; + } + + static int CalcAndAddPadding(const IndexDumper::Pointer &dumper, + size_t data_size, size_t *padding_size); + + protected: + inline const HNSWHeader &header() const { + return header_; + } + + inline HNSWHeader *mutable_header() { + return &header_; + } + + inline size_t header_size() const { + return sizeof(header_); + } + + void set_node_size(size_t size) { + header_.graph.node_size = size; + } + + //! Dump all segment by dumper + //! Return dump size if success, errno(<0) in failure + int64_t dump_segments( + const IndexDumper::Pointer &dumper, key_t *keys, + const std::function &get_level) const; + + static inline size_t AlignSize(size_t size) { + return (size + 0x1F) & (~0x1F); + } + + static inline size_t AlignPageSize(size_t size) { + size_t page_mask = ailego::MemoryHelper::PageSize() - 1; + return (size + page_mask) & (~page_mask); + } + + static inline size_t AlignHugePageSize(size_t size) { + size_t page_mask = ailego::MemoryHelper::HugePageSize() - 1; + return (size + page_mask) & (~page_mask); + } + + private: + //! dump mapping segment, for get_vector_by_key in provider + int64_t dump_mapping_segment(const IndexDumper::Pointer &dumper, + const key_t *keys) const; + + //! dump hnsw head by dumper + //! Return dump size if success, errno(<0) in failure + int64_t dump_header(const IndexDumper::Pointer &dumper, + const HNSWHeader &hd) const; + + //! dump vectors by dumper + //! Return dump size if success, errno(<0) in failure + int64_t dump_vectors(const IndexDumper::Pointer &dumper, + const std::vector &reorder_mapping) const; + + //! dump hnsw neighbors by dumper + //! Return dump size if success, errno(<0) in failure + int64_t dump_neighbors(const IndexDumper::Pointer &dumper, + const std::function &get_level, + const std::vector &reorder_mapping, + const std::vector &neighbor_mapping) const { + auto len1 = dump_graph_neighbors(dumper, reorder_mapping, neighbor_mapping); + if (len1 < 0) { + return len1; + } + auto len2 = dump_upper_neighbors(dumper, get_level, reorder_mapping, + neighbor_mapping); + if (len2 < 0) { + return len2; + } + + return len1 + len2; + } + + //! dump segment by dumper + //! Return dump size if success, errno(<0) in failure + int64_t dump_segment(const IndexDumper::Pointer &dumper, + const std::string &segment_id, const void *data, + size_t size) const; + + //! Dump level 0 neighbors + //! Return dump size if success, errno(<0) in failure + int64_t dump_graph_neighbors( + const IndexDumper::Pointer &dumper, + const std::vector &reorder_mapping, + const std::vector &neighbor_mapping) const; + + //! Dump upper level neighbors + //! Return dump size if success, errno(<0) in failure + int64_t dump_upper_neighbors( + const IndexDumper::Pointer &dumper, + const std::function &get_level, + const std::vector &reorder_mapping, + const std::vector &neighbor_mapping) const; + + public: + const static std::string kGraphHeaderSegmentId; + const static std::string kGraphFeaturesSegmentId; + const static std::string kGraphKeysSegmentId; + const static std::string kGraphNeighborsSegmentId; + const static std::string kGraphOffsetsSegmentId; + const static std::string kGraphMappingSegmentId; + const static std::string kHnswHeaderSegmentId; + const static std::string kHnswNeighborsSegmentId; + const static std::string kHnswOffsetsSegmentId; + + constexpr static uint32_t kRevision = 0U; + constexpr static size_t kMaxGraphLayers = 15; + constexpr static uint32_t kDefaultEfConstruction = 500; + constexpr static uint32_t kDefaultEf = 500; + constexpr static uint32_t kDefaultUpperMaxNeighborCnt = 50; // M of HNSW + constexpr static uint32_t kDefaultL0MaxNeighborCnt = 100; + constexpr static uint32_t kMaxNeighborCnt = 65535; + constexpr static float kDefaultScanRatio = 0.1f; + constexpr static uint32_t kDefaultMinScanLimit = 10000; + constexpr static uint32_t kDefaultMaxScanLimit = + std::numeric_limits::max(); + constexpr static float kDefaultBFNegativeProbability = 0.001f; + constexpr static uint32_t kDefaultScalingFactor = 50U; + constexpr static uint32_t kDefaultBruteForceThreshold = 1000U; + constexpr static uint32_t kDefaultDocsHardLimit = 1 << 30U; // 1 billion + constexpr static float kDefaultDocsSoftLimitRatio = 0.9f; + constexpr static size_t kMaxChunkSize = 0xFFFFFFFF; + constexpr static size_t kDefaultChunkSize = 2UL * 1024UL * 1024UL; + constexpr static size_t kDefaultMaxChunkCnt = 50000UL; + constexpr static float kDefaultNeighborPruneMultiplier = + 1.0f; // prune_cnt = upper_max_neighbor_cnt * multiplier + constexpr static float kDefaultL0MaxNeighborCntMultiplier = + 2.0f; // l0_max_neighbor_cnt = upper_max_neighbor_cnt * multiplier + + public: + //! Constructor + HnswStreamerEntityNew(IndexStreamer::Stats &stats); + + //! Destructor + ~HnswStreamerEntityNew(); + + //! Get vector feature data by key + + + //! Init entity + int init(size_t max_doc_cnt); + + //! Flush graph entity to disk + //! return 0 on success, or errCode in failure + int flush(uint64_t checkpoint); + + //! Open entity from storage + //! return 0 on success, or errCode in failure + int open(IndexStorage::Pointer stg, uint64_t max_index_size, bool check_crc); + + //! Close entity + //! return 0 on success, or errCode in failure + int close(); + + void set_use_key_info_map(bool use_id_map) { + use_key_info_map_ = use_id_map; + LOG_DEBUG("use_key_info_map_: %d", (int)use_key_info_map_); + } + + //! Set meta information from entity + int set_index_meta(const IndexMeta &meta) const { + return IndexHelper::SerializeToStorage(meta, broker_->storage().get()); + } + + //! Get meta information from entity + int get_index_meta(IndexMeta *meta) const { + return IndexHelper::DeserializeFromStorage(broker_->storage().get(), meta); + } + + //! Set params: chunk size + inline void set_chunk_size(size_t val) { + chunk_size_ = val; + } + + //! Set params + inline void set_filter_same_key(bool val) { + filter_same_key_ = val; + } + + //! Set params + inline void set_get_vector(bool val) { + get_vector_enabled_ = val; + } + + //! Get vector local id by key + inline node_id_t get_id(key_t key) const { + if (use_key_info_map_) { + keys_map_lock_->lock_shared(); + auto it = keys_map_->find(key); + keys_map_lock_->unlock_shared(); + return it == keys_map_->end() ? kInvalidNodeId : it->second; + } else { + return key; + } + } + + void print_key_map() const { + std::cout << "key map begins" << std::endl; + + auto iter = keys_map_->begin(); + while (iter != keys_map_->end()) { + std::cout << "key: " << iter->first << ", id: " << iter->second + << std::endl; + ; + iter++; + } + + std::cout << "key map ends" << std::endl; + } + + //! Get l0 neighbors size + inline size_t neighbors_size() const { + return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); + } + + //! Get neighbors size for level > 0 + inline size_t upper_neighbors_size() const { + return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); + } + + + private: + union UpperNeighborIndexMeta { + struct { + uint32_t level : 4; + uint32_t index : 28; // index is composite type: chunk idx, and the + // N th neighbors in chunk, they two composite + // the 28 bits location + }; + uint32_t data; + }; + + template + using HashMap = google::dense_hash_map>; + template + using HashMapPointer = std::shared_ptr>; + + template + using HashSet = google::dense_hash_set>; + template + using HashSetPointer = std::shared_ptr>; + + //! upper neighbor index hashmap + using NIHashMap = HnswIndexHashMap; + using NIHashMapPointer = std::shared_ptr; + + //! Private construct, only be called by clone method + HnswStreamerEntityNew(IndexStreamer::Stats &stats, const HNSWHeader &hd, + size_t chunk_size, uint32_t node_index_mask_bits, + uint32_t upper_neighbor_mask_bits, bool filter_same_key, + bool get_vector_enabled, + const NIHashMapPointer &upper_neighbor_index, + std::shared_ptr &keys_map_lock, + const HashMapPointer &keys_map, + bool use_key_info_map, + std::vector &&node_chunks, + std::vector &&upper_neighbor_chunks, + const ChunkBroker::Pointer &broker) + : stats_(stats), + chunk_size_(chunk_size), + node_index_mask_bits_(node_index_mask_bits), + node_cnt_per_chunk_(1UL << node_index_mask_bits_), + node_index_mask_(node_cnt_per_chunk_ - 1), + upper_neighbor_mask_bits_(upper_neighbor_mask_bits), + upper_neighbor_mask_((1U << upper_neighbor_mask_bits_) - 1), + filter_same_key_(filter_same_key), + get_vector_enabled_(get_vector_enabled), + use_key_info_map_(use_key_info_map), + upper_neighbor_index_(upper_neighbor_index), + keys_map_lock_(keys_map_lock), + keys_map_(keys_map), + node_chunks_(std::move(node_chunks)), + upper_neighbor_chunks_(std::move(upper_neighbor_chunks)), + broker_(broker) { + *mutable_header() = hd; + + neighbor_size_ = neighbors_size(); + upper_neighbor_size_ = upper_neighbors_size(); + } + + //! Called only in searching procedure per context, so no need to lock + void sync_chunks(ChunkBroker::CHUNK_TYPE type, size_t idx, + std::vector *chunks) const { + if (ailego_likely(idx < chunks->size())) { + return; + } + for (size_t i = chunks->size(); i <= idx; ++i) { + auto chunk = broker_->get_chunk(type, i); + // the storage can ensure get chunk will success after the first get + ailego_assert_with(!!chunk, "get chunk failed"); + chunks->emplace_back(std::move(chunk)); + } + } + + //! return pair: chunk index + chunk offset + inline std::pair get_vector_chunk_loc( + node_id_t id) const { + uint32_t chunk_idx = id >> node_index_mask_bits_; + uint32_t offset = (id & node_index_mask_) * node_size(); + + sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); + return std::make_pair(chunk_idx, offset); + } + + //! return pair: chunk index + chunk offset + inline std::pair get_key_chunk_loc(node_id_t id) const { + uint32_t chunk_idx = id >> node_index_mask_bits_; + uint32_t offset = (id & node_index_mask_) * node_size() + vector_size(); + + sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); + return std::make_pair(chunk_idx, offset); + } + + inline std::pair get_upper_neighbor_chunk_loc( + level_t level, node_id_t id) const { + auto it = upper_neighbor_index_->find(id); + ailego_assert_abort(it != upper_neighbor_index_->end(), + "Get upper neighbor header failed"); + auto meta = reinterpret_cast(&it->second); + uint32_t chunk_idx = (meta->index) >> upper_neighbor_mask_bits_; + uint32_t offset = (((meta->index) & upper_neighbor_mask_) + level - 1) * + upper_neighbor_size_; + sync_chunks(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, chunk_idx, + &upper_neighbor_chunks_); + ailego_assert_abort(chunk_idx < upper_neighbor_chunks_.size(), + "invalid chunk idx"); + ailego_assert_abort(offset < upper_neighbor_chunks_[chunk_idx]->data_size(), + "invalid chunk offset"); + return std::make_pair(chunk_idx, offset); + } + + //! return pair: chunk + chunk offset + inline std::pair get_neighbor_chunk_loc(level_t level, + node_id_t id) const { + if (level == 0UL) { + uint32_t chunk_idx = id >> node_index_mask_bits_; + uint32_t offset = + (id & node_index_mask_) * node_size() + vector_size() + sizeof(key_t); + + sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); + ailego_assert_abort(chunk_idx < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_abort(offset < node_chunks_[chunk_idx]->data_size(), + "invalid chunk offset"); + return std::make_pair(node_chunks_[chunk_idx].get(), offset); + } else { + auto p = get_upper_neighbor_chunk_loc(level, id); + return std::make_pair(upper_neighbor_chunks_[p.first].get(), p.second); + } + } + + //! Chunk hnsw index valid + int check_hnsw_index(const HNSWHeader *hd) const; + + size_t get_total_upper_neighbors_size(level_t level) const { + return level * upper_neighbor_size_; + } + + //! Add upper neighbor header and reserve space for upper neighbor + int add_upper_neighbor(level_t level, node_id_t id) { + if (level == 0) { + return 0; + } + Chunk::Pointer chunk; + uint64_t chunk_offset = -1UL; + size_t neighbors_size = get_total_upper_neighbors_size(level); + uint64_t chunk_index = upper_neighbor_chunks_.size() - 1UL; + if (chunk_index == -1UL || + (upper_neighbor_chunks_[chunk_index]->padding_size() < + neighbors_size)) { // no space left and need to alloc + chunk_index++; + if (ailego_unlikely(upper_neighbor_chunks_.capacity() == + upper_neighbor_chunks_.size())) { + LOG_ERROR("add upper neighbor failed for no memory quota"); + return IndexError_IndexFull; + } + auto p = broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, + chunk_index, upper_neighbor_chunk_size_); + if (ailego_unlikely(p.first != 0)) { + LOG_ERROR("Alloc data chunk failed"); + return p.first; + } + chunk = p.second; + chunk_offset = 0UL; + upper_neighbor_chunks_.emplace_back(chunk); + } else { + chunk = upper_neighbor_chunks_[chunk_index]; + chunk_offset = chunk->data_size(); + } + ailego_assert_with((size_t)level < kMaxGraphLayers, "invalid level"); + ailego_assert_with(chunk_offset % upper_neighbor_size_ == 0, + "invalid offset"); + ailego_assert_with((chunk_offset / upper_neighbor_size_) < + (1U << upper_neighbor_mask_bits_), + "invalid offset"); + ailego_assert_with(chunk_index < (1U << (28 - upper_neighbor_mask_bits_)), + "invalid chunk index"); + UpperNeighborIndexMeta meta; + meta.level = level; + meta.index = (chunk_index << upper_neighbor_mask_bits_) | + (chunk_offset / upper_neighbor_size_); + chunk_offset += upper_neighbor_size_ * level; + if (ailego_unlikely(!upper_neighbor_index_->insert(id, meta.data))) { + LOG_ERROR("HashMap insert value failed"); + return IndexError_Runtime; + } + + if (ailego_unlikely(chunk->resize(chunk_offset) != chunk_offset)) { + LOG_ERROR("Chunk resize to %zu failed", (size_t)chunk_offset); + return IndexError_Runtime; + } + + return 0; + } + + size_t estimate_doc_capacity() const { + return node_chunks_.capacity() * node_cnt_per_chunk_; + } + + int init_chunk_params(size_t max_index_size, bool huge_page) { + node_cnt_per_chunk_ = std::max(1, chunk_size_ / node_size()); + //! align node cnt per chunk to pow of 2 + node_index_mask_bits_ = std::ceil(std::log2(node_cnt_per_chunk_)); + node_cnt_per_chunk_ = 1UL << node_index_mask_bits_; + if (huge_page) { + chunk_size_ = AlignHugePageSize(node_cnt_per_chunk_ * node_size()); + } else { + chunk_size_ = AlignPageSize(node_cnt_per_chunk_ * node_size()); + } + node_index_mask_ = node_cnt_per_chunk_ - 1; + + if (max_index_size == 0UL) { + max_index_size_ = chunk_size_ * kDefaultMaxChunkCnt; + } else { + max_index_size_ = max_index_size; + } + + //! To get a balanced upper neighbor chunk size. + //! If the upper chunk size is equal to node chunk size, it may waste + //! upper neighbor chunk space; if the upper neighbor chunk size is too + //! small, the will need large upper neighbor chunks index space. So to + //! get a balanced ratio be sqrt of the node/neighbor size ratio + float ratio = + std::sqrt(node_size() * scaling_factor() * 1.0f / upper_neighbor_size_); + if (huge_page) { + upper_neighbor_chunk_size_ = AlignHugePageSize( + std::max(get_total_upper_neighbors_size(kMaxGraphLayers), + static_cast(chunk_size_ / ratio))); + } else { + upper_neighbor_chunk_size_ = AlignPageSize( + std::max(get_total_upper_neighbors_size(kMaxGraphLayers), + static_cast(chunk_size_ / ratio))); + } + upper_neighbor_mask_bits_ = + std::ceil(std::log2(upper_neighbor_chunk_size_ / upper_neighbor_size_)); + upper_neighbor_mask_ = (1 << upper_neighbor_mask_bits_) - 1; + + size_t max_node_chunk_cnt = std::ceil(max_index_size_ / chunk_size_); + size_t max_upper_chunk_cnt = std::ceil( + (max_node_chunk_cnt * node_cnt_per_chunk_ * 1.0f / scaling_factor()) / + (upper_neighbor_chunk_size_ / upper_neighbor_size_)); + max_upper_chunk_cnt = + max_upper_chunk_cnt + std::ceil(max_upper_chunk_cnt / scaling_factor()); + + //! reserve space to avoid memmove in chunks vector emplace chunk, so + //! as to lock-free in reading chunk + node_chunks_.reserve(max_node_chunk_cnt); + upper_neighbor_chunks_.reserve(max_upper_chunk_cnt); + + LOG_DEBUG( + "Settings: nodeSize=%zu chunkSize=%u upperNeighborSize=%u " + "upperNeighborChunkSize=%u " + "nodeCntPerChunk=%u maxChunkCnt=%zu maxNeighborChunkCnt=%zu " + "maxIndexSize=%zu ratio=%.3f", + node_size(), chunk_size_, upper_neighbor_size_, + upper_neighbor_chunk_size_, node_cnt_per_chunk_, max_node_chunk_cnt, + max_upper_chunk_cnt, max_index_size_, ratio); + + return 0; + } + + //! Init node chunk and neighbor chunks + int init_chunks(const Chunk::Pointer &header_chunk); + + int flush_header(void) { + if (!broker_->dirty()) { + // do not need to flush + return 0; + } + auto header_chunk = broker_->get_chunk(ChunkBroker::CHUNK_TYPE_HEADER, + ChunkBroker::kDefaultChunkSeqId); + if (ailego_unlikely(!header_chunk)) { + LOG_ERROR("get header chunk failed"); + return IndexError_Runtime; + } + size_t size = header_chunk->write(0UL, &header(), header_size()); + if (ailego_unlikely(size != header_size())) { + LOG_ERROR("Write header chunk failed"); + return IndexError_WriteData; + } + + return 0; + } + + private: + HnswStreamerEntityNew(const HnswStreamerEntityNew &) = delete; + HnswStreamerEntityNew &operator=(const HnswStreamerEntityNew &) = delete; + static constexpr uint64_t kUpperHashMemoryInflateRatio = 2.0f; + + private: + IndexStreamer::Stats &stats_; + HNSWHeader header_{}; + std::mutex mutex_{}; + size_t max_index_size_{0UL}; + uint32_t chunk_size_{kDefaultChunkSize}; + uint32_t upper_neighbor_chunk_size_{kDefaultChunkSize}; + uint32_t node_index_mask_bits_{0U}; + uint32_t node_cnt_per_chunk_{0U}; + uint32_t node_index_mask_{0U}; + uint32_t neighbor_size_{0U}; + uint32_t upper_neighbor_size_{0U}; + //! UpperNeighborIndex.index composite chunkIdx and offset in chunk by the + //! following mask + uint32_t upper_neighbor_mask_bits_{0U}; + uint32_t upper_neighbor_mask_{0U}; + bool filter_same_key_{false}; + bool get_vector_enabled_{false}; + bool use_key_info_map_{true}; + + NIHashMapPointer upper_neighbor_index_{}; + + mutable std::shared_ptr keys_map_lock_{}; + HashMapPointer keys_map_{}; + + //! the chunks will be changed in searcher, so need mutable + //! data chunk include: vector, key, level 0 neighbors + mutable std::vector node_chunks_{}; + + //! upper neighbor chunk inlude: UpperNeighborHeader + (1~level) neighbors + mutable std::vector upper_neighbor_chunks_{}; + + ChunkBroker::Pointer broker_{}; // chunk broker +}; + +} // namespace core +} // namespace zvec \ No newline at end of file From 7a6256781992f9f1a2eba116da3ae1173ec703b0 Mon Sep 17 00:00:00 2001 From: "yinzefeng.yzf" Date: Tue, 10 Mar 2026 21:12:11 +0800 Subject: [PATCH 05/34] replace hnsw_entity --- src/core/algorithm/hnsw/hnsw_algorithm.cc | 2 +- src/core/algorithm/hnsw/hnsw_algorithm.h | 6 +- src/core/algorithm/hnsw/hnsw_context.cc | 6 +- src/core/algorithm/hnsw/hnsw_context.h | 18 +++--- .../algorithm/hnsw/hnsw_dist_calculator.h | 14 ++-- src/core/algorithm/hnsw/hnsw_index_provider.h | 10 +-- src/core/algorithm/hnsw/hnsw_streamer.cc | 64 +++++++++---------- src/core/algorithm/hnsw/hnsw_streamer.h | 4 +- 8 files changed, 62 insertions(+), 62 deletions(-) diff --git a/src/core/algorithm/hnsw/hnsw_algorithm.cc b/src/core/algorithm/hnsw/hnsw_algorithm.cc index fa553f55..e5561544 100644 --- a/src/core/algorithm/hnsw/hnsw_algorithm.cc +++ b/src/core/algorithm/hnsw/hnsw_algorithm.cc @@ -20,7 +20,7 @@ namespace zvec { namespace core { -HnswAlgorithm::HnswAlgorithm(HnswEntity &entity) +HnswAlgorithm::HnswAlgorithm(HnswStreamerEntityNew &entity) : entity_(entity), mt_(std::chrono::system_clock::now().time_since_epoch().count()), lock_pool_(kLockCnt) {} diff --git a/src/core/algorithm/hnsw/hnsw_algorithm.h b/src/core/algorithm/hnsw/hnsw_algorithm.h index 886d870c..e699477b 100644 --- a/src/core/algorithm/hnsw/hnsw_algorithm.h +++ b/src/core/algorithm/hnsw/hnsw_algorithm.h @@ -17,7 +17,7 @@ #include #include "hnsw_context.h" #include "hnsw_dist_calculator.h" -#include "hnsw_entity.h" +#include "hnsw_streamer_entity_new.h" namespace zvec { namespace core { @@ -29,7 +29,7 @@ class HnswAlgorithm { public: //! Constructor - explicit HnswAlgorithm(HnswEntity &entity); + explicit HnswAlgorithm(HnswStreamerEntityNew &entity); //! Destructor ~HnswAlgorithm() = default; @@ -116,7 +116,7 @@ class HnswAlgorithm { static constexpr uint32_t kLockCnt{1U << 8}; static constexpr uint32_t kLockMask{kLockCnt - 1U}; - HnswEntity &entity_; + HnswStreamerEntityNew &entity_; mutable std::mt19937 mt_{}; std::vector level_probas_{}; diff --git a/src/core/algorithm/hnsw/hnsw_context.cc b/src/core/algorithm/hnsw/hnsw_context.cc index b930e418..3ac97515 100644 --- a/src/core/algorithm/hnsw/hnsw_context.cc +++ b/src/core/algorithm/hnsw/hnsw_context.cc @@ -19,13 +19,13 @@ namespace zvec { namespace core { HnswContext::HnswContext(size_t dimension, const IndexMetric::Pointer &metric, - const HnswEntity::Pointer &entity) + const HnswStreamerEntityNew::Pointer &entity) : IndexContext(metric), entity_(entity), dc_(entity_.get(), metric, dimension) {} HnswContext::HnswContext(const IndexMetric::Pointer &metric, - const HnswEntity::Pointer &entity) + const HnswStreamerEntityNew::Pointer &entity) : IndexContext(metric), entity_(entity), dc_(entity_.get(), metric) {} HnswContext::~HnswContext() { @@ -201,7 +201,7 @@ int HnswContext::update(const ailego::Params ¶ms) { int HnswContext::update_context(ContextType type, const IndexMeta &meta, const IndexMetric::Pointer &metric, - const HnswEntity::Pointer &entity, + const HnswStreamerEntityNew::Pointer &entity, uint32_t magic_num) { uint32_t doc_cnt; diff --git a/src/core/algorithm/hnsw/hnsw_context.h b/src/core/algorithm/hnsw/hnsw_context.h index 22bcfaad..0f988baf 100644 --- a/src/core/algorithm/hnsw/hnsw_context.h +++ b/src/core/algorithm/hnsw/hnsw_context.h @@ -17,7 +17,7 @@ #include "utility/sparse_utility.h" #include "utility/visit_filter.h" #include "hnsw_dist_calculator.h" -#include "hnsw_entity.h" +#include "hnsw_streamer_entity_new.h" namespace zvec { namespace core { @@ -36,11 +36,11 @@ class HnswContext : public IndexContext { //! Construct HnswContext(size_t dimension, const IndexMetric::Pointer &metric, - const HnswEntity::Pointer &entity); + const HnswStreamerEntityNew::Pointer &entity); //! Construct HnswContext(const IndexMetric::Pointer &metric, - const HnswEntity::Pointer &entity); + const HnswStreamerEntityNew::Pointer &entity); //! Destructor virtual ~HnswContext(); @@ -114,9 +114,9 @@ class HnswContext : public IndexContext { //! Update context, the context may be shared by different searcher/streamer int update_context(ContextType type, const IndexMeta &meta, const IndexMetric::Pointer &metric, - const HnswEntity::Pointer &entity, uint32_t magic_num); + const HnswStreamerEntityNew::Pointer &entity, uint32_t magic_num); - inline const HnswEntity &get_entity() const { + inline const HnswStreamerEntityNew &get_entity() const { return *entity_; } @@ -488,7 +488,7 @@ class HnswContext : public IndexContext { constexpr static uint32_t kInvalidMgic = -1U; private: - HnswEntity::Pointer entity_; + HnswStreamerEntityNew::Pointer entity_; HnswDistCalculator dc_; IndexMetric::Pointer metric_; @@ -501,9 +501,9 @@ class HnswContext : public IndexContext { uint32_t topk_{0}; uint32_t group_topk_{0}; uint32_t filter_mode_{VisitFilter::ByteMap}; - float negative_probability_{HnswEntity::kDefaultBFNegativeProbability}; - uint32_t ef_{HnswEntity::kDefaultEf}; - float max_scan_ratio_{HnswEntity::kDefaultScanRatio}; + float negative_probability_{HnswStreamerEntityNew::kDefaultBFNegativeProbability}; + uint32_t ef_{HnswStreamerEntityNew::kDefaultEf}; + float max_scan_ratio_{HnswStreamerEntityNew::kDefaultScanRatio}; uint32_t magic_{0U}; std::vector results_{}; std::vector group_results_{}; diff --git a/src/core/algorithm/hnsw/hnsw_dist_calculator.h b/src/core/algorithm/hnsw/hnsw_dist_calculator.h index 84faba40..4f6b624e 100644 --- a/src/core/algorithm/hnsw/hnsw_dist_calculator.h +++ b/src/core/algorithm/hnsw/hnsw_dist_calculator.h @@ -14,7 +14,7 @@ #pragma once #include -#include "hnsw_entity.h" +#include "hnsw_streamer_entity_new.h" namespace zvec { namespace core { @@ -33,7 +33,7 @@ class HnswDistCalculator { public: //! Constructor - HnswDistCalculator(const HnswEntity *entity, + HnswDistCalculator(const HnswStreamerEntityNew *entity, const IndexMetric::Pointer &metric, uint32_t dim) : entity_(entity), distance_(metric->distance()), @@ -43,7 +43,7 @@ class HnswDistCalculator { compare_cnt_(0) {} //! Constructor - HnswDistCalculator(const HnswEntity *entity, + HnswDistCalculator(const HnswStreamerEntityNew *entity, const IndexMetric::Pointer &metric, uint32_t dim, const void *query) : entity_(entity), @@ -54,7 +54,7 @@ class HnswDistCalculator { compare_cnt_(0) {} //! Constructor - HnswDistCalculator(const HnswEntity *entity, + HnswDistCalculator(const HnswStreamerEntityNew *entity, const IndexMetric::Pointer &metric) : entity_(entity), distance_(metric->distance()), @@ -63,13 +63,13 @@ class HnswDistCalculator { dim_(0), compare_cnt_(0) {} - void update(const HnswEntity *entity, const IndexMetric::Pointer &metric) { + void update(const HnswStreamerEntityNew *entity, const IndexMetric::Pointer &metric) { entity_ = entity; distance_ = metric->distance(); batch_distance_ = metric->batch_distance(); } - void update(const HnswEntity *entity, const IndexMetric::Pointer &metric, + void update(const HnswStreamerEntityNew *entity, const IndexMetric::Pointer &metric, uint32_t dim) { entity_ = entity; distance_ = metric->distance(); @@ -201,7 +201,7 @@ class HnswDistCalculator { HnswDistCalculator &operator=(const HnswDistCalculator &) = delete; private: - const HnswEntity *entity_; + const HnswStreamerEntityNew *entity_; IndexMetric::MatrixDistance distance_; IndexMetric::MatrixBatchDistance batch_distance_; diff --git a/src/core/algorithm/hnsw/hnsw_index_provider.h b/src/core/algorithm/hnsw/hnsw_index_provider.h index 4a6ccaeb..b128a2c0 100644 --- a/src/core/algorithm/hnsw/hnsw_index_provider.h +++ b/src/core/algorithm/hnsw/hnsw_index_provider.h @@ -16,14 +16,14 @@ #include #include #include -#include "hnsw_entity.h" +#include "hnsw_streamer_entity_new.h" namespace zvec { namespace core { class HnswIndexProvider : public IndexProvider { public: - HnswIndexProvider(const IndexMeta &meta, const HnswEntity::Pointer &entity, + HnswIndexProvider(const IndexMeta &meta, const HnswStreamerEntityNew::Pointer &entity, const std::string &owner) : meta_(meta), entity_(entity), owner_class_(owner) {} @@ -76,7 +76,7 @@ class HnswIndexProvider : public IndexProvider { private: class Iterator : public IndexProvider::Iterator { public: - Iterator(const HnswEntity::Pointer &entity) + Iterator(const HnswStreamerEntityNew::Pointer &entity) : entity_(entity), cur_id_(0U) {} //! Retrieve pointer of data @@ -119,13 +119,13 @@ class HnswIndexProvider : public IndexProvider { } private: - const HnswEntity::Pointer entity_; + const HnswStreamerEntityNew::Pointer entity_; node_id_t cur_id_; }; private: const IndexMeta &meta_; - const HnswEntity::Pointer entity_; + const HnswStreamerEntityNew::Pointer entity_; const std::string owner_class_; }; diff --git a/src/core/algorithm/hnsw/hnsw_streamer.cc b/src/core/algorithm/hnsw/hnsw_streamer.cc index 5804c7d0..057a804b 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer.cc +++ b/src/core/algorithm/hnsw/hnsw_streamer.cc @@ -35,16 +35,16 @@ HnswStreamer::~HnswStreamer() { int HnswStreamer::init(const IndexMeta &imeta, const ailego::Params ¶ms) { meta_ = imeta; - meta_.set_streamer("HnswStreamer", HnswEntity::kRevision, params); + meta_.set_streamer("HnswStreamer", HnswStreamerEntityNew::kRevision, params); params.get(PARAM_HNSW_STREAMER_MAX_INDEX_SIZE, &max_index_size_); params.get(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, &upper_max_neighbor_cnt_); - float multiplier = HnswEntity::kDefaultL0MaxNeighborCntMultiplier; + float multiplier = HnswStreamerEntityNew::kDefaultL0MaxNeighborCntMultiplier; params.get(PARAM_HNSW_STREAMER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER, &multiplier); l0_max_neighbor_cnt_ = multiplier * upper_max_neighbor_cnt_; - multiplier = HnswEntity::kDefaultNeighborPruneMultiplier; + multiplier = HnswStreamerEntityNew::kDefaultNeighborPruneMultiplier; params.get(PARAM_HNSW_STREAMER_NEIGHBOR_PRUNE_MULTIPLIER, &multiplier); size_t prune_cnt = multiplier * upper_max_neighbor_cnt_; scaling_factor_ = upper_max_neighbor_cnt_; @@ -78,30 +78,30 @@ int HnswStreamer::init(const IndexMeta &imeta, const ailego::Params ¶ms) { return IndexError_InvalidArgument; } else if (docs_soft_limit_ == 0UL) { docs_soft_limit_ = - docs_hard_limit_ * HnswEntity::kDefaultDocsSoftLimitRatio; + docs_hard_limit_ * HnswStreamerEntityNew::kDefaultDocsSoftLimitRatio; } if (ef_ == 0U) { - ef_ = HnswEntity::kDefaultEf; + ef_ = HnswStreamerEntityNew::kDefaultEf; } if (ef_construction_ == 0U) { - ef_construction_ = HnswEntity::kDefaultEfConstruction; + ef_construction_ = HnswStreamerEntityNew::kDefaultEfConstruction; } if (upper_max_neighbor_cnt_ == 0U) { - upper_max_neighbor_cnt_ = HnswEntity::kDefaultUpperMaxNeighborCnt; + upper_max_neighbor_cnt_ = HnswStreamerEntityNew::kDefaultUpperMaxNeighborCnt; } - if (upper_max_neighbor_cnt_ > HnswEntity::kMaxNeighborCnt) { + if (upper_max_neighbor_cnt_ > HnswStreamerEntityNew::kMaxNeighborCnt) { LOG_ERROR("[%s] must be in range (0,%d)", PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT.c_str(), - HnswEntity::kMaxNeighborCnt); + HnswStreamerEntityNew::kMaxNeighborCnt); return IndexError_InvalidArgument; } if (l0_max_neighbor_cnt_ == 0U) { - l0_max_neighbor_cnt_ = HnswEntity::kDefaultL0MaxNeighborCnt; + l0_max_neighbor_cnt_ = HnswStreamerEntityNew::kDefaultL0MaxNeighborCnt; } - if (l0_max_neighbor_cnt_ > HnswEntity::kMaxNeighborCnt) { + if (l0_max_neighbor_cnt_ > HnswStreamerEntityNew::kMaxNeighborCnt) { LOG_ERROR("MaxL0NeighborCnt must be in range (0,%d)", - HnswEntity::kMaxNeighborCnt); + HnswStreamerEntityNew::kMaxNeighborCnt); return IndexError_InvalidArgument; } if (min_neighbor_cnt_ > upper_max_neighbor_cnt_) { @@ -119,7 +119,7 @@ int HnswStreamer::init(const IndexMeta &imeta, const ailego::Params ¶ms) { } if (scaling_factor_ == 0U) { - scaling_factor_ = HnswEntity::kDefaultScalingFactor; + scaling_factor_ = HnswStreamerEntityNew::kDefaultScalingFactor; } if (scaling_factor_ < 5 || scaling_factor_ > 1000) { LOG_ERROR("[%s] must be in range [5,1000]", @@ -144,11 +144,11 @@ int HnswStreamer::init(const IndexMeta &imeta, const ailego::Params ¶ms) { prune_cnt = upper_max_neighbor_cnt_; } if (chunk_size_ == 0UL) { - chunk_size_ = HnswEntity::kDefaultChunkSize; + chunk_size_ = HnswStreamerEntityNew::kDefaultChunkSize; } - if (chunk_size_ > HnswEntity::kMaxChunkSize) { + if (chunk_size_ > HnswStreamerEntityNew::kMaxChunkSize) { LOG_ERROR("[%s] must be < %zu", PARAM_HNSW_STREAMER_CHUNK_SIZE.c_str(), - HnswEntity::kMaxChunkSize); + HnswStreamerEntityNew::kMaxChunkSize); return IndexError_InvalidArgument; } @@ -215,20 +215,20 @@ int HnswStreamer::cleanup(void) { } max_index_size_ = 0UL; - docs_hard_limit_ = HnswEntity::kDefaultDocsHardLimit; + docs_hard_limit_ = HnswStreamerEntityNew::kDefaultDocsHardLimit; docs_soft_limit_ = 0UL; - upper_max_neighbor_cnt_ = HnswEntity::kDefaultUpperMaxNeighborCnt; - l0_max_neighbor_cnt_ = HnswEntity::kDefaultL0MaxNeighborCnt; - ef_ = HnswEntity::kDefaultEf; - ef_construction_ = HnswEntity::kDefaultEfConstruction; + upper_max_neighbor_cnt_ = HnswStreamerEntityNew::kDefaultUpperMaxNeighborCnt; + l0_max_neighbor_cnt_ = HnswStreamerEntityNew::kDefaultL0MaxNeighborCnt; + ef_ = HnswStreamerEntityNew::kDefaultEf; + ef_construction_ = HnswStreamerEntityNew::kDefaultEfConstruction; bf_enabled_ = false; - scaling_factor_ = HnswEntity::kDefaultScalingFactor; - bruteforce_threshold_ = HnswEntity::kDefaultBruteForceThreshold; - max_scan_limit_ = HnswEntity::kDefaultMaxScanLimit; - min_scan_limit_ = HnswEntity::kDefaultMinScanLimit; - chunk_size_ = HnswEntity::kDefaultChunkSize; - bf_negative_prob_ = HnswEntity::kDefaultBFNegativeProbability; - max_scan_ratio_ = HnswEntity::kDefaultScanRatio; + scaling_factor_ = HnswStreamerEntityNew::kDefaultScalingFactor; + bruteforce_threshold_ = HnswStreamerEntityNew::kDefaultBruteForceThreshold; + max_scan_limit_ = HnswStreamerEntityNew::kDefaultMaxScanLimit; + min_scan_limit_ = HnswStreamerEntityNew::kDefaultMinScanLimit; + chunk_size_ = HnswStreamerEntityNew::kDefaultChunkSize; + bf_negative_prob_ = HnswStreamerEntityNew::kDefaultBFNegativeProbability; + max_scan_ratio_ = HnswStreamerEntityNew::kDefaultScanRatio; state_ = STATE_INIT; check_crc_enabled_ = false; filter_same_key_ = false; @@ -342,7 +342,7 @@ int HnswStreamer::dump(const IndexDumper::Pointer &dumper) { shared_mutex_.lock(); AILEGO_DEFER([&]() { shared_mutex_.unlock(); }); - meta_.set_searcher("HnswSearcher", HnswEntity::kRevision, ailego::Params()); + meta_.set_searcher("HnswSearcher", HnswStreamerEntityNew::kRevision, ailego::Params()); int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (ret != 0) { @@ -358,7 +358,7 @@ IndexStreamer::Context::Pointer HnswStreamer::create_context(void) const { return Context::Pointer(); } - HnswEntity::Pointer entity = entity_.clone(); + HnswStreamerEntityNew::Pointer entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("CreateContext clone init failed"); return Context::Pointer(); @@ -401,7 +401,7 @@ IndexProvider::Pointer HnswStreamer::create_provider(void) const { auto entity = entity_.clone(); if (ailego_unlikely(!entity)) { - LOG_ERROR("Clone HnswEntity failed"); + LOG_ERROR("Clone HnswStreamerEntityNew failed"); return nullptr; } return Provider::Pointer( @@ -409,7 +409,7 @@ IndexProvider::Pointer HnswStreamer::create_provider(void) const { } int HnswStreamer::update_context(HnswContext *ctx) const { - const HnswEntity::Pointer entity = entity_.clone(); + const HnswStreamerEntityNew::Pointer entity = entity_.clone(); if (!entity) { LOG_ERROR("Failed to clone search context entity"); return IndexError_Runtime; diff --git a/src/core/algorithm/hnsw/hnsw_streamer.h b/src/core/algorithm/hnsw/hnsw_streamer.h index b81106da..9613533e 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer.h +++ b/src/core/algorithm/hnsw/hnsw_streamer.h @@ -16,7 +16,7 @@ #include #include #include "hnsw_algorithm.h" -#include "hnsw_streamer_entity.h" +#include "hnsw_streamer_entity_new.h" namespace zvec { namespace core { @@ -181,7 +181,7 @@ class HnswStreamer : public IndexStreamer { } }; - HnswStreamerEntity entity_; + HnswStreamerEntityNew entity_; HnswAlgorithm::UPointer alg_; IndexMeta meta_{}; IndexMetric::Pointer metric_{}; From 8eccbac8480d2dcd855347ba1fb47b23c3894d02 Mon Sep 17 00:00:00 2001 From: "yinzefeng.yzf" Date: Wed, 11 Mar 2026 18:02:58 +0800 Subject: [PATCH 06/34] rm old entity --- .../algorithm/hnsw/hnsw_streamer_entity.cc | 701 ------------------ .../algorithm/hnsw/hnsw_streamer_entity.h | 515 ------------- 2 files changed, 1216 deletions(-) delete mode 100644 src/core/algorithm/hnsw/hnsw_streamer_entity.cc delete mode 100644 src/core/algorithm/hnsw/hnsw_streamer_entity.h diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity.cc b/src/core/algorithm/hnsw/hnsw_streamer_entity.cc deleted file mode 100644 index 71e2e477..00000000 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity.cc +++ /dev/null @@ -1,701 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "hnsw_streamer_entity.h" -#include - -// #define DEBUG_PRINT - -namespace zvec { -namespace core { - -HnswStreamerEntity::HnswStreamerEntity(IndexStreamer::Stats &stats) - : stats_(stats) {} - -HnswStreamerEntity::~HnswStreamerEntity() {} - -int HnswStreamerEntity::init(size_t max_doc_cnt) { - if (std::pow(scaling_factor(), kMaxGraphLayers) < max_doc_cnt) { - LOG_ERROR("scalingFactor=%zu is too small", scaling_factor()); - return IndexError_InvalidArgument; - } - - std::lock_guard lock(mutex_); - broker_ = std::make_shared(stats_); - upper_neighbor_index_ = std::make_shared(); - keys_map_lock_ = std::make_shared(); - keys_map_ = std::make_shared>(); - if (!keys_map_ || !upper_neighbor_index_ || !broker_ || !keys_map_lock_) { - LOG_ERROR("HnswStreamerEntity new object failed"); - return IndexError_NoMemory; - } - keys_map_->set_empty_key(kInvalidKey); - - neighbor_size_ = neighbors_size(); - upper_neighbor_size_ = upper_neighbors_size(); - - //! vector + key + level 0 neighbors - size_t size = vector_size() + sizeof(key_t) + neighbor_size_; - - size = AlignSize(size); - set_node_size(size); - return 0; -} - -int HnswStreamerEntity::cleanup() { - std::lock_guard lock(mutex_); - mutable_header()->clear(); - chunk_size_ = kDefaultChunkSize; - node_index_mask_bits_ = 0U; - node_index_mask_ = 0U; - node_cnt_per_chunk_ = 0U; - neighbor_size_ = 0U; - upper_neighbor_size_ = 0U; - if (upper_neighbor_index_) { - upper_neighbor_index_->cleanup(); - } - if (keys_map_) { - keys_map_->clear(); - } - node_chunks_.clear(); - upper_neighbor_chunks_.clear(); - filter_same_key_ = false; - get_vector_enabled_ = false; - broker_.reset(); - - return 0; -} - -int HnswStreamerEntity::update_neighbors( - level_t level, node_id_t id, - const std::vector> &neighbors) { - std::vector buffer(neighbor_size_); - NeighborsHeader *hd = reinterpret_cast(buffer.data()); - hd->neighbor_cnt = neighbors.size(); - size_t i = 0; - for (; i < neighbors.size(); ++i) { - hd->neighbors[i] = neighbors[i].first; - } - - auto loc = get_neighbor_chunk_loc(level, id); - size_t size = reinterpret_cast(&hd->neighbors[i]) - &buffer[0]; - size_t ret = loc.first->write(loc.second, hd, size); - if (ailego_unlikely(ret != size)) { - LOG_ERROR("Write neighbor header failed, ret=%zu", ret); - - return IndexError_Runtime; - } - - return 0; -} - -const Neighbors HnswStreamerEntity::get_neighbors(level_t level, - node_id_t id) const { - Chunk *chunk = nullptr; - size_t offset = 0UL; - size_t neighbor_size = neighbor_size_; - if (level == 0UL) { - uint32_t chunk_idx = id >> node_index_mask_bits_; - offset = - (id & node_index_mask_) * node_size() + vector_size() + sizeof(key_t); - - sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); - ailego_assert_with(chunk_idx < node_chunks_.size(), "invalid chunk idx"); - chunk = node_chunks_[chunk_idx].get(); - } else { - auto p = get_upper_neighbor_chunk_loc(level, id); - chunk = upper_neighbor_chunks_[p.first].get(); - offset = p.second; - neighbor_size = upper_neighbor_size_; - } - - ailego_assert_with(offset < chunk->data_size(), "invalid chunk offset"); - IndexStorage::MemoryBlock neighbor_block; - size_t size = chunk->read(offset, neighbor_block, neighbor_size); - if (ailego_unlikely(size != neighbor_size)) { - LOG_ERROR("Read neighbor header failed, ret=%zu", size); - return Neighbors(); - } - return Neighbors(std::move(neighbor_block)); -} - -//! Get vector data by key -const void *HnswStreamerEntity::get_vector(node_id_t id) const { - auto loc = get_vector_chunk_loc(id); - const void *vec = nullptr; - ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); - ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), - "invalid chunk offset"); - - size_t read_size = vector_size(); - - size_t ret = node_chunks_[loc.first]->read(loc.second, &vec, read_size); - if (ailego_unlikely(ret != read_size)) { - LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", - loc.second, read_size, ret); - } - - return vec; -} - -int HnswStreamerEntity::get_vector(const node_id_t *ids, uint32_t count, - const void **vecs) const { - for (auto i = 0U; i < count; ++i) { - auto loc = get_vector_chunk_loc(ids[i]); - ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); - ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), - "invalid chunk offset"); - - size_t read_size = vector_size(); - - size_t ret = node_chunks_[loc.first]->read(loc.second, &vecs[i], read_size); - if (ailego_unlikely(ret != read_size)) { - LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", - loc.second, read_size, ret); - return IndexError_ReadData; - } - } - return 0; -} - -int HnswStreamerEntity::get_vector(const node_id_t id, - IndexStorage::MemoryBlock &block) const { - auto loc = get_vector_chunk_loc(id); - ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); - ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), - "invalid chunk offset"); - - size_t read_size = vector_size(); - - size_t ret = node_chunks_[loc.first]->read(loc.second, block, read_size); - if (ailego_unlikely(ret != read_size)) { - LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", - loc.second, read_size, ret); - return IndexError_ReadData; - } - return 0; -} - -int HnswStreamerEntity::get_vector( - const node_id_t *ids, uint32_t count, - std::vector &vec_blocks) const { - vec_blocks.resize(count); - for (auto i = 0U; i < count; ++i) { - auto loc = get_vector_chunk_loc(ids[i]); - ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); - ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), - "invalid chunk offset"); - - size_t read_size = vector_size(); - - size_t ret = - node_chunks_[loc.first]->read(loc.second, vec_blocks[i], read_size); - if (ailego_unlikely(ret != read_size)) { - LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", - loc.second, read_size, ret); - return IndexError_ReadData; - } - } - return 0; -} - -key_t HnswStreamerEntity::get_key(node_id_t id) const { - if (use_key_info_map_) { - auto loc = get_key_chunk_loc(id); - IndexStorage::MemoryBlock key_block; - ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); - ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), - "invalid chunk offset"); - size_t ret = - node_chunks_[loc.first]->read(loc.second, key_block, sizeof(key_t)); - if (ailego_unlikely(ret != sizeof(key_t))) { - LOG_ERROR("Read vector failed, ret=%zu", ret); - return kInvalidKey; - } - - return *reinterpret_cast(key_block.data()); - } else { - return id; - } -} - -void HnswStreamerEntity::add_neighbor(level_t level, node_id_t id, - uint32_t size, node_id_t neighbor_id) { - auto loc = get_neighbor_chunk_loc(level, id); - size_t offset = - loc.second + sizeof(NeighborsHeader) + size * sizeof(node_id_t); - ailego_assert_with(size < neighbor_cnt(level), "invalid neighbor size"); - ailego_assert_with(offset < loc.first->data_size(), "invalid chunk offset"); - size_t ret = loc.first->write(offset, &neighbor_id, sizeof(node_id_t)); - if (ailego_unlikely(ret != sizeof(node_id_t))) { - LOG_ERROR("Write neighbor id failed, ret=%zu", ret); - return; - } - - uint32_t neighbors = size + 1; - ret = loc.first->write(loc.second, &neighbors, sizeof(uint32_t)); - if (ailego_unlikely(ret != sizeof(uint32_t))) { - LOG_ERROR("Write neighbor cnt failed, ret=%zu", ret); - } - - return; -} - -int HnswStreamerEntity::init_chunks(const Chunk::Pointer &header_chunk) { - if (header_chunk->data_size() < header_size()) { - LOG_ERROR("Invalid header chunk size"); - return IndexError_InvalidFormat; - } - IndexStorage::MemoryBlock header_block; - size_t size = header_chunk->read(0UL, header_block, header_size()); - if (ailego_unlikely(size != header_size())) { - LOG_ERROR("Read header chunk failed"); - return IndexError_ReadData; - } - *mutable_header() = - *reinterpret_cast(header_block.data()); - - int ret = check_hnsw_index(&header()); - if (ret != 0) { - broker_->close(); - return ret; - } - - node_chunks_.resize(broker_->get_chunk_cnt(ChunkBroker::CHUNK_TYPE_NODE)); - for (auto seq = 0UL; seq < node_chunks_.size(); ++seq) { - node_chunks_[seq] = broker_->get_chunk(ChunkBroker::CHUNK_TYPE_NODE, seq); - if (!node_chunks_[seq]) { - LOG_ERROR("Missing hnsw streamer data chunk %zu th of %zu", seq, - node_chunks_.size()); - return IndexError_InvalidFormat; - } - } - - upper_neighbor_chunks_.resize( - broker_->get_chunk_cnt(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR)); - for (auto seq = 0UL; seq < upper_neighbor_chunks_.size(); ++seq) { - upper_neighbor_chunks_[seq] = - broker_->get_chunk(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, seq); - if (!upper_neighbor_chunks_[seq]) { - LOG_ERROR("Missing hnsw streamer index chunk %zu th of %zu", seq, - upper_neighbor_chunks_.size()); - return IndexError_InvalidFormat; - } - } - - return 0; -} - -int HnswStreamerEntity::open(IndexStorage::Pointer stg, uint64_t max_index_size, - bool check_crc) { - std::lock_guard lock(mutex_); - bool huge_page = stg->isHugePage(); - LOG_DEBUG("huge_page: %d", (int)huge_page); - int ret = init_chunk_params(max_index_size, huge_page); - if (ailego_unlikely(ret != 0)) { - LOG_ERROR("init_chunk_params failed for %s", IndexError::What(ret)); - return ret; - } - ret = broker_->open(std::move(stg), max_index_size_, chunk_size_, check_crc); - if (ailego_unlikely(ret != 0)) { - LOG_ERROR("Open index failed for %s", IndexError::What(ret)); - return ret; - } - ret = upper_neighbor_index_->init(broker_, upper_neighbor_chunk_size_, - scaling_factor(), estimate_doc_capacity(), - kUpperHashMemoryInflateRatio); - if (ailego_unlikely(ret != 0)) { - LOG_ERROR("Init neighbor hash map failed"); - return ret; - } - - //! init header - auto header_chunk = broker_->get_chunk(ChunkBroker::CHUNK_TYPE_HEADER, - ChunkBroker::kDefaultChunkSeqId); - if (!header_chunk) { // open empty index, create one - auto p = - broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_HEADER, - ChunkBroker::kDefaultChunkSeqId, header_size()); - if (ailego_unlikely(p.first != 0)) { - LOG_ERROR("Alloc header chunk failed"); - return p.first; - } - size_t size = p.second->write(0UL, &header(), header_size()); - if (ailego_unlikely(size != header_size())) { - LOG_ERROR("Write header chunk failed"); - return IndexError_WriteData; - } - return 0; - } - - //! Open an exist hnsw index - ret = init_chunks(header_chunk); - if (ailego_unlikely(ret != 0)) { - return ret; - } - - //! total docs including features wrote in index but neighbors may not ready - node_id_t total_vecs = 0; - if (node_chunks_.size() > 0) { - size_t last_idx = node_chunks_.size() - 1; - auto last_chunk = node_chunks_[last_idx]; - if (last_chunk->data_size() % node_size()) { - LOG_WARN("The index may broken"); - return IndexError_InvalidFormat; - } - total_vecs = last_idx * node_cnt_per_chunk_ + - node_chunks_[last_idx]->data_size() / node_size(); - } - - LOG_INFO( - "Open index, l0NeighborCnt=%zu upperNeighborCnt=%zu " - "efConstruction=%zu curDocCnt=%u totalVecs=%u maxLevel=%u", - l0_neighbor_cnt(), upper_neighbor_cnt(), ef_construction(), doc_cnt(), - total_vecs, cur_max_level()); - //! try to correct the docCnt if index not fully flushed - if (doc_cnt() != total_vecs) { - LOG_WARN("Index closed abnormally, using totalVecs as curDocCnt"); - *mutable_doc_cnt() = total_vecs; - } - if (filter_same_key_ || get_vector_enabled_) { - if (use_key_info_map_) { - for (node_id_t id = 0U; id < doc_cnt(); ++id) { - if (get_key(id) == kInvalidKey) { - continue; - } - (*keys_map_)[get_key(id)] = id; - } - } - } - - stats_.set_loaded_count(doc_cnt()); - - return 0; -} - -int HnswStreamerEntity::close() { - LOG_DEBUG("close index"); - - std::lock_guard lock(mutex_); - flush_header(); - mutable_header()->reset(); - upper_neighbor_index_->cleanup(); - keys_map_->clear(); - header_.clear(); - node_chunks_.clear(); - upper_neighbor_chunks_.clear(); - - return broker_->close(); -} - -int HnswStreamerEntity::flush(uint64_t checkpoint) { - LOG_INFO("Flush index, curDocs=%u", doc_cnt()); - - std::lock_guard lock(mutex_); - flush_header(); - int ret = broker_->flush(checkpoint); - if (ret != 0) { - return ret; - } - - return 0; -} - -int HnswStreamerEntity::dump(const IndexDumper::Pointer &dumper) { - LOG_INFO("Dump index, curDocs=%u", doc_cnt()); - - //! sort by keys, to support get_vector by key in searcher - std::vector keys(doc_cnt()); - for (node_id_t i = 0; i < doc_cnt(); ++i) { - keys[i] = get_key(i); - } - - //! dump neighbors - auto get_level = [&](node_id_t id) { - auto it = upper_neighbor_index_->find(id); - if (it == upper_neighbor_index_->end()) { - return 0U; - }; - auto meta = reinterpret_cast(&it->second); - return meta->level; - }; - auto ret = dump_segments(dumper, keys.data(), get_level); - if (ailego_unlikely(ret < 0)) { - return ret; - } - *stats_.mutable_dumped_size() += ret; - - return 0; -} - -int HnswStreamerEntity::check_hnsw_index(const HNSWHeader *hd) const { - if (l0_neighbor_cnt() != hd->l0_neighbor_cnt() || - upper_neighbor_cnt() != hd->upper_neighbor_cnt()) { - LOG_ERROR("Param neighbor cnt: %zu:%zu mismatch index previous %zu:%zu", - l0_neighbor_cnt(), upper_neighbor_cnt(), hd->l0_neighbor_cnt(), - hd->upper_neighbor_cnt()); - return IndexError_Mismatch; - } - if (vector_size() != hd->vector_size()) { - LOG_ERROR("vector size %zu mismatch index previous %zu", vector_size(), - hd->vector_size()); - return IndexError_Mismatch; - } - if (ef_construction() != hd->ef_construction()) { - LOG_WARN("Param efConstruction %zu mismatch index previous %zu", - ef_construction(), hd->ef_construction()); - } - if (scaling_factor() != hd->scaling_factor()) { - LOG_WARN("Param scalingFactor %zu mismatch index previous %zu", - scaling_factor(), hd->scaling_factor()); - return IndexError_Mismatch; - } - if (prune_cnt() != hd->neighbor_prune_cnt()) { - LOG_WARN("Param pruneCnt %zu mismatch index previous %zu", prune_cnt(), - hd->neighbor_prune_cnt()); - return IndexError_Mismatch; - } - if ((hd->entry_point() != kInvalidNodeId && - hd->entry_point() >= hd->doc_cnt()) || - (hd->entry_point() == kInvalidNodeId && hd->doc_cnt() > 0U)) { - LOG_WARN("Invalid entryPoint %u, docCnt %u", hd->entry_point(), - hd->doc_cnt()); - return IndexError_InvalidFormat; - } - if (hd->entry_point() == kInvalidNodeId && - broker_->get_chunk_cnt(ChunkBroker::CHUNK_TYPE_NODE) > 0) { - LOG_WARN("The index is broken, maybe it haven't flush"); - return IndexError_InvalidFormat; - } - - return 0; -} - -int HnswStreamerEntity::add_vector(level_t level, key_t key, const void *vec, - node_id_t *id) { - Chunk::Pointer node_chunk; - size_t chunk_offset = -1UL; - - std::lock_guard lock(mutex_); - // duplicate check - if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { - LOG_WARN("Try to add duplicate key, ignore it"); - return IndexError_Duplicate; - } - - node_id_t local_id = static_cast(doc_cnt()); - uint32_t chunk_index = node_chunks_.size() - 1U; - if (chunk_index == -1U || - (node_chunks_[chunk_index]->data_size() >= - node_cnt_per_chunk_ * node_size())) { // no space left and need to alloc - if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { - LOG_ERROR("add vector failed for no memory quota"); - return IndexError_IndexFull; - } - chunk_index++; - auto p = broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_NODE, chunk_index, - chunk_size_); - if (ailego_unlikely(p.first != 0)) { - LOG_ERROR("Alloc data chunk failed"); - return p.first; - } - node_chunk = p.second; - chunk_offset = 0UL; - node_chunks_.emplace_back(node_chunk); - } else { - node_chunk = node_chunks_[chunk_index]; - chunk_offset = node_chunk->data_size(); - } - - size_t size = node_chunk->write(chunk_offset, vec, vector_size()); - if (ailego_unlikely(size != vector_size())) { - LOG_ERROR("Chunk write vec failed, ret=%zu", size); - return IndexError_WriteData; - } - size = node_chunk->write(chunk_offset + vector_size(), &key, sizeof(key_t)); - if (ailego_unlikely(size != sizeof(key_t))) { - LOG_ERROR("Chunk write vec failed, ret=%zu", size); - return IndexError_WriteData; - } - //! level 0 neighbors is inited to zero by default - - int ret = add_upper_neighbor(level, local_id); - if (ret != 0) { - return ret; - } - - chunk_offset += node_size(); - if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { - LOG_ERROR("Chunk resize to %zu failed", chunk_offset); - return IndexError_Runtime; - } - if (filter_same_key_ || get_vector_enabled_) { - if (use_key_info_map_) { - keys_map_lock_->lock(); - (*keys_map_)[key] = local_id; - keys_map_lock_->unlock(); - } - } - - *mutable_doc_cnt() += 1; - broker_->mark_dirty(); - *id = local_id; - - return 0; -} - -int HnswStreamerEntity::add_vector_with_id(level_t level, node_id_t id, - const void *vec) { - Chunk::Pointer node_chunk; - size_t chunk_offset = -1UL; - key_t key = id; - - std::lock_guard lock(mutex_); - - // duplicate check - if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { - LOG_WARN("Try to add duplicate key, ignore it"); - return IndexError_Duplicate; - } - - // set node_chunk & chunk_offset if succeed - auto func_get_node_chunk_and_offset = [&](node_id_t node_id) -> int { - uint32_t chunk_index = node_id >> node_index_mask_bits_; - ailego_assert_with(chunk_index <= node_chunks_.size(), "invalid chunk idx"); - // belongs to next chunk - if (chunk_index == node_chunks_.size()) { - if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { - LOG_ERROR("add vector failed for no memory quota"); - return IndexError_IndexFull; - } - auto p = broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_NODE, chunk_index, - chunk_size_); - if (ailego_unlikely(p.first != 0)) { - LOG_ERROR("Alloc data chunk failed"); - return p.first; - } - node_chunk = p.second; - node_chunks_.emplace_back(node_chunk); - } - - node_chunk = node_chunks_[chunk_index]; - chunk_offset = (node_id & node_index_mask_) * node_size(); - return 0; - }; - - for (size_t start_id = doc_cnt(); start_id < id; ++start_id) { - if (auto ret = func_get_node_chunk_and_offset(start_id); ret != 0) { - LOG_ERROR("func_get_node_chunk_and_offset failed"); - return ret; - } - size_t size = node_chunk->write(chunk_offset + vector_size(), &kInvalidKey, - sizeof(key_t)); - if (ailego_unlikely(size != sizeof(key_t))) { - LOG_ERROR("Chunk write key failed, ret=%zu", size); - return IndexError_WriteData; - } - - chunk_offset += node_size(); - if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { - LOG_ERROR("Chunk resize to %zu failed", chunk_offset); - return IndexError_Runtime; - } - } - - if (auto ret = func_get_node_chunk_and_offset(id); ret != 0) { - LOG_ERROR("func_get_node_chunk_and_offset failed"); - return ret; - } - - size_t size = node_chunk->write(chunk_offset, vec, vector_size()); - if (ailego_unlikely(size != vector_size())) { - LOG_ERROR("Chunk write vec failed, ret=%zu", size); - return IndexError_WriteData; - } - - size = node_chunk->write(chunk_offset + vector_size(), &key, sizeof(key_t)); - if (ailego_unlikely(size != sizeof(key_t))) { - LOG_ERROR("Chunk write vec failed, ret=%zu", size); - return IndexError_WriteData; - } - //! level 0 neighbors is inited to zero by default - - int ret = add_upper_neighbor(level, id); - if (ret != 0) { - return ret; - } - - if (*mutable_doc_cnt() <= id) { - *mutable_doc_cnt() = id + 1; - chunk_offset += node_size(); - if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { - LOG_ERROR("Chunk resize to %zu failed", chunk_offset); - return IndexError_Runtime; - } - } - - if (filter_same_key_ || get_vector_enabled_) { - if (use_key_info_map_) { - keys_map_lock_->lock(); - (*keys_map_)[key] = id; - keys_map_lock_->unlock(); - } - } - - broker_->mark_dirty(); - - return 0; -} - -void HnswStreamerEntity::update_ep_and_level(node_id_t ep, level_t level) { - HnswEntity::update_ep_and_level(ep, level); - flush_header(); - - return; -} - -const HnswEntity::Pointer HnswStreamerEntity::clone() const { - std::vector node_chunks; - node_chunks.reserve(node_chunks_.size()); - for (size_t i = 0UL; i < node_chunks_.size(); ++i) { - node_chunks.emplace_back(node_chunks_[i]->clone()); - if (ailego_unlikely(!node_chunks[i])) { - LOG_ERROR("HnswStreamerEntity get chunk failed in clone"); - return HnswEntity::Pointer(); - } - } - - std::vector upper_neighbor_chunks; - upper_neighbor_chunks.reserve(upper_neighbor_chunks_.size()); - for (size_t i = 0UL; i < upper_neighbor_chunks_.size(); ++i) { - upper_neighbor_chunks.emplace_back(upper_neighbor_chunks_[i]->clone()); - if (ailego_unlikely(!upper_neighbor_chunks[i])) { - LOG_ERROR("HnswStreamerEntity get chunk failed in clone"); - return HnswEntity::Pointer(); - } - } - - HnswStreamerEntity *entity = new (std::nothrow) HnswStreamerEntity( - stats_, header(), chunk_size_, node_index_mask_bits_, - upper_neighbor_mask_bits_, filter_same_key_, get_vector_enabled_, - upper_neighbor_index_, keys_map_lock_, keys_map_, use_key_info_map_, - std::move(node_chunks), std::move(upper_neighbor_chunks), broker_); - if (ailego_unlikely(!entity)) { - LOG_ERROR("HnswStreamerEntity new failed"); - } - return HnswEntity::Pointer(entity); -} - -} // namespace core -} // namespace zvec diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity.h b/src/core/algorithm/hnsw/hnsw_streamer_entity.h deleted file mode 100644 index 1a01b141..00000000 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity.h +++ /dev/null @@ -1,515 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "hnsw_chunk.h" -#include "hnsw_entity.h" -#include "hnsw_index_hash.h" -#include "hnsw_params.h" - -namespace zvec { -namespace core { - -//! HnswStreamerEntity manage vector data, pkey, and node's neighbors -class HnswStreamerEntity : public HnswEntity { - public: - //! Cleanup - //! return 0 on success, or errCode in failure - virtual int cleanup() override; - - //! Make a copy of streamer entity, to support thread-safe operation. - //! The segment in container cannot be read concurrenly - virtual const HnswEntity::Pointer clone() const override; - - //! Get primary key of the node id - virtual key_t get_key(node_id_t id) const override; - - //! Get vector feature data by key - virtual const void *get_vector(node_id_t id) const override; - - //! Get vectors feature data by local ids - virtual int get_vector(const node_id_t *ids, uint32_t count, - const void **vecs) const override; - - virtual int get_vector(const node_id_t id, - IndexStorage::MemoryBlock &block) const override; - - virtual int get_vector( - const node_id_t *ids, uint32_t count, - std::vector &vec_blocks) const override; - - //! Get the node id's neighbors on graph level - //! Note: the neighbors cannot be modified, using the following - //! method to get WritableNeighbors if want to - virtual const Neighbors get_neighbors(level_t level, - node_id_t id) const override; - - //! Add vector and key to hnsw entity, and local id will be saved in id - virtual int add_vector(level_t level, key_t key, const void *vec, - node_id_t *id) override; - - //! Add vector and id to hnsw entity - virtual int add_vector_with_id(level_t level, node_id_t id, - const void *vec) override; - - virtual int update_neighbors( - level_t level, node_id_t id, - const std::vector> &neighbors) override; - - //! Append neighbor_id to node id neighbors on level - //! Notice: the caller must be ensure the neighbors not full - virtual void add_neighbor(level_t level, node_id_t id, uint32_t size, - node_id_t neighbor_id) override; - - //! Dump index by dumper - virtual int dump(const IndexDumper::Pointer &dumper) override; - - virtual void update_ep_and_level(node_id_t ep, level_t level) override; - - void set_use_key_info_map(bool use_id_map) { - use_key_info_map_ = use_id_map; - LOG_DEBUG("use_key_info_map_: %d", (int)use_key_info_map_); - } - - public: - //! Constructor - HnswStreamerEntity(IndexStreamer::Stats &stats); - - //! Destructor - ~HnswStreamerEntity(); - - //! Get vector feature data by key - virtual const void *get_vector_by_key(key_t key) const override { - auto id = get_id(key); - return id == kInvalidNodeId ? nullptr : get_vector(id); - } - - virtual int get_vector_by_key( - const key_t key, IndexStorage::MemoryBlock &block) const override { - auto id = get_id(key); - if (id != kInvalidNodeId) { - return get_vector(id, block); - } else { - return IndexError_InvalidArgument; - } - } - - //! Init entity - int init(size_t max_doc_cnt); - - //! Flush graph entity to disk - //! return 0 on success, or errCode in failure - int flush(uint64_t checkpoint); - - //! Open entity from storage - //! return 0 on success, or errCode in failure - int open(IndexStorage::Pointer stg, uint64_t max_index_size, bool check_crc); - - //! Close entity - //! return 0 on success, or errCode in failure - int close(); - - //! Set meta information from entity - int set_index_meta(const IndexMeta &meta) const { - return IndexHelper::SerializeToStorage(meta, broker_->storage().get()); - } - - //! Get meta information from entity - int get_index_meta(IndexMeta *meta) const { - return IndexHelper::DeserializeFromStorage(broker_->storage().get(), meta); - } - - //! Set params: chunk size - inline void set_chunk_size(size_t val) { - chunk_size_ = val; - } - - //! Set params - inline void set_filter_same_key(bool val) { - filter_same_key_ = val; - } - - //! Set params - inline void set_get_vector(bool val) { - get_vector_enabled_ = val; - } - - //! Get vector local id by key - inline node_id_t get_id(key_t key) const { - if (use_key_info_map_) { - keys_map_lock_->lock_shared(); - auto it = keys_map_->find(key); - keys_map_lock_->unlock_shared(); - return it == keys_map_->end() ? kInvalidNodeId : it->second; - } else { - return key; - } - } - - void print_key_map() const { - std::cout << "key map begins" << std::endl; - - auto iter = keys_map_->begin(); - while (iter != keys_map_->end()) { - std::cout << "key: " << iter->first << ", id: " << iter->second - << std::endl; - ; - iter++; - } - - std::cout << "key map ends" << std::endl; - } - - //! Get l0 neighbors size - inline size_t neighbors_size() const { - return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); - } - - //! Get neighbors size for level > 0 - inline size_t upper_neighbors_size() const { - return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); - } - - - private: - union UpperNeighborIndexMeta { - struct { - uint32_t level : 4; - uint32_t index : 28; // index is composite type: chunk idx, and the - // N th neighbors in chunk, they two composite - // the 28 bits location - }; - uint32_t data; - }; - - template - using HashMap = google::dense_hash_map>; - template - using HashMapPointer = std::shared_ptr>; - - template - using HashSet = google::dense_hash_set>; - template - using HashSetPointer = std::shared_ptr>; - - //! upper neighbor index hashmap - using NIHashMap = HnswIndexHashMap; - using NIHashMapPointer = std::shared_ptr; - - //! Private construct, only be called by clone method - HnswStreamerEntity(IndexStreamer::Stats &stats, const HNSWHeader &hd, - size_t chunk_size, uint32_t node_index_mask_bits, - uint32_t upper_neighbor_mask_bits, bool filter_same_key, - bool get_vector_enabled, - const NIHashMapPointer &upper_neighbor_index, - std::shared_ptr &keys_map_lock, - const HashMapPointer &keys_map, - bool use_key_info_map, - std::vector &&node_chunks, - std::vector &&upper_neighbor_chunks, - const ChunkBroker::Pointer &broker) - : stats_(stats), - chunk_size_(chunk_size), - node_index_mask_bits_(node_index_mask_bits), - node_cnt_per_chunk_(1UL << node_index_mask_bits_), - node_index_mask_(node_cnt_per_chunk_ - 1), - upper_neighbor_mask_bits_(upper_neighbor_mask_bits), - upper_neighbor_mask_((1U << upper_neighbor_mask_bits_) - 1), - filter_same_key_(filter_same_key), - get_vector_enabled_(get_vector_enabled), - use_key_info_map_(use_key_info_map), - upper_neighbor_index_(upper_neighbor_index), - keys_map_lock_(keys_map_lock), - keys_map_(keys_map), - node_chunks_(std::move(node_chunks)), - upper_neighbor_chunks_(std::move(upper_neighbor_chunks)), - broker_(broker) { - *mutable_header() = hd; - - neighbor_size_ = neighbors_size(); - upper_neighbor_size_ = upper_neighbors_size(); - } - - //! Called only in searching procedure per context, so no need to lock - void sync_chunks(ChunkBroker::CHUNK_TYPE type, size_t idx, - std::vector *chunks) const { - if (ailego_likely(idx < chunks->size())) { - return; - } - for (size_t i = chunks->size(); i <= idx; ++i) { - auto chunk = broker_->get_chunk(type, i); - // the storage can ensure get chunk will success after the first get - ailego_assert_with(!!chunk, "get chunk failed"); - chunks->emplace_back(std::move(chunk)); - } - } - - //! return pair: chunk index + chunk offset - inline std::pair get_vector_chunk_loc( - node_id_t id) const { - uint32_t chunk_idx = id >> node_index_mask_bits_; - uint32_t offset = (id & node_index_mask_) * node_size(); - - sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); - return std::make_pair(chunk_idx, offset); - } - - //! return pair: chunk index + chunk offset - inline std::pair get_key_chunk_loc(node_id_t id) const { - uint32_t chunk_idx = id >> node_index_mask_bits_; - uint32_t offset = (id & node_index_mask_) * node_size() + vector_size(); - - sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); - return std::make_pair(chunk_idx, offset); - } - - inline std::pair get_upper_neighbor_chunk_loc( - level_t level, node_id_t id) const { - auto it = upper_neighbor_index_->find(id); - ailego_assert_abort(it != upper_neighbor_index_->end(), - "Get upper neighbor header failed"); - auto meta = reinterpret_cast(&it->second); - uint32_t chunk_idx = (meta->index) >> upper_neighbor_mask_bits_; - uint32_t offset = (((meta->index) & upper_neighbor_mask_) + level - 1) * - upper_neighbor_size_; - sync_chunks(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, chunk_idx, - &upper_neighbor_chunks_); - ailego_assert_abort(chunk_idx < upper_neighbor_chunks_.size(), - "invalid chunk idx"); - ailego_assert_abort(offset < upper_neighbor_chunks_[chunk_idx]->data_size(), - "invalid chunk offset"); - return std::make_pair(chunk_idx, offset); - } - - //! return pair: chunk + chunk offset - inline std::pair get_neighbor_chunk_loc(level_t level, - node_id_t id) const { - if (level == 0UL) { - uint32_t chunk_idx = id >> node_index_mask_bits_; - uint32_t offset = - (id & node_index_mask_) * node_size() + vector_size() + sizeof(key_t); - - sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); - ailego_assert_abort(chunk_idx < node_chunks_.size(), "invalid chunk idx"); - ailego_assert_abort(offset < node_chunks_[chunk_idx]->data_size(), - "invalid chunk offset"); - return std::make_pair(node_chunks_[chunk_idx].get(), offset); - } else { - auto p = get_upper_neighbor_chunk_loc(level, id); - return std::make_pair(upper_neighbor_chunks_[p.first].get(), p.second); - } - } - - //! Chunk hnsw index valid - int check_hnsw_index(const HNSWHeader *hd) const; - - size_t get_total_upper_neighbors_size(level_t level) const { - return level * upper_neighbor_size_; - } - - //! Add upper neighbor header and reserve space for upper neighbor - int add_upper_neighbor(level_t level, node_id_t id) { - if (level == 0) { - return 0; - } - Chunk::Pointer chunk; - uint64_t chunk_offset = -1UL; - size_t neighbors_size = get_total_upper_neighbors_size(level); - uint64_t chunk_index = upper_neighbor_chunks_.size() - 1UL; - if (chunk_index == -1UL || - (upper_neighbor_chunks_[chunk_index]->padding_size() < - neighbors_size)) { // no space left and need to alloc - chunk_index++; - if (ailego_unlikely(upper_neighbor_chunks_.capacity() == - upper_neighbor_chunks_.size())) { - LOG_ERROR("add upper neighbor failed for no memory quota"); - return IndexError_IndexFull; - } - auto p = broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, - chunk_index, upper_neighbor_chunk_size_); - if (ailego_unlikely(p.first != 0)) { - LOG_ERROR("Alloc data chunk failed"); - return p.first; - } - chunk = p.second; - chunk_offset = 0UL; - upper_neighbor_chunks_.emplace_back(chunk); - } else { - chunk = upper_neighbor_chunks_[chunk_index]; - chunk_offset = chunk->data_size(); - } - ailego_assert_with((size_t)level < kMaxGraphLayers, "invalid level"); - ailego_assert_with(chunk_offset % upper_neighbor_size_ == 0, - "invalid offset"); - ailego_assert_with((chunk_offset / upper_neighbor_size_) < - (1U << upper_neighbor_mask_bits_), - "invalid offset"); - ailego_assert_with(chunk_index < (1U << (28 - upper_neighbor_mask_bits_)), - "invalid chunk index"); - UpperNeighborIndexMeta meta; - meta.level = level; - meta.index = (chunk_index << upper_neighbor_mask_bits_) | - (chunk_offset / upper_neighbor_size_); - chunk_offset += upper_neighbor_size_ * level; - if (ailego_unlikely(!upper_neighbor_index_->insert(id, meta.data))) { - LOG_ERROR("HashMap insert value failed"); - return IndexError_Runtime; - } - - if (ailego_unlikely(chunk->resize(chunk_offset) != chunk_offset)) { - LOG_ERROR("Chunk resize to %zu failed", (size_t)chunk_offset); - return IndexError_Runtime; - } - - return 0; - } - - size_t estimate_doc_capacity() const { - return node_chunks_.capacity() * node_cnt_per_chunk_; - } - - int init_chunk_params(size_t max_index_size, bool huge_page) { - node_cnt_per_chunk_ = std::max(1, chunk_size_ / node_size()); - //! align node cnt per chunk to pow of 2 - node_index_mask_bits_ = std::ceil(std::log2(node_cnt_per_chunk_)); - node_cnt_per_chunk_ = 1UL << node_index_mask_bits_; - if (huge_page) { - chunk_size_ = AlignHugePageSize(node_cnt_per_chunk_ * node_size()); - } else { - chunk_size_ = AlignPageSize(node_cnt_per_chunk_ * node_size()); - } - node_index_mask_ = node_cnt_per_chunk_ - 1; - - if (max_index_size == 0UL) { - max_index_size_ = chunk_size_ * kDefaultMaxChunkCnt; - } else { - max_index_size_ = max_index_size; - } - - //! To get a balanced upper neighbor chunk size. - //! If the upper chunk size is equal to node chunk size, it may waste - //! upper neighbor chunk space; if the upper neighbor chunk size is too - //! small, the will need large upper neighbor chunks index space. So to - //! get a balanced ratio be sqrt of the node/neighbor size ratio - float ratio = - std::sqrt(node_size() * scaling_factor() * 1.0f / upper_neighbor_size_); - if (huge_page) { - upper_neighbor_chunk_size_ = AlignHugePageSize( - std::max(get_total_upper_neighbors_size(kMaxGraphLayers), - static_cast(chunk_size_ / ratio))); - } else { - upper_neighbor_chunk_size_ = AlignPageSize( - std::max(get_total_upper_neighbors_size(kMaxGraphLayers), - static_cast(chunk_size_ / ratio))); - } - upper_neighbor_mask_bits_ = - std::ceil(std::log2(upper_neighbor_chunk_size_ / upper_neighbor_size_)); - upper_neighbor_mask_ = (1 << upper_neighbor_mask_bits_) - 1; - - size_t max_node_chunk_cnt = std::ceil(max_index_size_ / chunk_size_); - size_t max_upper_chunk_cnt = std::ceil( - (max_node_chunk_cnt * node_cnt_per_chunk_ * 1.0f / scaling_factor()) / - (upper_neighbor_chunk_size_ / upper_neighbor_size_)); - max_upper_chunk_cnt = - max_upper_chunk_cnt + std::ceil(max_upper_chunk_cnt / scaling_factor()); - - //! reserve space to avoid memmove in chunks vector emplace chunk, so - //! as to lock-free in reading chunk - node_chunks_.reserve(max_node_chunk_cnt); - upper_neighbor_chunks_.reserve(max_upper_chunk_cnt); - - LOG_DEBUG( - "Settings: nodeSize=%zu chunkSize=%u upperNeighborSize=%u " - "upperNeighborChunkSize=%u " - "nodeCntPerChunk=%u maxChunkCnt=%zu maxNeighborChunkCnt=%zu " - "maxIndexSize=%zu ratio=%.3f", - node_size(), chunk_size_, upper_neighbor_size_, - upper_neighbor_chunk_size_, node_cnt_per_chunk_, max_node_chunk_cnt, - max_upper_chunk_cnt, max_index_size_, ratio); - - return 0; - } - - //! Init node chunk and neighbor chunks - int init_chunks(const Chunk::Pointer &header_chunk); - - int flush_header(void) { - if (!broker_->dirty()) { - // do not need to flush - return 0; - } - auto header_chunk = broker_->get_chunk(ChunkBroker::CHUNK_TYPE_HEADER, - ChunkBroker::kDefaultChunkSeqId); - if (ailego_unlikely(!header_chunk)) { - LOG_ERROR("get header chunk failed"); - return IndexError_Runtime; - } - size_t size = header_chunk->write(0UL, &header(), header_size()); - if (ailego_unlikely(size != header_size())) { - LOG_ERROR("Write header chunk failed"); - return IndexError_WriteData; - } - - return 0; - } - - private: - HnswStreamerEntity(const HnswStreamerEntity &) = delete; - HnswStreamerEntity &operator=(const HnswStreamerEntity &) = delete; - static constexpr uint64_t kUpperHashMemoryInflateRatio = 2.0f; - - private: - IndexStreamer::Stats &stats_; - HNSWHeader header_{}; - std::mutex mutex_{}; - size_t max_index_size_{0UL}; - uint32_t chunk_size_{kDefaultChunkSize}; - uint32_t upper_neighbor_chunk_size_{kDefaultChunkSize}; - uint32_t node_index_mask_bits_{0U}; - uint32_t node_cnt_per_chunk_{0U}; - uint32_t node_index_mask_{0U}; - uint32_t neighbor_size_{0U}; - uint32_t upper_neighbor_size_{0U}; - //! UpperNeighborIndex.index composite chunkIdx and offset in chunk by the - //! following mask - uint32_t upper_neighbor_mask_bits_{0U}; - uint32_t upper_neighbor_mask_{0U}; - bool filter_same_key_{false}; - bool get_vector_enabled_{false}; - bool use_key_info_map_{true}; - - NIHashMapPointer upper_neighbor_index_{}; - - mutable std::shared_ptr keys_map_lock_{}; - HashMapPointer keys_map_{}; - - //! the chunks will be changed in searcher, so need mutable - //! data chunk include: vector, key, level 0 neighbors - mutable std::vector node_chunks_{}; - - //! upper neighbor chunk inlude: UpperNeighborHeader + (1~level) neighbors - mutable std::vector upper_neighbor_chunks_{}; - - ChunkBroker::Pointer broker_{}; // chunk broker -}; - -} // namespace core -} // namespace zvec \ No newline at end of file From b67277fd25ca9410d0b84ac41d98983532fca8b4 Mon Sep 17 00:00:00 2001 From: "yinzefeng.yzf" Date: Wed, 11 Mar 2026 20:28:57 +0800 Subject: [PATCH 07/34] fix --- .../hnsw/hnsw_streamer_entity_new.cc | 4 +- .../algorithm/hnsw/hnsw_streamer_entity_new.h | 51 ++++++++++--------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc index 5d8fc439..616c4258 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc @@ -718,8 +718,8 @@ int HnswStreamerEntityNew::add_vector_with_id(level_t level, node_id_t id, } void HnswStreamerEntityNew::update_ep_and_level(node_id_t ep, level_t level) { - header_.hnsw.entry_point = ep; - header_.hnsw.max_level = level; + base_header_.hnsw.entry_point = ep; + base_header_.hnsw.max_level = level; flush_header(); return; diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h index cf19b63d..40a2d3e3 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h @@ -105,95 +105,95 @@ class HnswStreamerEntityNew { //! Get max neighbor size of graph level inline size_t neighbor_cnt(level_t level) const { - return level == 0 ? header_.graph.l0_neighbor_count - : header_.hnsw.upper_neighbor_count; + return level == 0 ? base_header_.graph.l0_neighbor_count + : base_header_.hnsw.upper_neighbor_count; } //! get max neighbor size of graph level 0 inline size_t l0_neighbor_cnt() const { - return header_.graph.l0_neighbor_count; + return base_header_.graph.l0_neighbor_count; } //! get min neighbor size of graph inline size_t min_neighbor_cnt() const { - return header_.graph.min_neighbor_count; + return base_header_.graph.min_neighbor_count; } //! get upper neighbor size of graph level other than 0 inline size_t upper_neighbor_cnt() const { - return header_.hnsw.upper_neighbor_count; + return base_header_.hnsw.upper_neighbor_count; } //! Get current total doc of the hnsw graph inline node_id_t *mutable_doc_cnt() { - return &header_.graph.doc_count; + return &base_header_.graph.doc_count; } inline node_id_t doc_cnt() const { - return header_.graph.doc_count; + return base_header_.graph.doc_count; } //! Get hnsw graph scaling params inline size_t scaling_factor() const { - return header_.hnsw.scaling_factor; + return base_header_.hnsw.scaling_factor; } //! Get prune_size inline size_t prune_cnt() const { - return header_.graph.prune_neighbor_count; + return base_header_.graph.prune_neighbor_count; } //! Current entity of top level graph inline node_id_t entry_point() const { - return header_.hnsw.entry_point; + return base_header_.hnsw.entry_point; } //! Current max graph level inline level_t cur_max_level() const { - return header_.hnsw.max_level; + return base_header_.hnsw.max_level; } //! Retrieve index vector size size_t vector_size() const { - return header_.graph.vector_size; + return base_header_.graph.vector_size; } //! Retrieve node size size_t node_size() const { - return header_.graph.node_size; + return base_header_.graph.node_size; } //! Retrieve ef constuction size_t ef_construction() const { - return header_.graph.ef_construction; + return base_header_.graph.ef_construction; } void set_vector_size(size_t size) { - header_.graph.vector_size = size; + base_header_.graph.vector_size = size; } void set_prune_cnt(size_t v) { - header_.graph.prune_neighbor_count = v; + base_header_.graph.prune_neighbor_count = v; } void set_scaling_factor(size_t val) { - header_.hnsw.scaling_factor = val; + base_header_.hnsw.scaling_factor = val; } void set_l0_neighbor_cnt(size_t cnt) { - header_.graph.l0_neighbor_count = cnt; + base_header_.graph.l0_neighbor_count = cnt; } void set_min_neighbor_cnt(size_t cnt) { - header_.graph.min_neighbor_count = cnt; + base_header_.graph.min_neighbor_count = cnt; } void set_upper_neighbor_cnt(size_t cnt) { - header_.hnsw.upper_neighbor_count = cnt; + base_header_.hnsw.upper_neighbor_count = cnt; } void set_ef_construction(size_t ef) { - header_.graph.ef_construction = ef; + base_header_.graph.ef_construction = ef; } static int CalcAndAddPadding(const IndexDumper::Pointer &dumper, @@ -201,19 +201,19 @@ class HnswStreamerEntityNew { protected: inline const HNSWHeader &header() const { - return header_; + return base_header_; } inline HNSWHeader *mutable_header() { - return &header_; + return &base_header_; } inline size_t header_size() const { - return sizeof(header_); + return sizeof(base_header_); } void set_node_size(size_t size) { - header_.graph.node_size = size; + base_header_.graph.node_size = size; } //! Dump all segment by dumper @@ -707,6 +707,7 @@ class HnswStreamerEntityNew { private: IndexStreamer::Stats &stats_; + HNSWHeader base_header_{}; HNSWHeader header_{}; std::mutex mutex_{}; size_t max_index_size_{0UL}; From 6b4eea8388e6a01b3202c88dcb9eb506690394bc Mon Sep 17 00:00:00 2001 From: "yinzefeng.yzf" Date: Wed, 11 Mar 2026 22:10:44 +0800 Subject: [PATCH 08/34] upd entity --- src/core/algorithm/hnsw/hnsw_algorithm.cc | 7 +- src/core/algorithm/hnsw/hnsw_context.h | 10 +- .../algorithm/hnsw/hnsw_dist_calculator.h | 15 +- src/core/algorithm/hnsw/hnsw_index_provider.h | 5 +- src/core/algorithm/hnsw/hnsw_streamer.h | 4 +- .../hnsw/hnsw_streamer_entity_new.cc | 79 +++++++--- .../algorithm/hnsw/hnsw_streamer_entity_new.h | 135 +++++++++--------- 7 files changed, 151 insertions(+), 104 deletions(-) diff --git a/src/core/algorithm/hnsw/hnsw_algorithm.cc b/src/core/algorithm/hnsw/hnsw_algorithm.cc index e5561544..83d286da 100644 --- a/src/core/algorithm/hnsw/hnsw_algorithm.cc +++ b/src/core/algorithm/hnsw/hnsw_algorithm.cc @@ -123,7 +123,7 @@ void HnswAlgorithm::select_entry_point(level_t level, node_id_t *entry_point, } std::vector neighbor_vec_blocks; - int ret = entity.get_vector(&neighbors[0], size, neighbor_vec_blocks); + int ret = entity.get_vector_new(&neighbors[0], size, neighbor_vec_blocks); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_vector())++; } @@ -232,7 +232,8 @@ void HnswAlgorithm::search_neighbors(level_t level, node_id_t *entry_point, } std::vector neighbor_vec_blocks; - int ret = entity.get_vector(neighbor_ids.data(), size, neighbor_vec_blocks); + int ret = + entity.get_vector_new(neighbor_ids.data(), size, neighbor_vec_blocks); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_vector())++; } @@ -356,7 +357,7 @@ void HnswAlgorithm::expand_neighbors_by_group(TopkHeap &topk, std::vector neighbor_vec_blocks; int ret = - entity.get_vector(neighbor_ids.data(), size, neighbor_vec_blocks); + entity.get_vector_new(neighbor_ids.data(), size, neighbor_vec_blocks); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_vector())++; } diff --git a/src/core/algorithm/hnsw/hnsw_context.h b/src/core/algorithm/hnsw/hnsw_context.h index 0f988baf..684eafe2 100644 --- a/src/core/algorithm/hnsw/hnsw_context.h +++ b/src/core/algorithm/hnsw/hnsw_context.h @@ -114,7 +114,8 @@ class HnswContext : public IndexContext { //! Update context, the context may be shared by different searcher/streamer int update_context(ContextType type, const IndexMeta &meta, const IndexMetric::Pointer &metric, - const HnswStreamerEntityNew::Pointer &entity, uint32_t magic_num); + const HnswStreamerEntityNew::Pointer &entity, + uint32_t magic_num); inline const HnswStreamerEntityNew &get_entity() const { return *entity_; @@ -175,7 +176,7 @@ class HnswContext : public IndexContext { node_id_t id = topk_heap_[i].first; if (fetch_vector_) { results_[idx].emplace_back(entity_->get_key(id), score, id, - entity_->get_vector(id)); + entity_->get_vector_new(id)); } else { results_[idx].emplace_back(entity_->get_key(id), score, id); } @@ -238,7 +239,7 @@ class HnswContext : public IndexContext { if (fetch_vector_) { group_results_[idx][i].mutable_docs()->emplace_back( - entity_->get_key(id), score, id, entity_->get_vector(id)); + entity_->get_key(id), score, id, entity_->get_vector_new(id)); } else { group_results_[idx][i].mutable_docs()->emplace_back( entity_->get_key(id), score, id); @@ -501,7 +502,8 @@ class HnswContext : public IndexContext { uint32_t topk_{0}; uint32_t group_topk_{0}; uint32_t filter_mode_{VisitFilter::ByteMap}; - float negative_probability_{HnswStreamerEntityNew::kDefaultBFNegativeProbability}; + float negative_probability_{ + HnswStreamerEntityNew::kDefaultBFNegativeProbability}; uint32_t ef_{HnswStreamerEntityNew::kDefaultEf}; float max_scan_ratio_{HnswStreamerEntityNew::kDefaultScanRatio}; uint32_t magic_{0U}; diff --git a/src/core/algorithm/hnsw/hnsw_dist_calculator.h b/src/core/algorithm/hnsw/hnsw_dist_calculator.h index 4f6b624e..84bc255e 100644 --- a/src/core/algorithm/hnsw/hnsw_dist_calculator.h +++ b/src/core/algorithm/hnsw/hnsw_dist_calculator.h @@ -63,14 +63,15 @@ class HnswDistCalculator { dim_(0), compare_cnt_(0) {} - void update(const HnswStreamerEntityNew *entity, const IndexMetric::Pointer &metric) { + void update(const HnswStreamerEntityNew *entity, + const IndexMetric::Pointer &metric) { entity_ = entity; distance_ = metric->distance(); batch_distance_ = metric->batch_distance(); } - void update(const HnswStreamerEntityNew *entity, const IndexMetric::Pointer &metric, - uint32_t dim) { + void update(const HnswStreamerEntityNew *entity, + const IndexMetric::Pointer &metric, uint32_t dim) { entity_ = entity; distance_ = metric->distance(); batch_distance_ = metric->batch_distance(); @@ -116,7 +117,7 @@ class HnswDistCalculator { inline dist_t dist(node_id_t id) { compare_cnt_++; - const void *feat = entity_->get_vector(id); + const void *feat = entity_->get_vector_new(id); if (ailego_unlikely(feat == nullptr)) { LOG_ERROR("Get nullptr vector, id=%u", id); error_ = true; @@ -130,8 +131,8 @@ class HnswDistCalculator { inline dist_t dist(node_id_t lhs, node_id_t rhs) { compare_cnt_++; - const void *feat = entity_->get_vector(lhs); - const void *query = entity_->get_vector(rhs); + const void *feat = entity_->get_vector_new(lhs); + const void *query = entity_->get_vector_new(rhs); if (ailego_unlikely(feat == nullptr || query == nullptr)) { LOG_ERROR("Get nullptr vector"); error_ = true; @@ -162,7 +163,7 @@ class HnswDistCalculator { inline dist_t batch_dist(node_id_t id) { compare_cnt_++; - const void *feat = entity_->get_vector(id); + const void *feat = entity_->get_vector_new(id); if (ailego_unlikely(feat == nullptr)) { LOG_ERROR("Get nullptr vector, id=%u", id); error_ = true; diff --git a/src/core/algorithm/hnsw/hnsw_index_provider.h b/src/core/algorithm/hnsw/hnsw_index_provider.h index b128a2c0..064934ce 100644 --- a/src/core/algorithm/hnsw/hnsw_index_provider.h +++ b/src/core/algorithm/hnsw/hnsw_index_provider.h @@ -23,7 +23,8 @@ namespace core { class HnswIndexProvider : public IndexProvider { public: - HnswIndexProvider(const IndexMeta &meta, const HnswStreamerEntityNew::Pointer &entity, + HnswIndexProvider(const IndexMeta &meta, + const HnswStreamerEntityNew::Pointer &entity, const std::string &owner) : meta_(meta), entity_(entity), owner_class_(owner) {} @@ -83,7 +84,7 @@ class HnswIndexProvider : public IndexProvider { //! NOTICE: the vec feature will be changed after iterating to next, so //! the caller need to keep a copy of it before iterator to next vector virtual const void *data(void) const override { - return entity_->get_vector(cur_id_); + return entity_->get_vector_new(cur_id_); } //! Test if the iterator is valid diff --git a/src/core/algorithm/hnsw/hnsw_streamer.h b/src/core/algorithm/hnsw/hnsw_streamer.h index 9613533e..48f377e6 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer.h +++ b/src/core/algorithm/hnsw/hnsw_streamer.h @@ -98,12 +98,12 @@ class HnswStreamer : public IndexStreamer { //! Fetch vector by id virtual const void *get_vector_by_id(uint32_t id) const override { - return entity_.get_vector(id); + return entity_.get_vector_new(id); } virtual int get_vector_by_id( const uint32_t id, IndexStorage::MemoryBlock &block) const override { - return entity_.get_vector(id, block); + return entity_.get_vector_new(id, block); } //! Open index from file path diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc index 616c4258..451cf4b3 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc @@ -21,18 +21,24 @@ namespace zvec { namespace core { const std::string HnswStreamerEntityNew::kGraphHeaderSegmentId = "graph.header"; -const std::string HnswStreamerEntityNew::kGraphFeaturesSegmentId = "graph.features"; +const std::string HnswStreamerEntityNew::kGraphFeaturesSegmentId = + "graph.features"; const std::string HnswStreamerEntityNew::kGraphKeysSegmentId = "graph.keys"; -const std::string HnswStreamerEntityNew::kGraphNeighborsSegmentId = "graph.neighbors"; -const std::string HnswStreamerEntityNew::kGraphOffsetsSegmentId = "graph.offsets"; -const std::string HnswStreamerEntityNew::kGraphMappingSegmentId = "graph.mapping"; +const std::string HnswStreamerEntityNew::kGraphNeighborsSegmentId = + "graph.neighbors"; +const std::string HnswStreamerEntityNew::kGraphOffsetsSegmentId = + "graph.offsets"; +const std::string HnswStreamerEntityNew::kGraphMappingSegmentId = + "graph.mapping"; const std::string HnswStreamerEntityNew::kHnswHeaderSegmentId = "hnsw.header"; -const std::string HnswStreamerEntityNew::kHnswNeighborsSegmentId = "hnsw.neighbors"; +const std::string HnswStreamerEntityNew::kHnswNeighborsSegmentId = + "hnsw.neighbors"; const std::string HnswStreamerEntityNew::kHnswOffsetsSegmentId = "hnsw.offsets"; int64_t HnswStreamerEntityNew::dump_segment(const IndexDumper::Pointer &dumper, - const std::string &segment_id, - const void *data, size_t size) const { + const std::string &segment_id, + const void *data, + size_t size) const { size_t len = dumper->write(data, size); if (len != size) { LOG_ERROR("Dump segment %s data failed, expect: %lu, actual: %lu", @@ -60,7 +66,7 @@ int64_t HnswStreamerEntityNew::dump_segment(const IndexDumper::Pointer &dumper, } int64_t HnswStreamerEntityNew::dump_header(const IndexDumper::Pointer &dumper, - const HNSWHeader &hd) const { + const HNSWHeader &hd) const { //! dump basic graph header. header is aligned and does not need padding int64_t graph_hd_size = dump_segment(dumper, kGraphHeaderSegmentId, &hd.graph, hd.graph.size); @@ -160,7 +166,7 @@ int HnswStreamerEntityNew::update_neighbors( } const Neighbors HnswStreamerEntityNew::get_neighbors(level_t level, - node_id_t id) const { + node_id_t id) const { Chunk *chunk = nullptr; size_t offset = 0UL; size_t neighbor_size = neighbor_size_; @@ -208,8 +214,13 @@ const void *HnswStreamerEntityNew::get_vector(node_id_t id) const { return vec; } +const void *HnswStreamerEntityNew::get_vector_new(node_id_t id) const { + return vector_value_.data() + vector_size() * id; + // return get_vector(id); +} + int HnswStreamerEntityNew::get_vector(const node_id_t *ids, uint32_t count, - const void **vecs) const { + const void **vecs) const { for (auto i = 0U; i < count; ++i) { auto loc = get_vector_chunk_loc(ids[i]); ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); @@ -229,7 +240,7 @@ int HnswStreamerEntityNew::get_vector(const node_id_t *ids, uint32_t count, } int HnswStreamerEntityNew::get_vector(const node_id_t id, - IndexStorage::MemoryBlock &block) const { + IndexStorage::MemoryBlock &block) const { auto loc = get_vector_chunk_loc(id); ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), @@ -246,6 +257,26 @@ int HnswStreamerEntityNew::get_vector(const node_id_t id, return 0; } +int HnswStreamerEntityNew::get_vector_new( + const node_id_t id, IndexStorage::MemoryBlock &block) const { + // const void *data = vector_value_.data() + vector_size() * id; + // block.reset((void *)data); + // return 0; + return get_vector(id, block); +} + +int HnswStreamerEntityNew::get_vector_new( + const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const { + // vec_blocks.resize(count); + // for (int i = 0; i < count; i++) { + // const void *data = vector_value_.data() + vector_size() * ids[i]; + // vec_blocks[i].reset((void *)data); + // } + // return 0; + return get_vector(ids, count, vec_blocks); +} + int HnswStreamerEntityNew::get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const { @@ -290,7 +321,7 @@ key_t HnswStreamerEntityNew::get_key(node_id_t id) const { } void HnswStreamerEntityNew::add_neighbor(level_t level, node_id_t id, - uint32_t size, node_id_t neighbor_id) { + uint32_t size, node_id_t neighbor_id) { auto loc = get_neighbor_chunk_loc(level, id); size_t offset = loc.second + sizeof(NeighborsHeader) + size * sizeof(node_id_t); @@ -356,8 +387,8 @@ int HnswStreamerEntityNew::init_chunks(const Chunk::Pointer &header_chunk) { return 0; } -int HnswStreamerEntityNew::open(IndexStorage::Pointer stg, uint64_t max_index_size, - bool check_crc) { +int HnswStreamerEntityNew::open(IndexStorage::Pointer stg, + uint64_t max_index_size, bool check_crc) { std::lock_guard lock(mutex_); bool huge_page = stg->isHugePage(); LOG_DEBUG("huge_page: %d", (int)huge_page); @@ -437,7 +468,11 @@ int HnswStreamerEntityNew::open(IndexStorage::Pointer stg, uint64_t max_index_si } } } - + vector_value_.clear(); + vector_value_.reserve(vector_size() * doc_cnt()); + for (int i = 0; i < doc_cnt(); i++) { + vector_value_.append((const char *)get_vector(i), vector_size()); + } stats_.set_loaded_count(doc_cnt()); return 0; @@ -542,7 +577,7 @@ int HnswStreamerEntityNew::check_hnsw_index(const HNSWHeader *hd) const { } int HnswStreamerEntityNew::add_vector(level_t level, key_t key, const void *vec, - node_id_t *id) { + node_id_t *id) { Chunk::Pointer node_chunk; size_t chunk_offset = -1UL; @@ -615,7 +650,7 @@ int HnswStreamerEntityNew::add_vector(level_t level, key_t key, const void *vec, } int HnswStreamerEntityNew::add_vector_with_id(level_t level, node_id_t id, - const void *vec) { + const void *vec) { Chunk::Pointer node_chunk; size_t chunk_offset = -1UL; key_t key = id; @@ -750,15 +785,16 @@ const HnswStreamerEntityNew::Pointer HnswStreamerEntityNew::clone() const { stats_, header(), chunk_size_, node_index_mask_bits_, upper_neighbor_mask_bits_, filter_same_key_, get_vector_enabled_, upper_neighbor_index_, keys_map_lock_, keys_map_, use_key_info_map_, - std::move(node_chunks), std::move(upper_neighbor_chunks), broker_); + std::move(node_chunks), std::move(upper_neighbor_chunks), broker_, + vector_value_); if (ailego_unlikely(!entity)) { LOG_ERROR("HnswStreamerEntityNew new failed"); } return HnswStreamerEntityNew::Pointer(entity); } -int64_t HnswStreamerEntityNew::dump_mapping_segment(const IndexDumper::Pointer &dumper, - const key_t *keys) const { +int64_t HnswStreamerEntityNew::dump_mapping_segment( + const IndexDumper::Pointer &dumper, const key_t *keys) const { std::vector mapping(doc_cnt()); std::iota(mapping.begin(), mapping.end(), 0U); @@ -1019,7 +1055,8 @@ int64_t HnswStreamerEntityNew::dump_upper_neighbors( } int HnswStreamerEntityNew::CalcAndAddPadding(const IndexDumper::Pointer &dumper, - size_t data_size, size_t *padding_size) { + size_t data_size, + size_t *padding_size) { *padding_size = AlignSize(data_size) - data_size; if (*padding_size == 0) { return 0; diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h index 40a2d3e3..63c32629 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h @@ -30,7 +30,7 @@ namespace core { //! HnswStreamerEntityNew manage vector data, pkey, and node's neighbors class HnswStreamerEntityNew { - public: // override + public: // override typedef std::shared_ptr Pointer; //! Cleanup @@ -47,30 +47,32 @@ class HnswStreamerEntityNew { //! Get vector feature data by key const void *get_vector(node_id_t id) const; + const void *get_vector_new(node_id_t id) const; + //! Get vectors feature data by local ids + int get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const; + + int get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const; + int get_vector(const node_id_t *ids, uint32_t count, - const void **vecs) const; + std::vector &vec_blocks) const; - int get_vector(const node_id_t id, - IndexStorage::MemoryBlock &block) const; + int get_vector_new(const node_id_t id, + IndexStorage::MemoryBlock &block) const; - int get_vector( - const node_id_t *ids, uint32_t count, - std::vector &vec_blocks) const; + int get_vector_new(const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const; //! Get the node id's neighbors on graph level //! Note: the neighbors cannot be modified, using the following //! method to get WritableNeighbors if want to - const Neighbors get_neighbors(level_t level, - node_id_t id) const; + const Neighbors get_neighbors(level_t level, node_id_t id) const; //! Add vector and key to hnsw entity, and local id will be saved in id - int add_vector(level_t level, key_t key, const void *vec, - node_id_t *id); + int add_vector(level_t level, key_t key, const void *vec, node_id_t *id); //! Add vector and id to hnsw entity - int add_vector_with_id(level_t level, node_id_t id, - const void *vec); + int add_vector_with_id(level_t level, node_id_t id, const void *vec); int update_neighbors( level_t level, node_id_t id, @@ -79,7 +81,7 @@ class HnswStreamerEntityNew { //! Append neighbor_id to node id neighbors on level //! Notice: the caller must be ensure the neighbors not full void add_neighbor(level_t level, node_id_t id, uint32_t size, - node_id_t neighbor_id); + node_id_t neighbor_id); //! Dump index by dumper int dump(const IndexDumper::Pointer &dumper); @@ -91,8 +93,8 @@ class HnswStreamerEntityNew { return id == kInvalidNodeId ? nullptr : get_vector(id); } - int get_vector_by_key( - const key_t key, IndexStorage::MemoryBlock &block) const { + int get_vector_by_key(const key_t key, + IndexStorage::MemoryBlock &block) const { auto id = get_id(key); if (id != kInvalidNodeId) { return get_vector(id, block); @@ -101,9 +103,8 @@ class HnswStreamerEntityNew { } } - public: // hnsw entity public - - //! Get max neighbor size of graph level + public: // hnsw entity public + //! Get max neighbor size of graph level inline size_t neighbor_cnt(level_t level) const { return level == 0 ? base_header_.graph.l0_neighbor_count : base_header_.hnsw.upper_neighbor_count; @@ -291,40 +292,40 @@ class HnswStreamerEntityNew { const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const; - public: - const static std::string kGraphHeaderSegmentId; - const static std::string kGraphFeaturesSegmentId; - const static std::string kGraphKeysSegmentId; - const static std::string kGraphNeighborsSegmentId; - const static std::string kGraphOffsetsSegmentId; - const static std::string kGraphMappingSegmentId; - const static std::string kHnswHeaderSegmentId; - const static std::string kHnswNeighborsSegmentId; - const static std::string kHnswOffsetsSegmentId; - - constexpr static uint32_t kRevision = 0U; - constexpr static size_t kMaxGraphLayers = 15; - constexpr static uint32_t kDefaultEfConstruction = 500; - constexpr static uint32_t kDefaultEf = 500; - constexpr static uint32_t kDefaultUpperMaxNeighborCnt = 50; // M of HNSW - constexpr static uint32_t kDefaultL0MaxNeighborCnt = 100; - constexpr static uint32_t kMaxNeighborCnt = 65535; - constexpr static float kDefaultScanRatio = 0.1f; - constexpr static uint32_t kDefaultMinScanLimit = 10000; - constexpr static uint32_t kDefaultMaxScanLimit = - std::numeric_limits::max(); - constexpr static float kDefaultBFNegativeProbability = 0.001f; - constexpr static uint32_t kDefaultScalingFactor = 50U; - constexpr static uint32_t kDefaultBruteForceThreshold = 1000U; - constexpr static uint32_t kDefaultDocsHardLimit = 1 << 30U; // 1 billion - constexpr static float kDefaultDocsSoftLimitRatio = 0.9f; - constexpr static size_t kMaxChunkSize = 0xFFFFFFFF; - constexpr static size_t kDefaultChunkSize = 2UL * 1024UL * 1024UL; - constexpr static size_t kDefaultMaxChunkCnt = 50000UL; - constexpr static float kDefaultNeighborPruneMultiplier = - 1.0f; // prune_cnt = upper_max_neighbor_cnt * multiplier - constexpr static float kDefaultL0MaxNeighborCntMultiplier = - 2.0f; // l0_max_neighbor_cnt = upper_max_neighbor_cnt * multiplier + public: + const static std::string kGraphHeaderSegmentId; + const static std::string kGraphFeaturesSegmentId; + const static std::string kGraphKeysSegmentId; + const static std::string kGraphNeighborsSegmentId; + const static std::string kGraphOffsetsSegmentId; + const static std::string kGraphMappingSegmentId; + const static std::string kHnswHeaderSegmentId; + const static std::string kHnswNeighborsSegmentId; + const static std::string kHnswOffsetsSegmentId; + + constexpr static uint32_t kRevision = 0U; + constexpr static size_t kMaxGraphLayers = 15; + constexpr static uint32_t kDefaultEfConstruction = 500; + constexpr static uint32_t kDefaultEf = 500; + constexpr static uint32_t kDefaultUpperMaxNeighborCnt = 50; // M of HNSW + constexpr static uint32_t kDefaultL0MaxNeighborCnt = 100; + constexpr static uint32_t kMaxNeighborCnt = 65535; + constexpr static float kDefaultScanRatio = 0.1f; + constexpr static uint32_t kDefaultMinScanLimit = 10000; + constexpr static uint32_t kDefaultMaxScanLimit = + std::numeric_limits::max(); + constexpr static float kDefaultBFNegativeProbability = 0.001f; + constexpr static uint32_t kDefaultScalingFactor = 50U; + constexpr static uint32_t kDefaultBruteForceThreshold = 1000U; + constexpr static uint32_t kDefaultDocsHardLimit = 1 << 30U; // 1 billion + constexpr static float kDefaultDocsSoftLimitRatio = 0.9f; + constexpr static size_t kMaxChunkSize = 0xFFFFFFFF; + constexpr static size_t kDefaultChunkSize = 2UL * 1024UL * 1024UL; + constexpr static size_t kDefaultMaxChunkCnt = 50000UL; + constexpr static float kDefaultNeighborPruneMultiplier = + 1.0f; // prune_cnt = upper_max_neighbor_cnt * multiplier + constexpr static float kDefaultL0MaxNeighborCntMultiplier = + 2.0f; // l0_max_neighbor_cnt = upper_max_neighbor_cnt * multiplier public: //! Constructor @@ -334,7 +335,7 @@ class HnswStreamerEntityNew { ~HnswStreamerEntityNew(); //! Get vector feature data by key - + //! Init entity int init(size_t max_doc_cnt); @@ -445,16 +446,17 @@ class HnswStreamerEntityNew { //! Private construct, only be called by clone method HnswStreamerEntityNew(IndexStreamer::Stats &stats, const HNSWHeader &hd, - size_t chunk_size, uint32_t node_index_mask_bits, - uint32_t upper_neighbor_mask_bits, bool filter_same_key, - bool get_vector_enabled, - const NIHashMapPointer &upper_neighbor_index, - std::shared_ptr &keys_map_lock, - const HashMapPointer &keys_map, - bool use_key_info_map, - std::vector &&node_chunks, - std::vector &&upper_neighbor_chunks, - const ChunkBroker::Pointer &broker) + size_t chunk_size, uint32_t node_index_mask_bits, + uint32_t upper_neighbor_mask_bits, bool filter_same_key, + bool get_vector_enabled, + const NIHashMapPointer &upper_neighbor_index, + std::shared_ptr &keys_map_lock, + const HashMapPointer &keys_map, + bool use_key_info_map, + std::vector &&node_chunks, + std::vector &&upper_neighbor_chunks, + const ChunkBroker::Pointer &broker, + std::string vector_value) : stats_(stats), chunk_size_(chunk_size), node_index_mask_bits_(node_index_mask_bits), @@ -470,7 +472,8 @@ class HnswStreamerEntityNew { keys_map_(keys_map), node_chunks_(std::move(node_chunks)), upper_neighbor_chunks_(std::move(upper_neighbor_chunks)), - broker_(broker) { + broker_(broker), + vector_value_(vector_value) { *mutable_header() = hd; neighbor_size_ = neighbors_size(); @@ -739,6 +742,8 @@ class HnswStreamerEntityNew { mutable std::vector upper_neighbor_chunks_{}; ChunkBroker::Pointer broker_{}; // chunk broker + + std::string vector_value_{}; }; } // namespace core From 5aa9ed2083bc2653518861af99de5872a5899913 Mon Sep 17 00:00:00 2001 From: "yinzefeng.yzf" Date: Thu, 12 Mar 2026 11:17:56 +0800 Subject: [PATCH 09/34] upd --- .../hnsw/hnsw_streamer_entity_new.cc | 32 +++++++++---------- .../algorithm/hnsw/hnsw_streamer_entity_new.h | 6 ++-- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc index 451cf4b3..d1d5dc60 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc @@ -215,7 +215,7 @@ const void *HnswStreamerEntityNew::get_vector(node_id_t id) const { } const void *HnswStreamerEntityNew::get_vector_new(node_id_t id) const { - return vector_value_.data() + vector_size() * id; + return vector_value_ptr_->data() + vector_size() * id; // return get_vector(id); } @@ -259,22 +259,22 @@ int HnswStreamerEntityNew::get_vector(const node_id_t id, int HnswStreamerEntityNew::get_vector_new( const node_id_t id, IndexStorage::MemoryBlock &block) const { - // const void *data = vector_value_.data() + vector_size() * id; - // block.reset((void *)data); - // return 0; - return get_vector(id, block); + const void *data = vector_value_ptr_->data() + vector_size() * id; + block.reset((void *)data); + return 0; + // return get_vector(id, block); } int HnswStreamerEntityNew::get_vector_new( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const { - // vec_blocks.resize(count); - // for (int i = 0; i < count; i++) { - // const void *data = vector_value_.data() + vector_size() * ids[i]; - // vec_blocks[i].reset((void *)data); - // } - // return 0; - return get_vector(ids, count, vec_blocks); + vec_blocks.resize(count); + for (int i = 0; i < count; i++) { + const void *data = vector_value_ptr_->data() + vector_size() * ids[i]; + vec_blocks[i].reset((void *)data); + } + return 0; + // return get_vector(ids, count, vec_blocks); } int HnswStreamerEntityNew::get_vector( @@ -468,10 +468,10 @@ int HnswStreamerEntityNew::open(IndexStorage::Pointer stg, } } } - vector_value_.clear(); - vector_value_.reserve(vector_size() * doc_cnt()); + vector_value_ptr_ = std::make_shared(); + vector_value_ptr_->reserve(vector_size() * doc_cnt()); for (int i = 0; i < doc_cnt(); i++) { - vector_value_.append((const char *)get_vector(i), vector_size()); + vector_value_ptr_->append((const char *)get_vector(i), vector_size()); } stats_.set_loaded_count(doc_cnt()); @@ -786,7 +786,7 @@ const HnswStreamerEntityNew::Pointer HnswStreamerEntityNew::clone() const { upper_neighbor_mask_bits_, filter_same_key_, get_vector_enabled_, upper_neighbor_index_, keys_map_lock_, keys_map_, use_key_info_map_, std::move(node_chunks), std::move(upper_neighbor_chunks), broker_, - vector_value_); + vector_value_ptr_); if (ailego_unlikely(!entity)) { LOG_ERROR("HnswStreamerEntityNew new failed"); } diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h index 63c32629..018d7c3e 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h @@ -456,7 +456,7 @@ class HnswStreamerEntityNew { std::vector &&node_chunks, std::vector &&upper_neighbor_chunks, const ChunkBroker::Pointer &broker, - std::string vector_value) + std::shared_ptr vector_value_ptr) : stats_(stats), chunk_size_(chunk_size), node_index_mask_bits_(node_index_mask_bits), @@ -473,7 +473,7 @@ class HnswStreamerEntityNew { node_chunks_(std::move(node_chunks)), upper_neighbor_chunks_(std::move(upper_neighbor_chunks)), broker_(broker), - vector_value_(vector_value) { + vector_value_ptr_(vector_value_ptr) { *mutable_header() = hd; neighbor_size_ = neighbors_size(); @@ -743,7 +743,7 @@ class HnswStreamerEntityNew { ChunkBroker::Pointer broker_{}; // chunk broker - std::string vector_value_{}; + std::shared_ptr vector_value_ptr_{}; }; } // namespace core From 4e87aa88ddf67bc32f5e9bb09dd8d2da2ab89bff Mon Sep 17 00:00:00 2001 From: "yinzefeng.yzf" Date: Thu, 12 Mar 2026 16:25:32 +0800 Subject: [PATCH 10/34] upd get neighbor --- src/core/algorithm/hnsw/hnsw_algorithm.cc | 8 +++---- src/core/algorithm/hnsw/hnsw_streamer.cc | 2 +- .../hnsw/hnsw_streamer_entity_new.cc | 21 ++++++++++++++++++- .../algorithm/hnsw/hnsw_streamer_entity_new.h | 8 +++++-- 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/core/algorithm/hnsw/hnsw_algorithm.cc b/src/core/algorithm/hnsw/hnsw_algorithm.cc index 83d286da..50cd9b64 100644 --- a/src/core/algorithm/hnsw/hnsw_algorithm.cc +++ b/src/core/algorithm/hnsw/hnsw_algorithm.cc @@ -113,7 +113,7 @@ void HnswAlgorithm::select_entry_point(level_t level, node_id_t *entry_point, auto &entity = ctx->get_entity(); HnswDistCalculator &dc = ctx->dist_calculator(); while (true) { - const Neighbors neighbors = entity.get_neighbors(level, *entry_point); + const Neighbors neighbors = entity.get_neighbors_new(level, *entry_point); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } @@ -208,7 +208,7 @@ void HnswAlgorithm::search_neighbors(level_t level, node_id_t *entry_point, } candidates.pop(); - const Neighbors neighbors = entity.get_neighbors(level, main_node); + const Neighbors neighbors = entity.get_neighbors_new(level, main_node); ailego_prefetch(neighbors.data); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; @@ -333,7 +333,7 @@ void HnswAlgorithm::expand_neighbors_by_group(TopkHeap &topk, node_id_t main_node = top->first; candidates.pop(); - const Neighbors neighbors = entity.get_neighbors(0, main_node); + const Neighbors neighbors = entity.get_neighbors_new(0, main_node); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } @@ -464,7 +464,7 @@ void HnswAlgorithm::reverse_update_neighbors(HnswDistCalculator &dc, uint32_t lock_idx = id & kLockMask; lock_pool_[lock_idx].lock(); - const Neighbors neighbors = entity_.get_neighbors(level, id); + const Neighbors neighbors = entity_.get_neighbors_new(level, id); size_t size = neighbors.size(); ailego_assert_with(size <= max_neighbor_cnt, "invalid neighbor size"); if (size < max_neighbor_cnt) { diff --git a/src/core/algorithm/hnsw/hnsw_streamer.cc b/src/core/algorithm/hnsw/hnsw_streamer.cc index 057a804b..e08953aa 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer.cc +++ b/src/core/algorithm/hnsw/hnsw_streamer.cc @@ -642,7 +642,7 @@ void HnswStreamer::print_debug_info() { if (entity_.get_key(id) == kInvalidKey) { continue; } - Neighbors neighbours = entity_.get_neighbors(0, id); + Neighbors neighbours = entity_.get_neighbors_new(0, id); std::cout << "node: " << id << "; "; if (neighbours.size() == 0) std::cout << std::endl; for (uint32_t i = 0; i < neighbours.size(); ++i) { diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc index d1d5dc60..f97d80b5 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc @@ -195,6 +195,17 @@ const Neighbors HnswStreamerEntityNew::get_neighbors(level_t level, return Neighbors(std::move(neighbor_block)); } +const Neighbors HnswStreamerEntityNew::get_neighbors_new(level_t level, + node_id_t id) const { + if (id) { + return get_neighbors(level, id); + } else { + const void *src = neighbors_value_ptr_->data() + id * neighbor_size_; + const NeighborsHeader *header = reinterpret_cast(src); + return Neighbors(header->neighbor_cnt, header->neighbors); + } +} + //! Get vector data by key const void *HnswStreamerEntityNew::get_vector(node_id_t id) const { auto loc = get_vector_chunk_loc(id); @@ -473,6 +484,14 @@ int HnswStreamerEntityNew::open(IndexStorage::Pointer stg, for (int i = 0; i < doc_cnt(); i++) { vector_value_ptr_->append((const char *)get_vector(i), vector_size()); } + + neighbors_value_ptr_ = std::make_shared(); + neighbors_value_ptr_->reserve(neighbor_size_ * doc_cnt()); + for (int i = 0; i < doc_cnt(); i++) { + Neighbors neighbor = get_neighbors(0, i); + neighbors_value_ptr_->append((const char *)neighbor.neighbor_block.data(), neighbor_size_); + } + stats_.set_loaded_count(doc_cnt()); return 0; @@ -786,7 +805,7 @@ const HnswStreamerEntityNew::Pointer HnswStreamerEntityNew::clone() const { upper_neighbor_mask_bits_, filter_same_key_, get_vector_enabled_, upper_neighbor_index_, keys_map_lock_, keys_map_, use_key_info_map_, std::move(node_chunks), std::move(upper_neighbor_chunks), broker_, - vector_value_ptr_); + vector_value_ptr_, neighbors_value_ptr_); if (ailego_unlikely(!entity)) { LOG_ERROR("HnswStreamerEntityNew new failed"); } diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h index 018d7c3e..11707a02 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h @@ -67,6 +67,7 @@ class HnswStreamerEntityNew { //! Note: the neighbors cannot be modified, using the following //! method to get WritableNeighbors if want to const Neighbors get_neighbors(level_t level, node_id_t id) const; + const Neighbors get_neighbors_new(level_t level, node_id_t id) const; //! Add vector and key to hnsw entity, and local id will be saved in id int add_vector(level_t level, key_t key, const void *vec, node_id_t *id); @@ -456,7 +457,8 @@ class HnswStreamerEntityNew { std::vector &&node_chunks, std::vector &&upper_neighbor_chunks, const ChunkBroker::Pointer &broker, - std::shared_ptr vector_value_ptr) + std::shared_ptr vector_value_ptr, + std::shared_ptr neighbors_value_ptr) : stats_(stats), chunk_size_(chunk_size), node_index_mask_bits_(node_index_mask_bits), @@ -473,7 +475,8 @@ class HnswStreamerEntityNew { node_chunks_(std::move(node_chunks)), upper_neighbor_chunks_(std::move(upper_neighbor_chunks)), broker_(broker), - vector_value_ptr_(vector_value_ptr) { + vector_value_ptr_(vector_value_ptr), + neighbors_value_ptr_(neighbors_value_ptr) { *mutable_header() = hd; neighbor_size_ = neighbors_size(); @@ -744,6 +747,7 @@ class HnswStreamerEntityNew { ChunkBroker::Pointer broker_{}; // chunk broker std::shared_ptr vector_value_ptr_{}; + std::shared_ptr neighbors_value_ptr_{}; }; } // namespace core From ce6a033c092228db723999e72e38893f0e3d67eb Mon Sep 17 00:00:00 2001 From: "yinzefeng.yzf" Date: Fri, 13 Mar 2026 11:01:44 +0800 Subject: [PATCH 11/34] upd --- src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc index f97d80b5..e1875fff 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc @@ -279,10 +279,10 @@ int HnswStreamerEntityNew::get_vector_new( int HnswStreamerEntityNew::get_vector_new( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const { - vec_blocks.resize(count); + vec_blocks.reserve(count); for (int i = 0; i < count; i++) { const void *data = vector_value_ptr_->data() + vector_size() * ids[i]; - vec_blocks[i].reset((void *)data); + vec_blocks.emplace_back((void *)data); } return 0; // return get_vector(ids, count, vec_blocks); From 075a19caddd9c0679c92f95994eebe3a8e9f6f77 Mon Sep 17 00:00:00 2001 From: ZeFeng Yin Date: Tue, 10 Mar 2026 16:09:37 +0800 Subject: [PATCH 12/34] fix: add validator for cosine metric with int8 (#209) --- src/db/index/common/schema.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/db/index/common/schema.cc b/src/db/index/common/schema.cc index 02789c61..1f743539 100644 --- a/src/db/index/common/schema.cc +++ b/src/db/index/common/schema.cc @@ -160,6 +160,16 @@ Status FieldSchema::validate() const { "types according to the IP metric"); } } + if (vector_index_params->metric_type() == MetricType::COSINE) { + if (data_type_ != DataType::VECTOR_FP16 && + data_type_ != DataType::VECTOR_FP32) { + return Status::InvalidArgument( + "schema validate failed: cosine metric only supports FP32/FP16 " + "data types, but field[", + name_, "]'s data type is ", + DataTypeCodeBook::AsString(data_type_)); + } + } } } else { if (index_params_) { From c9232115ab0f12d3be1ee97b6f8d651a81af57be Mon Sep 17 00:00:00 2001 From: Qinren Zhou Date: Wed, 11 Mar 2026 15:55:13 +0800 Subject: [PATCH 13/34] fix: clean up crash residue (#208) --- src/db/index/common/doc.cc | 63 +++- src/db/index/segment/segment.cc | 28 +- src/include/zvec/db/doc.h | 2 + tests/db/CMakeLists.txt | 1 + tests/db/crash_recovery/CMakeLists.txt | 67 ++++ tests/db/crash_recovery/data_generator.cc | 217 +++++++++++ tests/db/crash_recovery/utility.h | 152 ++++++++ .../db/crash_recovery/write_recovery_test.cc | 357 ++++++++++++++++++ 8 files changed, 865 insertions(+), 22 deletions(-) create mode 100644 tests/db/crash_recovery/CMakeLists.txt create mode 100644 tests/db/crash_recovery/data_generator.cc create mode 100644 tests/db/crash_recovery/utility.h create mode 100644 tests/db/crash_recovery/write_recovery_test.cc diff --git a/src/db/index/common/doc.cc b/src/db/index/common/doc.cc index dad9bbdd..6d411bfb 100644 --- a/src/db/index/common/doc.cc +++ b/src/db/index/common/doc.cc @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -1108,6 +1109,52 @@ std::string Doc::to_detail_string() const { return oss.str(); } +struct Doc::ValueEqual { + template + bool operator()(const T &, const U &) const { + return false; + } + + template + bool operator()(const T &a, const T &b) const { + return a == b; + } + + bool operator()(float a, float b) const { + return std::fabs(a - b) < 1e-6f; + } + + bool operator()(double a, double b) const { + return std::fabs(a - b) < 1e-9; + } + + bool operator()(const std::vector &a, + const std::vector &b) const { + if (a.size() != b.size()) return false; + for (size_t i = 0; i < a.size(); ++i) + if (std::fabs(static_cast(a[i]) - static_cast(b[i])) >= + 1e-3f) + return false; + return true; + } + + bool operator()(const std::vector &a, + const std::vector &b) const { + if (a.size() != b.size()) return false; + for (size_t i = 0; i < a.size(); ++i) + if (std::fabs(a[i] - b[i]) >= 1e-6f) return false; + return true; + } + + bool operator()(const std::vector &a, + const std::vector &b) const { + if (a.size() != b.size()) return false; + for (size_t i = 0; i < a.size(); ++i) + if (std::fabs(a[i] - b[i]) >= 1e-9) return false; + return true; + } +}; + bool Doc::operator==(const Doc &other) const { // Compare basic fields if (pk_ != other.pk_) { @@ -1135,21 +1182,7 @@ bool Doc::operator==(const Doc &other) const { } // Use visitor to compare the actual values - bool values_equal = std::visit( - [](const auto &lhs, const auto &rhs) -> bool { - if constexpr (std::is_same_v, - std::decay_t>) { - return lhs == rhs; - } else { - // This should not happen due to the index check above - return false; - } - }, - field_value, it->second); - - if (!values_equal) { - return false; - } + if (!std::visit(ValueEqual{}, field_value, it->second)) return false; } return true; diff --git a/src/db/index/segment/segment.cc b/src/db/index/segment/segment.cc index 517215a3..2d03cd78 100644 --- a/src/db/index/segment/segment.cc +++ b/src/db/index/segment/segment.cc @@ -3939,6 +3939,14 @@ VectorColumnIndexer::Ptr SegmentImpl::create_vector_indexer( memory_vector_block_ids_[field_name] = block_id; } + if (FileHelper::FileExists(index_file_path)) { + LOG_WARN( + "Index file[%s] already exists (possible crash residue); cleaning and " + "overwriting.", + index_file_path.c_str()); + FileHelper::RemoveFile(index_file_path); + } + auto vector_indexer = std::make_shared(index_file_path, field); vector_column_params::ReadOptions options{true, true}; @@ -3958,6 +3966,13 @@ Status SegmentImpl::init_memory_components() { // create and open memory forward block auto mem_path = FileHelper::MakeForwardBlockPath(seg_path_, mem_block.id_, !options_.enable_mmap_); + if (FileHelper::FileExists(mem_path)) { + LOG_WARN( + "ForwardBlock file[%s] already exists (possible crash residue); " + "cleaning and overwriting.", + mem_path.c_str()); + FileHelper::RemoveFile(mem_path); + } memory_store_ = std::make_shared( collection_schema_, mem_path, options_.enable_mmap_ ? FileFormat::IPC : FileFormat::PARQUET, @@ -4104,18 +4119,17 @@ Status SegmentImpl::recover() { } const auto added_docs = recovered_doc_count[0] + // INSERT - recovered_doc_count[1] + // UPDATE - recovered_doc_count[2]; // UPSERT + recovered_doc_count[1] + // UPSERT + recovered_doc_count[2]; // UPDATE mem_block.max_doc_id_ += added_docs; LOG_INFO( - "Recover from wal finished. total_recovered_doc_count[%zu] " - "insert[%zu] update[%zu] upsert[%zu] " - "delete[%zu] path[%s]", + "Recover from wal finished. total_recovered_doc_count[%zu] insert[%zu] " + "upsert[%zu] update[%zu] delete[%zu] path[%s]", (size_t)total_recovered_doc_count, (size_t)recovered_doc_count[0], // INSERT - (size_t)recovered_doc_count[1], // UPDATE - (size_t)recovered_doc_count[2], // UPSERT + (size_t)recovered_doc_count[1], // UPSERT + (size_t)recovered_doc_count[2], // UPDATE (size_t)recovered_doc_count[3], // DELETE wal_file_path.c_str()); diff --git a/src/include/zvec/db/doc.h b/src/include/zvec/db/doc.h index 5f927fa1..fa056053 100644 --- a/src/include/zvec/db/doc.h +++ b/src/include/zvec/db/doc.h @@ -294,6 +294,8 @@ class Doc { static void read_from_buffer(const uint8_t *&data, void *dest, size_t size); + struct ValueEqual; + private: std::string pk_; float score_{0.0f}; diff --git a/tests/db/CMakeLists.txt b/tests/db/CMakeLists.txt index 8de3089a..612ee150 100644 --- a/tests/db/CMakeLists.txt +++ b/tests/db/CMakeLists.txt @@ -2,6 +2,7 @@ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) cc_directory(common) +cc_directories(crash_recovery) cc_directory(sqlengine) cc_directories(index) diff --git a/tests/db/crash_recovery/CMakeLists.txt b/tests/db/crash_recovery/CMakeLists.txt new file mode 100644 index 00000000..296b8c59 --- /dev/null +++ b/tests/db/crash_recovery/CMakeLists.txt @@ -0,0 +1,67 @@ +include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) +include(${PROJECT_ROOT_DIR}/cmake/option.cmake) + +if(APPLE) + set(APPLE_FRAMEWORK_LIBS + -framework CoreFoundation + -framework CoreGraphics + -framework CoreData + -framework CoreText + -framework Security + -framework Foundation + -Wl,-U,_MallocExtension_ReleaseFreeMemory + -Wl,-U,_ProfilerStart + -Wl,-U,_ProfilerStop + -Wl,-U,_RegisterThriftProtocol + ) +endif() + + +# Build data_generator executable +cc_binary( + NAME data_generator + LIBS zvec_db + zvec_proto + core_knn_flat + core_knn_flat_sparse + core_knn_hnsw + core_knn_hnsw_sparse + core_knn_ivf + core_mix_reducer + core_metric + core_utility + core_quantizer + ${CMAKE_THREAD_LIBS_INIT} + ${CMAKE_DL_LIBS} + SRCS data_generator.cc + INCS .. ../../src + LDFLAGS ${APPLE_FRAMEWORK_LIBS} +) + + +# Build test executables +file(GLOB ALL_TEST_SRCS *_test.cc) +foreach(CC_SRCS ${ALL_TEST_SRCS}) + get_filename_component(CC_TARGET ${CC_SRCS} NAME_WE) + cc_gmock( + NAME ${CC_TARGET} STRICT + LIBS zvec_db + zvec_proto + core_knn_flat + core_knn_flat_sparse + core_knn_hnsw + core_knn_hnsw_sparse + core_knn_ivf + core_mix_reducer + core_metric + core_utility + core_quantizer + ${CMAKE_THREAD_LIBS_INIT} + ${CMAKE_DL_LIBS} + SRCS ${CC_SRCS} + INCS .. ../../src + LDFLAGS ${APPLE_FRAMEWORK_LIBS} + ) + add_dependencies(${CC_TARGET} data_generator) + cc_test_suite(zvec_crash_recovery ${CC_TARGET}) +endforeach() diff --git a/tests/db/crash_recovery/data_generator.cc b/tests/db/crash_recovery/data_generator.cc new file mode 100644 index 00000000..57542471 --- /dev/null +++ b/tests/db/crash_recovery/data_generator.cc @@ -0,0 +1,217 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include +#include +#include "zvec/ailego/logger/logger.h" +#include "utility.h" + + +constexpr int kBatchSize = 20; +constexpr int kBatchDelayMs = 10; + + +struct Config { + std::string path; + int start_id = 0; + int end_id = 0; + std::string operation; // "insert", "upsert", "update", "delete" + int version = 999999; +}; + + +bool ParseArgs(int argc, char **argv, Config &config) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "--path" && i + 1 < argc) { + config.path = argv[++i]; + } else if (arg == "--start" && i + 1 < argc) { + config.start_id = std::stoi(argv[++i]); + } else if (arg == "--end" && i + 1 < argc) { + config.end_id = std::stoi(argv[++i]); + } else if (arg == "--op" && i + 1 < argc) { + config.operation = argv[++i]; + } else if (arg == "--version" && i + 1 < argc) { + config.version = std::stoi(argv[++i]); + } else if (arg == "--help" || arg == "-h") { + return false; + } + } + + // Validate required arguments + if (config.path.empty() || config.operation.empty() || + config.start_id >= config.end_id || config.version == 999999) { + return false; + } + + // Validate operation + if (config.operation != "insert" && config.operation != "upsert" && + config.operation != "update" && config.operation != "delete") { + std::cerr << "Error: Invalid operation '" << config.operation + << "'. Must be 'insert', 'upsert', 'update', or 'delete'." + << std::endl; + return false; + } + + return true; +} + + +void PrintUsage(const char *program) { + std::cout << "Usage: " << program + << " --path --start --end " + "--op " + << std::endl; + std::cout << std::endl; + std::cout << "Arguments:" << std::endl; + std::cout << " --path Path to the collection (required)" << std::endl; + std::cout << " --start Starting document ID (inclusive, required)" + << std::endl; + std::cout << " --end Ending document ID (exclusive, required)" + << std::endl; + std::cout + << " --op Operation: insert, upsert, update, or delete (required)" + << std::endl; + std::cout << " --version Operation: version (required)" << std::endl; + std::cout << std::endl; + std::cout << "Examples:" << std::endl; + std::cout << " # Insert 1000 documents (pk_0 to pk_999)" << std::endl; + std::cout << " " << program + << " --path ./test_db --start 0 --end 1000 --op insert --version 0" + << std::endl; + std::cout << std::endl; + std::cout << " # Update documents 1000-1999" << std::endl; + std::cout + << " " << program + << " --path ./test_db --start 1000 --end 2000 --op update --version 1" + << std::endl; + std::cout << std::endl; + std::cout << " # Upsert documents 0-499" << std::endl; + std::cout << " " << program + << " --path ./test_db --start 0 --end 500 --op upsert --version 2" + << std::endl; +} + + +int main(int argc, char **argv) { + Config config; + + // Parse arguments + if (!ParseArgs(argc, argv, config)) { + PrintUsage(argv[0]); + return 1; + } + + try { + std::filesystem::path cwd = std::filesystem::current_path(); + std::cout << "[data_generator] Current Working Directory: " << cwd.string() + << std::endl; + } catch (const std::filesystem::filesystem_error &e) { + std::cout + << "[data_generator] Failed to get the current working directory: " + << e.what() << std::endl; + } + + std::cout << "Configuration:" << std::endl; + std::cout << " Path: " << config.path << std::endl; + std::cout << " Range: [" << config.start_id << ", " << config.end_id + << ")" << std::endl; + std::cout << " Operation: " << config.operation << std::endl; + std::cout << " BatchSize: " << kBatchSize << std::endl; + std::cout << " BatchDelay: " << kBatchDelayMs << "ms" << std::endl; + std::cout << std::endl; + + auto result = + zvec::Collection::Open(config.path, zvec::CollectionOptions{false, true}); + if (!result) { + LOG_ERROR("Failed to open collection[%s]: %s", config.path.c_str(), + result.error().c_str()); + return -1; + } + + auto collection = result.value(); + LOG_INFO("Collection[%s] opened successfully", config.path.c_str()); + + // Process documents in batches + int total_docs = config.end_id - config.start_id; + int processed = 0; + int batch_num = 0; + int next_progress_threshold = total_docs / 10; // 10% increments + int progress_percent = 0; + + while (config.start_id < config.end_id) { + int batch_end = std::min(config.start_id + kBatchSize, config.end_id); + int batch_count = batch_end - config.start_id; + + std::vector docs; + docs.reserve(batch_count); + for (uint64_t i = config.start_id; i < batch_end; i++) { + docs.push_back(zvec::CreateTestDoc(i, config.version)); + } + + zvec::Result results; + if (config.operation == "insert") { + results = collection->Insert(docs); + } else if (config.operation == "upsert") { + results = collection->Upsert(docs); + } else if (config.operation == "update") { + results = collection->Update(docs); + } else if (config.operation == "delete") { + std::vector pks{}; + for (const auto &doc : docs) { + pks.emplace_back(doc.pk()); + } + results = collection->Delete(pks); + } + if (!results) { + LOG_ERROR("Failed to perform operation[%s], reason: %s", + config.operation.c_str(), results.error().message().c_str()); + return 1; + } + for (auto &s : results.value()) { + if (!s.ok()) { + LOG_ERROR("Failed to perform operation[%s], reason: %s", + config.operation.c_str(), s.message().c_str()); + return 1; + } + } + + processed += batch_count; + config.start_id = batch_end; + batch_num++; + + // Print progress every 10% + if (processed >= next_progress_threshold) { + progress_percent++; + LOG_INFO("Progress: %d (%d/%d documents)", progress_percent * 10, + processed, total_docs); + next_progress_threshold = (progress_percent + 1) * total_docs / 10; + } + + // Sleep between batches + if (config.start_id < config.end_id) { + std::this_thread::sleep_for(std::chrono::milliseconds(kBatchDelayMs)); + } + } + + std::cout << std::endl; + std::cout << "Success! Processed " << processed << " documents in " + << batch_num << " batches." << std::endl; + + return 0; +} diff --git a/tests/db/crash_recovery/utility.h b/tests/db/crash_recovery/utility.h new file mode 100644 index 00000000..36768b24 --- /dev/null +++ b/tests/db/crash_recovery/utility.h @@ -0,0 +1,152 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma once + + +#include +#include + + +namespace zvec { + +/** + * @brief Create a test schema with deterministic field definitions. + * + * @param name The collection name (default: "crash_recovery_test") + * @return CollectionSchema::Ptr The test schema + */ +inline CollectionSchema::Ptr CreateTestSchema( + const std::string &name = "crash_recovery_test") { + auto schema = std::make_shared(name); + schema->set_max_doc_count_per_segment(2000); + + schema->add_field( + std::make_shared("int32_field", DataType::INT32, false)); + schema->add_field( + std::make_shared("int64_field", DataType::INT64, true)); + schema->add_field( + std::make_shared("float_field", DataType::FLOAT, true)); + schema->add_field( + std::make_shared("string_field", DataType::STRING, false)); + schema->add_field( + std::make_shared("bool_field", DataType::BOOL, false)); + schema->add_field(std::make_shared("array_int32_field", + DataType::ARRAY_INT32, true)); + schema->add_field(std::make_shared( + "array_string_field", DataType::ARRAY_STRING, false)); + schema->add_field(std::make_shared( + "dense_fp32_field", DataType::VECTOR_FP32, 128, false, + std::make_shared(MetricType::COSINE))); + schema->add_field(std::make_shared( + "sparse_fp32_field", DataType::SPARSE_VECTOR_FP32, 0, false, + std::make_shared(MetricType::IP))); + + return schema; +} + + +/** + * @brief Create a test document with deterministic values based on doc_id. + * + * Document pattern: + * - pk: "pk_{doc_id}" + * - int32_field: doc_id (cast to int32) + * - int64_field: doc_id, null if doc_id % 60 == 0 + * - float_field: doc_id / 1000.0, null if doc_id % 70 == 0 + * - string_field: "{version}_{doc_id}" + * - bool_field: doc_id % 2 == 0 or flipped if version % 2 !=0 + * - array_int32_field: [doc_id, doc_id+1, doc_id+2], null if doc_id % 100 == 0 + * - array_string_field: ["str_{version}_0", ...] + * - dense_fp32_field: vector where dense[i] = (doc_id + i) / 1000.0f + * - sparse_fp32_field: sparse vector with indices [0, 10, ...] + * + * @param doc_id The document ID (determines all field values) + * @param version The version of the document + * @return Doc The created document + */ +inline Doc CreateTestDoc(uint64_t doc_id, int version) { + Doc doc; + + // Set primary key + std::string pk = "pk_" + std::to_string(doc_id); + doc.set_pk(pk); + + // Set scalar fields + doc.set("int32_field", static_cast(doc_id)); + + // int64_field: nullable, null if doc_id % 60 == 0 + if (doc_id % 60 != 0) { + doc.set("int64_field", static_cast(doc_id)); + } + + // float_field: nullable, null if doc_id % 70 == 0 + if (doc_id % 70 != 0) { + doc.set("float_field", static_cast(doc_id) / 1000.0f); + } + + // string_field: "value_{id}" or "updated_value_{id}" + std::string string_value = + std::to_string(version) + "_" + std::to_string(doc_id); + doc.set("string_field", string_value); + + // bool_field: alternating based on doc_id, flipped if updated + bool bool_value = (doc_id % 2 == 0); + if (version % 2 != 0) { + bool_value = !bool_value; + } + doc.set("bool_field", bool_value); + + // array_int32_field: nullable, null if doc_id % 100 == 0 + if (doc_id % 100 != 0) { + std::vector array_int32; + for (int i = 0; i < 3; i++) { + array_int32.push_back(static_cast(doc_id + i)); + } + doc.set>("array_int32_field", array_int32); + } + + // array_string_field: ["str_0", "str_1", ...] or ["updated_str_0", ...] + std::vector array_string; + size_t array_size = doc_id % 5 + 1; // 1 to 5 elements + for (size_t i = 0; i < array_size; i++) { + array_string.push_back("str_" + std::to_string(version) + "_" + + std::to_string(i)); + } + doc.set>("array_string_field", array_string); + + // dense_fp32_field: deterministic pattern + std::vector dense(128); + for (int i = 0; i < 128; i++) { + dense[i] = static_cast(doc_id + i) / 1000.0f; + } + doc.set>("dense_fp32_field", dense); + + // sparse_fp32_field: sparse vector with indices [0, 10, 20, ..., 100] + // Values based on doc_id: value = (doc_id + index) / 1000.0 + std::vector sparse_indices; + std::vector sparse_values; + for (uint32_t idx = 0; idx <= 100; idx += 10) { + sparse_indices.push_back(idx); + sparse_values.push_back(static_cast(doc_id + idx) / 1000.0f); + } + doc.set, std::vector>>( + "sparse_fp32_field", std::make_pair(sparse_indices, sparse_values)); + + return doc; +} + + +} // namespace zvec diff --git a/tests/db/crash_recovery/write_recovery_test.cc b/tests/db/crash_recovery/write_recovery_test.cc new file mode 100644 index 00000000..1f53a5f4 --- /dev/null +++ b/tests/db/crash_recovery/write_recovery_test.cc @@ -0,0 +1,357 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include +#include +#include +#include +#include +#include "utility.h" + + +namespace zvec { + + +static std::string data_generator_bin_; +const std::string collection_name_{"crash_test"}; +const std::string dir_path_{"crash_test_db"}; +const zvec::CollectionOptions options_{false, true}; + + +static std::string LocateDataGenerator() { + namespace fs = std::filesystem; + const std::vector candidates{"./data_generator", + "./bin/data_generator"}; + for (const auto &p : candidates) { + if (fs::exists(p)) { + return fs::canonical(p).string(); + } + } + throw std::runtime_error("data_generator binary not found"); +} + + +void RunGenerator(const std::string &start, const std::string &end, + const std::string &op, const std::string &version) { + pid_t pid = fork(); + ASSERT_GE(pid, 0); + + if (pid == 0) { // Child process + char arg_path[] = "--path"; + char arg_start[] = "--start"; + char arg_end[] = "--end"; + char arg_op[] = "--op"; + char arg_version[] = "--version"; + char *args[] = {const_cast(data_generator_bin_.c_str()), + arg_path, + const_cast(dir_path_.c_str()), + arg_start, + const_cast(start.c_str()), + arg_end, + const_cast(end.c_str()), + arg_op, + const_cast(op.c_str()), + arg_version, + const_cast(version.c_str()), + nullptr}; + execvp(args[0], args); + perror("execvp failed"); + _exit(1); + } + + int status; + waitpid(pid, &status, 0); + ASSERT_TRUE(WIFEXITED(status)) + << "Child process did not exit normally. Terminated by signal?"; + int exit_code = WEXITSTATUS(status); + ASSERT_EQ(exit_code, 0) << "data_generator failed with exit code: " + << exit_code; +} + + +void RunGeneratorAndCrash(const std::string &start, const std::string &end, + const std::string &op, const std::string &version, + int seconds) { + pid_t pid = fork(); + ASSERT_GE(pid, 0); + + if (pid == 0) { // Child process + char arg_path[] = "--path"; + char arg_start[] = "--start"; + char arg_end[] = "--end"; + char arg_op[] = "--op"; + char arg_version[] = "--version"; + char *args[] = {const_cast(data_generator_bin_.c_str()), + arg_path, + const_cast(dir_path_.c_str()), + arg_start, + const_cast(start.c_str()), + arg_end, + const_cast(end.c_str()), + arg_op, + const_cast(op.c_str()), + arg_version, + const_cast(version.c_str()), + nullptr}; + execvp(args[0], args); + perror("execvp failed"); + _exit(1); + } + + std::this_thread::sleep_for(std::chrono::seconds(seconds)); + if (kill(pid, 0) == 0) { + kill(pid, SIGKILL); + } + int status; + waitpid(pid, &status, 0); + ASSERT_TRUE(WIFSIGNALED(status)) + << "Child process was not killed by a signal. It exited normally?"; +} + + +class CrashRecoveryTest : public ::testing::Test { + protected: + void SetUp() override { + system("rm -rf ./crash_test_db"); + ASSERT_NO_THROW(data_generator_bin_ = LocateDataGenerator()); + } + + void TearDown() override { + system("rm -rf ./crash_test_db"); + } +}; + + +TEST_F(CrashRecoveryTest, BasicInsertAndReopen) { + { + auto schema = CreateTestSchema(collection_name_); + auto result = Collection::CreateAndOpen(dir_path_, *schema, options_); + ASSERT_TRUE(result.has_value()); + auto collection = result.value(); + collection.reset(); + } + + RunGenerator("0", "5000", "insert", "0"); + auto result = Collection::Open(dir_path_, options_); + ASSERT_TRUE(result.has_value()); + auto collection = result.value(); + ASSERT_EQ(collection->Stats().value().doc_count, 5000) + << "Document count mismatch"; +} + + +TEST_F(CrashRecoveryTest, CrashRecoveryDuringInsertion) { + { + auto schema = CreateTestSchema(collection_name_); + auto result = Collection::CreateAndOpen(dir_path_, *schema, options_); + ASSERT_TRUE(result.has_value()); + auto collection = result.value(); + collection.reset(); + } + + RunGeneratorAndCrash("0", "10000", "insert", "0", 3); + + auto result = Collection::Open(dir_path_, options_); + ASSERT_TRUE(result.has_value()) << "Failed to reopen collection after crash. " + "Recovery mechanism may be broken."; + auto collection = result.value(); + uint64_t doc_count{collection->Stats().value().doc_count}; + ASSERT_GT(doc_count, 800) + << "Document count is too low after 3s of insertion and recovery"; + + for (uint64_t doc_id = 0; doc_id < doc_count; doc_id++) { + const auto expected_doc = CreateTestDoc(doc_id, 0); + std::vector pks{}; + pks.emplace_back(expected_doc.pk()); + if (auto res = collection->Fetch(pks); res) { + auto map = res.value(); + if (map.find(expected_doc.pk()) == map.end()) { + FAIL() << "Returned map does not contain doc[" << expected_doc.pk() + << "]"; + } + const auto actual_doc = map.at(expected_doc.pk()); + ASSERT_EQ(*actual_doc, expected_doc) + << "Data mismatch for doc[" << expected_doc.pk() << "]"; + } else { + FAIL() << "Failed to fetch doc[" << expected_doc.pk() << "]"; + } + } +} + + +TEST_F(CrashRecoveryTest, CrashRecoveryDuringUpsert) { + { + auto schema = CreateTestSchema(collection_name_); + auto result = Collection::CreateAndOpen(dir_path_, *schema, options_); + ASSERT_TRUE(result.has_value()); + auto collection = result.value(); + collection.reset(); + } + + RunGenerator("0", "5000", "insert", "0"); + { + auto result = Collection::Open(dir_path_, options_); + ASSERT_TRUE(result.has_value()); + auto collection = result.value(); + ASSERT_EQ(collection->Stats().value().doc_count, 5000) + << "Document count mismatch"; + } + + RunGeneratorAndCrash("4500", "20000", "upsert", "1", 5); + + auto result = Collection::Open(dir_path_, options_); + ASSERT_TRUE(result.has_value()) << "Failed to reopen collection after crash. " + "Recovery mechanism may be broken."; + auto collection = result.value(); + uint64_t doc_count{collection->Stats().value().doc_count}; + ASSERT_GT(doc_count, 6000) + << "Document count is too low after 5s of insertion and recovery"; + + for (uint64_t doc_id = 0; doc_id < doc_count; doc_id++) { + Doc expected_doc; + if (doc_id < 4500) { + expected_doc = CreateTestDoc(doc_id, 0); + } else { + expected_doc = CreateTestDoc(doc_id, 1); + } + std::vector pks{}; + pks.emplace_back(expected_doc.pk()); + if (auto res = collection->Fetch(pks); res) { + auto map = res.value(); + if (map.find(expected_doc.pk()) == map.end()) { + FAIL() << "Returned map does not contain doc[" << expected_doc.pk() + << "]"; + } + const auto actual_doc = map.at(expected_doc.pk()); + ASSERT_EQ(*actual_doc, expected_doc) + << "Data mismatch for doc[" << expected_doc.pk() << "]"; + } else { + FAIL() << "Failed to fetch doc[" << expected_doc.pk() << "]"; + } + } +} + + +TEST_F(CrashRecoveryTest, CrashRecoveryDuringUpdate) { + { + auto schema = CreateTestSchema(collection_name_); + auto result = Collection::CreateAndOpen(dir_path_, *schema, options_); + ASSERT_TRUE(result.has_value()); + auto collection = result.value(); + collection.reset(); + } + + RunGenerator("0", "18000", "upsert", "0"); + { + auto result = Collection::Open(dir_path_, options_); + ASSERT_TRUE(result.has_value()); + auto collection = result.value(); + ASSERT_EQ(collection->Stats().value().doc_count, 18000) + << "Document count mismatch"; + } + + RunGeneratorAndCrash("3000", "15000", "update", "3", 4); + + auto result = Collection::Open(dir_path_, options_); + ASSERT_TRUE(result.has_value()) << "Failed to reopen collection after crash. " + "Recovery mechanism may be broken."; + auto collection = result.value(); + uint64_t doc_count{collection->Stats().value().doc_count}; + ASSERT_EQ(doc_count, 18000) << "Document count mismatch after crash recovery"; + + for (int doc_id = 0; doc_id < 3500; doc_id++) { + Doc expected_doc; + if (doc_id < 3000) { + expected_doc = CreateTestDoc(doc_id, 0); + } else { + expected_doc = CreateTestDoc(doc_id, 3); + } + std::vector pks{}; + pks.emplace_back(expected_doc.pk()); + if (auto res = collection->Fetch(pks); res) { + auto map = res.value(); + if (map.find(expected_doc.pk()) == map.end()) { + FAIL() << "Returned map does not contain doc[" << expected_doc.pk() + << "]"; + } + const auto actual_doc = map.at(expected_doc.pk()); + ASSERT_EQ(*actual_doc, expected_doc) + << "Data mismatch for doc[" << expected_doc.pk() << "]"; + } else { + FAIL() << "Failed to fetch doc[" << expected_doc.pk() << "]"; + } + } +} + + +TEST_F(CrashRecoveryTest, CrashRecoveryDuringDelete) { + { + auto schema = CreateTestSchema(collection_name_); + auto result = Collection::CreateAndOpen(dir_path_, *schema, options_); + ASSERT_TRUE(result.has_value()); + auto collection = result.value(); + collection.reset(); + } + + RunGenerator("0", "18000", "insert", "0"); + { + auto result = Collection::Open(dir_path_, options_); + ASSERT_TRUE(result.has_value()); + auto collection = result.value(); + ASSERT_EQ(collection->Stats().value().doc_count, 18000) + << "Document count mismatch"; + } + + RunGeneratorAndCrash("3000", "15000", "delete", "0", 4); + + auto result = Collection::Open(dir_path_, options_); + ASSERT_TRUE(result.has_value()) << "Failed to reopen collection after crash. " + "Recovery mechanism may be broken."; + auto collection = result.value(); + uint64_t doc_count{collection->Stats().value().doc_count}; + ASSERT_LT(doc_count, 18000) + << "No deletes appear to have been applied before the crash"; + ASSERT_GT(doc_count, 6000) + << "Too many documents deleted, recovery likely lost data"; + + for (int doc_id = 0; doc_id < 3500; doc_id++) { + auto expected_doc = CreateTestDoc(doc_id, 0); + std::vector pks{}; + pks.emplace_back(expected_doc.pk()); + if (auto res = collection->Fetch(pks); res) { + auto map = res.value(); + auto it = map.find(expected_doc.pk()); + ASSERT_NE(it, map.end()) + << "Fetch result missing requested pk[" << expected_doc.pk() << "]"; + if (doc_id < 3000) { + ASSERT_NE(it->second, nullptr) + << "Existing doc returned as nullptr [" << expected_doc.pk() << "]"; + const auto actual_doc = map.at(expected_doc.pk()); + ASSERT_EQ(*actual_doc, expected_doc) + << "Data mismatch for doc[" << expected_doc.pk() << "]"; + } else { + ASSERT_EQ(it->second, nullptr) + << "Returned doc for deleted pk[" << expected_doc.pk() << "]"; + } + } else { + FAIL() << "Failed to fetch doc[" << expected_doc.pk() << "]"; + } + } +} + + +} // namespace zvec From bed78600d929f9e3e2c7c6e8795b868363ad1e2e Mon Sep 17 00:00:00 2001 From: egolearner Date: Wed, 11 Mar 2026 16:22:36 +0800 Subject: [PATCH 14/34] build: support config clang stdlib and fix default (#210) --- cmake/bazel.cmake | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/cmake/bazel.cmake b/cmake/bazel.cmake index f1effc6d..2e4f1ccf 100644 --- a/cmake/bazel.cmake +++ b/cmake/bazel.cmake @@ -365,11 +365,24 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) +if(APPLE OR ANDROID) + option(CLANG_USE_LIBCXX "Use libc++ instead of libstdc++" ON) +else() + option(CLANG_USE_LIBCXX "Use libc++ instead of libstdc++" OFF) +endif() + +set(CLANG_STDLIB_OPTION "") +if(CLANG_USE_LIBCXX) + set(CLANG_STDLIB_OPTION "-stdlib=libc++") +else() + set(CLANG_STDLIB_OPTION "-stdlib=libstdc++") +endif() + if(NOT MSVC) # Use color in diagnostics set( _COMPILER_FLAGS - "$<$:-fcolor-diagnostics;-stdlib=libc++>" + "$<$:-fcolor-diagnostics;${CLANG_STDLIB_OPTION}>" "$<$:-fcolor-diagnostics>" "$<$:-fdiagnostics-color=always>" ) @@ -460,7 +473,7 @@ endif() # C/C++ strict link flags set( BAZEL_CC_STRICT_LINK_FLAGS - "$<$:-stdlib=libc++>" + "$<$:${CLANG_STDLIB_OPTION}>" ${BAZEL_CC_ASAN_COMPILE_FLAGS} ${BAZEL_CC_COVERAGE_COMPILE_FLAGS} ) @@ -479,7 +492,7 @@ set( # C/C++ unstrict link flags set( BAZEL_CC_UNSTRICT_LINK_FLAGS - "$<$:-stdlib=libc++>" + "$<$:${CLANG_STDLIB_OPTION}>" ${BAZEL_CC_ASAN_COMPILE_FLAGS} ${BAZEL_CC_COVERAGE_COMPILE_FLAGS} ) @@ -572,7 +585,7 @@ function(_targets_link_dependencies _NAME) if(TARGET ${LIB}) list(APPEND LIBS_DEPS ${LIB}) list( - APPEND LIBS_INCS + APPEND LIBS_INCS "$" ) endif() @@ -590,45 +603,45 @@ function(_target_link_libraries _NAME) if(NOT _COLLECT_ALWAYS_LINK_VISITED) set(_COLLECT_ALWAYS_LINK_VISITED "" PARENT_SCOPE) endif() - + set(LOCAL_RESULT "") foreach(LIB ${LIB_LIST}) if(NOT TARGET ${LIB}) continue() endif() - + list(FIND _COLLECT_ALWAYS_LINK_VISITED ${LIB} ALREADY_VISITED) if(NOT ALREADY_VISITED EQUAL -1) continue() endif() - + list(APPEND _COLLECT_ALWAYS_LINK_VISITED ${LIB}) set(_COLLECT_ALWAYS_LINK_VISITED "${_COLLECT_ALWAYS_LINK_VISITED}" PARENT_SCOPE) - + get_target_property(ALWAYS_LINK ${LIB} ALWAYS_LINK) if(ALWAYS_LINK) list(APPEND LOCAL_RESULT ${LIB}) endif() - + get_target_property(DEP_LIBS ${LIB} INTERFACE_LINK_LIBRARIES) if(DEP_LIBS) _collect_always_link_libs("${DEP_LIBS}" DEP_ALWAYS_LINK_LIBS) list(APPEND LOCAL_RESULT ${DEP_ALWAYS_LINK_LIBS}) endif() - + get_target_property(LINK_LIBS ${LIB} LINK_LIBRARIES) if(LINK_LIBS) _collect_always_link_libs("${LINK_LIBS}" LINK_ALWAYS_LINK_LIBS) list(APPEND LOCAL_RESULT ${LINK_ALWAYS_LINK_LIBS}) endif() endforeach() - + list(REMOVE_DUPLICATES LOCAL_RESULT) set(${RESULT_VAR} "${LOCAL_RESULT}" PARENT_SCOPE) endfunction() - + _collect_always_link_libs("${ARGN}" ALL_ALWAYS_LINK_LIBS) - + set(ALL_LIBS_TO_PROCESS ${ARGN}) foreach(ALWAYS_LIB ${ALL_ALWAYS_LINK_LIBS}) list(FIND ARGN ${ALWAYS_LIB} FOUND_INDEX) @@ -636,9 +649,9 @@ function(_target_link_libraries _NAME) list(APPEND ALL_LIBS_TO_PROCESS ${ALWAYS_LIB}) endif() endforeach() - + list(REMOVE_DUPLICATES ALL_LIBS_TO_PROCESS) - + foreach(LIB ${ALL_LIBS_TO_PROCESS}) if(NOT TARGET ${LIB}) list(APPEND LINK_LIBS ${LIB}) From 9c04298c971fe83c5fae0ea3ef97011fce8679e0 Mon Sep 17 00:00:00 2001 From: rayx Date: Wed, 11 Mar 2026 22:44:48 +0800 Subject: [PATCH 15/34] refactor/march based reorganization (#193) * add for march compatible test * fix: remove MxN * fix: add macro scope * fix for android ci * refactor: call euclidean via squared euclidean * refactor: add match utility * refactor: comment config out * fix: fix dimension remainder * fix: fix condition & config flags * fix: remove redundant macros * fix: fix condition error * fix: fix macros * fix: remove unnecessary macros * fix: fix remaining code with vector * fix: remove unnecessary codes --------- Co-authored-by: Zefeng Yin Co-authored-by: xufeihong.xfh --- .github/workflows/android_build.yml | 11 + cmake/option.cmake | 72 +- pyproject.toml | 1 + src/ailego/CMakeLists.txt | 85 +- .../math/distance_matrix_euclidean_utility.i | 253 +++ .../distance_matrix_inner_product_utility.i | 208 ++ .../math/distance_matrix_mips_utility.i | 160 ++ src/ailego/math/euclidean_distance_matrix.h | 1964 +---------------- .../math/euclidean_distance_matrix_fp16.cc | 615 ------ .../euclidean_distance_matrix_fp16_avx.cc | 38 + .../euclidean_distance_matrix_fp16_avx512.cc | 96 + ...euclidean_distance_matrix_fp16_dispatch.cc | 87 + .../euclidean_distance_matrix_fp16_neon.cc | 35 + .../euclidean_distance_matrix_fp16_sse.cc | 54 + .../math/euclidean_distance_matrix_fp32.cc | 930 -------- .../euclidean_distance_matrix_fp32_avx.cc | 94 + .../euclidean_distance_matrix_fp32_avx512.cc | 81 + ...euclidean_distance_matrix_fp32_dispatch.cc | 92 + .../euclidean_distance_matrix_fp32_neon.cc | 62 + .../euclidean_distance_matrix_fp32_sse.cc | 78 + .../math/euclidean_distance_matrix_int4.cc | 801 ------- .../euclidean_distance_matrix_int4_avx2.cc | 118 + ...euclidean_distance_matrix_int4_dispatch.cc | 60 + .../euclidean_distance_matrix_int4_sse.cc | 98 + .../math/euclidean_distance_matrix_int8.cc | 884 -------- .../euclidean_distance_matrix_int8_avx2.cc | 182 ++ ...euclidean_distance_matrix_int8_dispatch.cc | 59 + .../euclidean_distance_matrix_int8_sse.cc | 164 ++ src/ailego/math/hamming_distance_matrix.cc | 765 ------- src/ailego/math/hamming_distance_matrix.h | 1005 +-------- src/ailego/math/inner_product_matrix.h | 1964 +---------------- src/ailego/math/inner_product_matrix_fp16.cc | 1948 ---------------- .../math/inner_product_matrix_fp16_avx.cc | 706 ++++++ .../math/inner_product_matrix_fp16_avx512.cc | 766 +++++++ .../inner_product_matrix_fp16_dispatch.cc | 162 ++ .../math/inner_product_matrix_fp16_neon.cc | 42 + src/ailego/math/inner_product_matrix_fp32.cc | 1180 ---------- .../math/inner_product_matrix_fp32_avx.cc | 94 + .../math/inner_product_matrix_fp32_avx512.cc | 75 + .../inner_product_matrix_fp32_dispatch.cc | 97 + .../math/inner_product_matrix_fp32_neon.cc | 57 + .../math/inner_product_matrix_fp32_sse.cc | 351 +++ src/ailego/math/inner_product_matrix_int4.cc | 803 ------- .../math/inner_product_matrix_int4_avx2.cc | 123 ++ .../inner_product_matrix_int4_dispatch.cc | 62 + .../math/inner_product_matrix_int4_sse.cc | 101 + src/ailego/math/inner_product_matrix_int8.cc | 841 ------- .../math/inner_product_matrix_int8_avx2.cc | 189 ++ .../inner_product_matrix_int8_dispatch.cc | 60 + .../math/inner_product_matrix_int8_sse.cc | 157 ++ .../mips_euclidean_distance_matrix_fp16.cc | 409 ---- ...mips_euclidean_distance_matrix_fp16_avx.cc | 116 + ...s_euclidean_distance_matrix_fp16_avx512.cc | 134 ++ ...euclidean_distance_matrix_fp16_dispatch.cc | 96 + ...ips_euclidean_distance_matrix_fp16_neon.cc | 126 ++ .../mips_euclidean_distance_matrix_fp32.cc | 684 ------ ...mips_euclidean_distance_matrix_fp32_avx.cc | 114 + ...s_euclidean_distance_matrix_fp32_avx512.cc | 100 + ...euclidean_distance_matrix_fp32_dispatch.cc | 136 ++ ...ips_euclidean_distance_matrix_fp32_neon.cc | 105 + ...mips_euclidean_distance_matrix_fp32_sse.cc | 336 +++ .../mips_euclidean_distance_matrix_int4.cc | 358 --- ...ips_euclidean_distance_matrix_int4_avx2.cc | 140 ++ ...euclidean_distance_matrix_int4_dispatch.cc | 83 + ...mips_euclidean_distance_matrix_int4_sse.cc | 104 + .../mips_euclidean_distance_matrix_int8.cc | 358 --- ...ips_euclidean_distance_matrix_int8_avx2.cc | 159 ++ ...euclidean_distance_matrix_int8_dispatch.cc | 81 + ...mips_euclidean_distance_matrix_int8_sse.cc | 137 ++ src/ailego/math/norm1_matrix_fp16.cc | 10 +- src/ailego/math/norm1_matrix_fp32.cc | 27 +- src/ailego/math/norm2_matrix_fp16.cc | 16 +- src/ailego/math/norm2_matrix_fp32.cc | 35 +- src/ailego/math_batch/distance_batch.h | 2 - .../math_batch/inner_product_distance_batch.h | 156 +- .../inner_product_distance_batch_dispatch.cc | 228 ++ ...r_product_distance_batch_impl_fp16_avx2.cc | 110 + ...roduct_distance_batch_impl_fp16_avx512.cc} | 106 +- ..._product_distance_batch_impl_fp32_avx2.cc} | 32 +- ...r_product_distance_batch_impl_int8_avx2.cc | 102 + ...roduct_distance_batch_impl_int8_avx512.cc} | 82 +- 81 files changed, 7862 insertions(+), 15750 deletions(-) create mode 100644 src/ailego/math/distance_matrix_euclidean_utility.i create mode 100644 src/ailego/math/distance_matrix_inner_product_utility.i create mode 100644 src/ailego/math/distance_matrix_mips_utility.i delete mode 100644 src/ailego/math/euclidean_distance_matrix_fp16.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_fp16_avx.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_fp16_avx512.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_fp16_dispatch.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_fp16_neon.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_fp16_sse.cc delete mode 100644 src/ailego/math/euclidean_distance_matrix_fp32.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_fp32_avx.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_fp32_avx512.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_fp32_dispatch.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_fp32_neon.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_fp32_sse.cc delete mode 100644 src/ailego/math/euclidean_distance_matrix_int4.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_int4_avx2.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_int4_dispatch.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_int4_sse.cc delete mode 100644 src/ailego/math/euclidean_distance_matrix_int8.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_int8_avx2.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_int8_dispatch.cc create mode 100644 src/ailego/math/euclidean_distance_matrix_int8_sse.cc delete mode 100644 src/ailego/math/inner_product_matrix_fp16.cc create mode 100644 src/ailego/math/inner_product_matrix_fp16_avx.cc create mode 100644 src/ailego/math/inner_product_matrix_fp16_avx512.cc create mode 100644 src/ailego/math/inner_product_matrix_fp16_dispatch.cc create mode 100644 src/ailego/math/inner_product_matrix_fp16_neon.cc delete mode 100644 src/ailego/math/inner_product_matrix_fp32.cc create mode 100644 src/ailego/math/inner_product_matrix_fp32_avx.cc create mode 100644 src/ailego/math/inner_product_matrix_fp32_avx512.cc create mode 100644 src/ailego/math/inner_product_matrix_fp32_dispatch.cc create mode 100644 src/ailego/math/inner_product_matrix_fp32_neon.cc create mode 100644 src/ailego/math/inner_product_matrix_fp32_sse.cc delete mode 100644 src/ailego/math/inner_product_matrix_int4.cc create mode 100644 src/ailego/math/inner_product_matrix_int4_avx2.cc create mode 100644 src/ailego/math/inner_product_matrix_int4_dispatch.cc create mode 100644 src/ailego/math/inner_product_matrix_int4_sse.cc delete mode 100644 src/ailego/math/inner_product_matrix_int8.cc create mode 100644 src/ailego/math/inner_product_matrix_int8_avx2.cc create mode 100644 src/ailego/math/inner_product_matrix_int8_dispatch.cc create mode 100644 src/ailego/math/inner_product_matrix_int8_sse.cc delete mode 100644 src/ailego/math/mips_euclidean_distance_matrix_fp16.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_fp16_avx.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_fp16_avx512.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_fp16_dispatch.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_fp16_neon.cc delete mode 100644 src/ailego/math/mips_euclidean_distance_matrix_fp32.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_fp32_avx.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_fp32_avx512.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_fp32_dispatch.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_fp32_neon.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_fp32_sse.cc delete mode 100644 src/ailego/math/mips_euclidean_distance_matrix_int4.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_int4_avx2.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_int4_dispatch.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_int4_sse.cc delete mode 100644 src/ailego/math/mips_euclidean_distance_matrix_int8.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_int8_avx2.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_int8_dispatch.cc create mode 100644 src/ailego/math/mips_euclidean_distance_matrix_int8_sse.cc create mode 100644 src/ailego/math_batch/inner_product_distance_batch_dispatch.cc create mode 100644 src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx2.cc rename src/ailego/math_batch/{inner_product_distance_batch_impl_fp16.h => inner_product_distance_batch_impl_fp16_avx512.cc} (70%) rename src/ailego/math_batch/{inner_product_distance_batch_impl.h => inner_product_distance_batch_impl_fp32_avx2.cc} (85%) create mode 100644 src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx2.cc rename src/ailego/math_batch/{inner_product_distance_batch_impl_int8.h => inner_product_distance_batch_impl_int8_avx512.cc} (68%) diff --git a/.github/workflows/android_build.yml b/.github/workflows/android_build.yml index 19706d85..099d4868 100644 --- a/.github/workflows/android_build.yml +++ b/.github/workflows/android_build.yml @@ -12,6 +12,13 @@ on: - '**.md' workflow_dispatch: +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }} + cancel-in-progress: true + +permissions: + contents: read + jobs: build-android: # sdkmanager and other Android tools are x86‑only; ARM runners fail with exit code 1 @@ -161,6 +168,10 @@ jobs: adb shell getprop ro.product.cpu.abi adb shell getprop ro.product.cpu.abilist + echo "=== CPU ISA / Instruction Set Support ===" + echo "--- /proc/cpuinfo flags ---" + adb shell 'cat /proc/cpuinfo | grep -E "^(Features|flags)"' + echo "Checking binary sizes:" ls -lah examples/c++/build-android-examples-${{ matrix.abi }}/ diff --git a/cmake/option.cmake b/cmake/option.cmake index 01388564..71e45784 100644 --- a/cmake/option.cmake +++ b/cmake/option.cmake @@ -13,7 +13,7 @@ option(ENABLE_SAPPHIRERAPIDS "Enable Intel Sapphire Rapids Server CPU microarchi option(ENABLE_EMERALDRAPIDS "Enable Intel Emerald Rapids Server CPU microarchitecture" OFF) option(ENABLE_GRANITERAPIDS "Enable Intel Granite Rapids Server CPU microarchitecture" OFF) -option(ENABLE_NATIVE "Enable native CPU microarchitecture" ON) +option(ENABLE_NATIVE "Enable native CPU microarchitecture" OFF) ## AMD Microarchitectures option(ENABLE_ZEN1 "Enable AMD Zen+ Family 17h CPU microarchitecture" OFF) @@ -74,41 +74,57 @@ macro(add_arch_flag FLAG VAR_NAME OPTION_NAME) endif() endmacro() -function(_detect_armv8_best) - set(_arm_flags - "armv8.6-a" "armv8.5-a" "armv8.4-a" "armv8.3-a" "armv8.2-a" "armv8.1-a" "armv8-a" "armv8" - ) - foreach(_ver IN LISTS _arm_flags) - check_c_compiler_flag("-march=${_ver}" _COMP_SUPP_${_ver}) - if(_COMP_SUPP_${_ver}) - _AppendFlags(CMAKE_C_FLAGS "-march=${_ver}") - _AppendFlags(CMAKE_CXX_FLAGS "-march=${_ver}") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" PARENT_SCOPE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" PARENT_SCOPE) - return() - endif() - endforeach() - message(WARNING "No ARMv8 architecture flag supported by compiler.") +function(_setup_armv8_march) + set(_arch "armv8") + check_c_compiler_flag("-march=${_arch}" _COMP_SUPP_${_arch}) + if(_COMP_SUPP_${_arch}) + _AppendFlags(CMAKE_C_FLAGS "-march=${_arch}") + _AppendFlags(CMAKE_CXX_FLAGS "-march=${_arch}") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" PARENT_SCOPE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" PARENT_SCOPE) + return() + else() + message(WARNING "No ARMv8 march flag supported by compiler.") + endif() endfunction() -function(_detect_x86_best) +function(_setup_x86_march) + set(_arch "x86-64") + check_c_compiler_flag("-march=${_arch}" _COMP_SUPP_${_arch}) + if(_COMP_SUPP_${_arch}) + _AppendFlags(CMAKE_C_FLAGS "-march=${_arch}") + _AppendFlags(CMAKE_CXX_FLAGS "-march=${_arch}") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" PARENT_SCOPE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" PARENT_SCOPE) + return() + else() + message(WARNING "No known x86 march flag supported; falling back to generic.") + endif() +endfunction() + +function(setup_compiler_march_for_x86 VAR_NAME_SSE VAR_NAME_AVX2 VAR_NAME_AVX512) + #sse + set(${VAR_NAME_SSE} "-march=corei7" PARENT_SCOPE) + + #avx 2 + set(${VAR_NAME_AVX2} "-march=core-avx2" PARENT_SCOPE) + + #avx512 set(_x86_flags - "graniterapids" "emeraldrapids" "sapphirerapids" - "skylake-avx512" "skylake" - "broadwell" "haswell" "sandybridge" "nehalem" - "znver3" "znver2" "znver1" + "graniterapids" "emeraldrapids" "sapphirerapids" "skylake-avx512" ) foreach(_arch IN LISTS _x86_flags) check_c_compiler_flag("-march=${_arch}" _COMP_SUPP_${_arch}) if(_COMP_SUPP_${_arch}) - _AppendFlags(CMAKE_C_FLAGS "-march=${_arch}") - _AppendFlags(CMAKE_CXX_FLAGS "-march=${_arch}") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" PARENT_SCOPE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" PARENT_SCOPE) + set(${VAR_NAME_AVX512} "-march=${_arch}" PARENT_SCOPE) return() endif() endforeach() - message(WARNING "No known x86 microarchitecture flag supported; falling back to generic.") + + + set(${VAR_NAME_AVX512} "-march=core-avx2" PARENT_SCOPE) + message(WARNING "No known avx512 microarchitecture flag found. Set up as core-avx2") + endfunction() if(MSVC) @@ -206,9 +222,9 @@ else() # AUTO DETECT # Heuristic: detect host architecture and probe appropriate flags if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64|ARM64") - _detect_armv8_best() + _setup_armv8_march() elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|i686|i386|x64") - _detect_x86_best() + _setup_x86_march() else() message(WARNING "Unknown host architecture: ${CMAKE_SYSTEM_PROCESSOR}; no -march= set.") endif() diff --git a/pyproject.toml b/pyproject.toml index 5e99edfa..349a6c2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,6 +124,7 @@ sdist.include = [ [tool.scikit-build.cmake.define] BUILD_TOOLS = "OFF" BUILD_PYTHON_BINDINGS = "ON" +#CMAKE_VERBOSE_MAKEFILE = "ON" # Setuptools config for test pypi [tool.setuptools_scm] diff --git a/src/ailego/CMakeLists.txt b/src/ailego/CMakeLists.txt index 5fcaacac..c9867c36 100644 --- a/src/ailego/CMakeLists.txt +++ b/src/ailego/CMakeLists.txt @@ -18,6 +18,89 @@ if(UNIX AND NOT APPLE) list(APPEND EXTRA_LIBS ${LIB_RT}) endif() +if(NOT ANDROID) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|i686|i386|x64") + setup_compiler_march_for_x86(MATH_MARCH_FLAG_SSE MATH_MARCH_FLAG_AVX2 MATH_MARCH_FLAG_AVX512) + message(STATUS "best compiler march, sse: " ${MATH_MARCH_FLAG_SSE} ", avx2: " ${MATH_MARCH_FLAG_AVX2} ", avx512: " ${MATH_MARCH_FLAG_AVX512}) + + file(GLOB_RECURSE MATH_FILES_SSE + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_sse.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_sse.c + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_sse.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_sse.c + ) + + file(GLOB_RECURSE MATH_FILES_AVX2 + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx2.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx2.c + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx2.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx2.c + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx.c + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx.c + ) + + file(GLOB_RECURSE MATH_FILES_AVX512 + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_dispatch.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_dispatch.c + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx512.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx512.c + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_dispatch.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_dispatch.c + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx512.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx512.c + ) + + foreach(MATH_FILE ${MATH_FILES_SSE}) + set_source_files_properties( + ${MATH_FILE} + PROPERTIES + COMPILE_FLAGS "${MATH_MARCH_FLAG_SSE}" + ) + endforeach() + + foreach(MATH_FILE ${MATH_FILES_AVX2}) + set_source_files_properties( + ${MATH_FILE} + PROPERTIES + COMPILE_FLAGS "${MATH_MARCH_FLAG_AVX2}" + ) + endforeach() + + foreach(MATH_FILE ${MATH_FILES_AVX512}) + set_source_files_properties( + ${MATH_FILE} + PROPERTIES + COMPILE_FLAGS "${MATH_MARCH_FLAG_AVX512}" + ) + endforeach() + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64|ARM64") + # set(CMAKE_CXX_FLAGS "-march=armv8-a") + # set(CMAKE_C_FLAGS "-march=armv8-a") + set(MATH_MARCH_FLAG_NEON "-march=armv8-a") + + file(GLOB_RECURSE MATH_FILES_NEON + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_dispatch.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_dispatch.c + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_dispatch.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_dispatch.c + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_neon.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math/*_neon.c + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_neon.cc + ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_neon.c + ) + + foreach(MATH_FILE ${MATH_FILES_NEON}) + set_source_files_properties( + ${MATH_FILE} + PROPERTIES + COMPILE_FLAGS "${MATH_MARCH_FLAG_NEON}" + ) + endforeach() + endif() +endif() + cc_library( NAME zvec_ailego STATIC STRICT PACKED SRCS ${ALL_SRCS} @@ -25,4 +108,4 @@ cc_library( Arrow::arrow_static Arrow::parquet_static VERSION "${GIT_SRCS_VER}" -) \ No newline at end of file +) diff --git a/src/ailego/math/distance_matrix_euclidean_utility.i b/src/ailego/math/distance_matrix_euclidean_utility.i new file mode 100644 index 00000000..b0b89372 --- /dev/null +++ b/src/ailego/math/distance_matrix_euclidean_utility.i @@ -0,0 +1,253 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Calculate sum of squared difference (GENERAL) +#define SSD_FP32_GENERAL(m, q, sum) \ + { \ + float x = m - q; \ + sum += (x * x); \ + } + +//! Calculate sum of squared difference (SSE) +#define SSD_FP32_SSE(xmm_m, xmm_q, xmm_sum) \ + { \ + __m128 xmm_d = _mm_sub_ps(xmm_m, xmm_q); \ + xmm_sum = _mm_fmadd_ps(xmm_d, xmm_d, xmm_sum); \ + } + +//! Calculate sum of squared difference (AVX) +#define SSD_FP32_AVX(ymm_m, ymm_q, ymm_sum) \ + { \ + __m256 ymm_d = _mm256_sub_ps(ymm_m, ymm_q); \ + ymm_sum = _mm256_fmadd_ps(ymm_d, ymm_d, ymm_sum); \ + } + +//! Calculate sum of squared difference (NEON) +#define SSD_FP32_NEON(v_m, v_q, v_sum) \ + { \ + float32x4_t v_d = vsubq_f32(v_m, v_q); \ + v_sum = vfmaq_f32(v_sum, v_d, v_d); \ + } + +//! Calculate sum of squared difference (GENERAL) +#define SSD_FP16_GENERAL(m, q, sum) \ + { \ + float x = m - q; \ + sum += (x * x); \ + } + +//! Calculate sum of squared difference (NEON) +#define SSD_FP16_NEON(v_m, v_q, v_sum) \ + { \ + float16x8_t v_d = vsubq_f16(v_m, v_q); \ + v_sum = vfmaq_f16(v_sum, v_d, v_d); \ + } + +//! Calculate sum of squared difference (AVX512) +#define SSD_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ + { \ + __m512 zmm_d = _mm512_sub_ps(zmm_m, zmm_q); \ + zmm_sum = _mm512_fmadd_ps(zmm_d, zmm_d, zmm_sum); \ + } + +//! Calculate sum of squared difference (GENERAL) +#define SSD_INT4_GENERAL(m, q, sum) \ + sum += Int4SquaredDiffTable[(((m) << 4) & 0xf0) | (((q) >> 0) & 0xf)] + \ + Int4SquaredDiffTable[(((m) >> 0) & 0xf0) | (((q) >> 4) & 0xf)]; + + +#if defined(__SSE4_1__) +static const __m128i MASK_INT4_SSE = _mm_set1_epi32(0xf0f0f0f0); +static const __m128i ONES_INT16_SSE = _mm_set1_epi32(0x00010001); +#endif // __SSE4_1__ + +//! Compute the square root of value (SSE) +#define SQRT_FP32_SSE(v, ...) _mm_sqrt_ps(_mm_cvtepi32_ps(v)) + +#if defined(__AVX2__) +static const __m256i MASK_INT4_AVX = _mm256_set1_epi32(0xf0f0f0f0); +static const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); +#endif // __AVX2__ + +//! Calculate sum of squared difference (SSE) +#define SSD_INT4_SSE(xmm_m, xmm_q, xmm_sum) \ + { \ + __m128i xmm_lhs = \ + _mm_and_si128(_mm_slli_epi32((xmm_m), 4), MASK_INT4_SSE); \ + __m128i xmm_rhs = \ + _mm_and_si128(_mm_slli_epi32((xmm_q), 4), MASK_INT4_SSE); \ + xmm_lhs = _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), \ + _mm_min_epi8(xmm_lhs, xmm_rhs)), \ + 4); \ + xmm_sum = _mm_add_epi32( \ + _mm_madd_epi16(_mm_maddubs_epi16(xmm_lhs, xmm_lhs), ONES_INT16_SSE), \ + xmm_sum); \ + xmm_lhs = _mm_and_si128((xmm_m), MASK_INT4_SSE); \ + xmm_rhs = _mm_and_si128((xmm_q), MASK_INT4_SSE); \ + xmm_lhs = _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), \ + _mm_min_epi8(xmm_lhs, xmm_rhs)), \ + 4); \ + xmm_sum = _mm_add_epi32( \ + _mm_madd_epi16(_mm_maddubs_epi16(xmm_lhs, xmm_lhs), ONES_INT16_SSE), \ + xmm_sum); \ + } + +//! Compute the distance between matrix and query +#define SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) \ + { \ + __m128i xmm_lhs_0 = \ + _mm_and_si128(_mm_slli_epi32((xmm_lhs), 4), MASK_INT4_SSE); \ + __m128i xmm_rhs_0 = \ + _mm_and_si128(_mm_slli_epi32((xmm_rhs), 4), MASK_INT4_SSE); \ + __m128i xmm_lhs_1 = _mm_and_si128((xmm_lhs), MASK_INT4_SSE); \ + __m128i xmm_rhs_1 = _mm_and_si128((xmm_rhs), MASK_INT4_SSE); \ + xmm_lhs_0 = \ + _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs_0, xmm_rhs_0), \ + _mm_min_epi8(xmm_lhs_0, xmm_rhs_0)), \ + 4); \ + xmm_rhs_0 = \ + _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs_1, xmm_rhs_1), \ + _mm_min_epi8(xmm_lhs_1, xmm_rhs_1)), \ + 4); \ + xmm_lhs_0 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_lhs_0, xmm_lhs_0), \ + ONES_INT16_SSE); \ + xmm_rhs_0 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_rhs_0), \ + ONES_INT16_SSE); \ + xmm_sum = _mm_add_epi32(_mm_add_epi32(xmm_lhs_0, xmm_rhs_0), xmm_sum); \ + } + +//! Calculate sum of squared difference (AVX) +#define SSD_INT4_AVX(ymm_m, ymm_q, ymm_sum) \ + { \ + __m256i ymm_lhs = \ + _mm256_and_si256(_mm256_slli_epi32((ymm_m), 4), MASK_INT4_AVX); \ + __m256i ymm_rhs = \ + _mm256_and_si256(_mm256_slli_epi32((ymm_q), 4), MASK_INT4_AVX); \ + ymm_lhs = \ + _mm256_srli_epi32(_mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), \ + _mm256_min_epi8(ymm_lhs, ymm_rhs)), \ + 4); \ + ymm_sum = _mm256_add_epi32( \ + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_lhs, ymm_lhs), \ + ONES_INT16_AVX), \ + ymm_sum); \ + ymm_lhs = _mm256_and_si256((ymm_m), MASK_INT4_AVX); \ + ymm_rhs = _mm256_and_si256((ymm_q), MASK_INT4_AVX); \ + ymm_lhs = \ + _mm256_srli_epi32(_mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), \ + _mm256_min_epi8(ymm_lhs, ymm_rhs)), \ + 4); \ + ymm_sum = _mm256_add_epi32( \ + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_lhs, ymm_lhs), \ + ONES_INT16_AVX), \ + ymm_sum); \ + } + +//! Compute the distance between matrix and query +#define SSD_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) \ + { \ + __m256i ymm_lhs_0 = \ + _mm256_and_si256(_mm256_slli_epi32((ymm_lhs), 4), MASK_INT4_AVX); \ + __m256i ymm_rhs_0 = \ + _mm256_and_si256(_mm256_slli_epi32((ymm_rhs), 4), MASK_INT4_AVX); \ + __m256i ymm_lhs_1 = _mm256_and_si256((ymm_lhs), MASK_INT4_AVX); \ + __m256i ymm_rhs_1 = _mm256_and_si256((ymm_rhs), MASK_INT4_AVX); \ + ymm_lhs_0 = _mm256_srli_epi32( \ + _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_0, ymm_rhs_0), \ + _mm256_min_epi8(ymm_lhs_0, ymm_rhs_0)), \ + 4); \ + ymm_rhs_0 = _mm256_srli_epi32( \ + _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_1, ymm_rhs_1), \ + _mm256_min_epi8(ymm_lhs_1, ymm_rhs_1)), \ + 4); \ + ymm_lhs_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_lhs_0, ymm_lhs_0), \ + ONES_INT16_AVX); \ + ymm_rhs_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_rhs_0), \ + ONES_INT16_AVX); \ + ymm_sum = \ + _mm256_add_epi32(_mm256_add_epi32(ymm_lhs_0, ymm_rhs_0), ymm_sum); \ + } + +//! Calculate sum of squared difference (GENERAL) +#define SSD_INT8_GENERAL(m, q, sum) \ + { \ + int32_t x = m - q; \ + sum += static_cast(x * x); \ + } + +//! Calculate sum of squared difference (SSE) +#define SSD_INT8_SSE(xmm_m, xmm_q, xmm_sum) \ + { \ + xmm_sum = _mm_add_epi32( \ + _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_m), \ + _mm_sign_epi8(xmm_m, xmm_m)), \ + ONES_INT16_SSE), \ + xmm_sum); \ + xmm_sum = _mm_add_epi32( \ + _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_q), \ + _mm_sign_epi8(xmm_q, xmm_q)), \ + ONES_INT16_SSE), \ + xmm_sum); \ + xmm_sum = _mm_sub_epi32( \ + xmm_sum, \ + _mm_slli_epi32( \ + _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_q), \ + _mm_sign_epi8(xmm_m, xmm_q)), \ + ONES_INT16_SSE), \ + 1)); \ + } + +//! Calculate sum of squared difference (AVX) +#define SSD_INT8_AVX(ymm_m, ymm_q, ymm_sum) \ + { \ + ymm_sum = _mm256_add_epi32( \ + _mm256_madd_epi16( \ + _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_m), \ + _mm256_sign_epi8(ymm_m, ymm_m)), \ + ONES_INT16_AVX), \ + ymm_sum); \ + ymm_sum = _mm256_add_epi32( \ + _mm256_madd_epi16( \ + _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_q), \ + _mm256_sign_epi8(ymm_q, ymm_q)), \ + ONES_INT16_AVX), \ + ymm_sum); \ + ymm_sum = _mm256_sub_epi32( \ + ymm_sum, _mm256_slli_epi32( \ + _mm256_madd_epi16( \ + _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_q), \ + _mm256_sign_epi8(ymm_m, ymm_q)), \ + ONES_INT16_AVX), \ + 1)); \ + } + +//! Compute the square root of value (AVX) +#define SQRT_FP32_AVX(v, ...) _mm256_sqrt_ps(_mm256_cvtepi32_ps(v)) + +//! Compute the square root of value (AVX512) +#define SQRT_FP32_AVX512(v, ...) _mm512_sqrt_ps(_mm512_cvtepi32_ps(v)) + +#define ACCUM_FP32_STEP_SSE SSD_FP32_SSE +#define ACCUM_FP32_STEP_AVX SSD_FP32_AVX + +#define ACCUM_FP32_STEP_AVX512 SSD_FP32_AVX512 +#define ACCUM_FP16_STEP_GENERAL SSD_FP16_GENERAL + +#define ACCUM_FP16_STEP_NEON SSD_FP16_NEON +#define ACCUM_FP32_STEP_NEON SSD_FP32_NEON + +#define ACCUM_INT4_STEP_SSE SSD_INT4_SSE +#define ACCUM_INT4_STEP_AVX SSD_INT4_AVX +#define ACCUM_INT8_STEP_SSE SSD_INT8_SSE +#define ACCUM_INT8_STEP_AVX SSD_INT8_AVX \ No newline at end of file diff --git a/src/ailego/math/distance_matrix_inner_product_utility.i b/src/ailego/math/distance_matrix_inner_product_utility.i new file mode 100644 index 00000000..3f28b15b --- /dev/null +++ b/src/ailego/math/distance_matrix_inner_product_utility.i @@ -0,0 +1,208 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(__SSE4_1__) +//! Four-bits Convert Table +static const AILEGO_ALIGNED(32) int8_t Int4ConvertTable[32] = { + 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1, + 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1}; + +#define NEGZEROS_FP32_SSE _mm_set1_ps(-0.0f) +#define MASK_INT4_SSE _mm_set1_epi32(0x0f0f0f0f) +#define ONES_INT16_SSE _mm_set1_epi32(0x00010001) +#define INT4_LOOKUP_SSE _mm_load_si128((const __m128i *)Int4ConvertTable) +#endif // __SSE4_1__ + +#if defined(__AVX__) +// #define NEGZEROS_FP32_AVX _mm256_set1_ps(-0.0f) +#define MASK_INT4_AVX _mm256_set1_epi32(0x0f0f0f0f) +#define ONES_INT16_AVX _mm256_set1_epi32(0x00010001) +#define INT4_LOOKUP_AVX _mm256_load_si256((const __m256i *)Int4ConvertTable) +#endif // __AVX__ + +#if defined(__AVX512F__) && !defined(__AVX512DQ__) +#define _mm512_xor_ps(a, b) \ + _mm512_castsi512_ps( \ + _mm512_xor_epi32(_mm512_castps_si512(a), _mm512_castps_si512(b))) +#endif // __AVX512DQ__ + +//! Reverse sign of value (GENERAL) +#define NEGATE_FP32_GENERAL(v) -(v) + +//! Calculate Fused-Multiply-Add (SSE) +#define FMA_FP32_SSE(xmm_m, xmm_q, xmm_sum) \ + xmm_sum = _mm_fmadd_ps(xmm_m, xmm_q, xmm_sum); + +//! Calculate Fused-Multiply-Add (AVX) +#define FMA_FP32_AVX(ymm_m, ymm_q, ymm_sum) \ + ymm_sum = _mm256_fmadd_ps(ymm_m, ymm_q, ymm_sum); + +//! Calculate Fused-Multiply-Add (AVX512) +#define FMA_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ + zmm_sum = _mm512_fmadd_ps(zmm_m, zmm_q, zmm_sum); + +//! Calculate Fused-Multiply-Add (AVX512FP16) +#define FMA_FP16_AVX512FP16(zmm_m, zmm_q, zmm_sum) \ + zmm_sum = _mm512_fmadd_ph(zmm_m, zmm_q, zmm_sum); + +//! Calculate Fused-Multiply-Add (GENERAL) +#define FMA_FP16_GENERAL(m, q, sum) sum += (m * q); + +//! Calculate Fused-Multiply-Add (GENERAL) +#define FMA_FP32_GENERAL(m, q, sum) sum += (m * q); + +//! Calculate Fused-Multiply-Add (NEON) +#define FMA_FP16_NEON(v_m, v_q, v_sum) v_sum = vfmaq_f16(v_sum, v_m, v_q); + +//! Calculate Fused-Multiply-Add (NEON) +#define FMA_FP32_NEON(v_m, v_q, v_sum) v_sum = vfmaq_f32(v_sum, v_m, v_q); + +//! Calculate Fused-Multiply-Add (GENERAL) +#define FMA_INT4_GENERAL(m, q, sum) \ + sum += Int4MulTable[(((m) << 4) & 0xf0) | (((q) >> 0) & 0xf)] + \ + Int4MulTable[(((m) >> 0) & 0xf0) | (((q) >> 4) & 0xf)]; + +//! Calculate Fused-Multiply-Add (GENERAL) +#define FMA_INT8_GENERAL(m, q, sum) sum += static_cast(m * q); + +//! Calculate Fused-Multiply-Add (SSE) +#define FMA_INT8_SSE(xmm_m, xmm_q, xmm_sum) \ + xmm_sum = _mm_add_epi32( \ + _mm_madd_epi16( \ + _mm_maddubs_epi16(_mm_abs_epi8(xmm_q), _mm_sign_epi8(xmm_m, xmm_q)), \ + ONES_INT16_SSE), \ + xmm_sum); + +//! Calculate Fused-Multiply-Add (AVX) +#define FMA_INT8_AVX(ymm_m, ymm_q, ymm_sum) \ + ymm_sum = _mm256_add_epi32( \ + _mm256_madd_epi16(_mm256_maddubs_epi16(_mm256_abs_epi8(ymm_q), \ + _mm256_sign_epi8(ymm_m, ymm_q)), \ + ONES_INT16_AVX), \ + ymm_sum); + +//! Calculate Fused-Multiply-Add (SSE) +#define FMA_INT4_SSE(xmm_m, xmm_q, xmm_sum) \ + { \ + __m128i xmm_lhs = _mm_shuffle_epi8(INT4_LOOKUP_SSE, \ + _mm_and_si128((xmm_m), MASK_INT4_SSE)); \ + __m128i xmm_rhs = _mm_shuffle_epi8(INT4_LOOKUP_SSE, \ + _mm_and_si128((xmm_q), MASK_INT4_SSE)); \ + xmm_sum = _mm_add_epi32( \ + _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ + _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ + ONES_INT16_SSE), \ + xmm_sum); \ + xmm_lhs = _mm_shuffle_epi8( \ + INT4_LOOKUP_SSE, \ + _mm_and_si128(_mm_srli_epi32((xmm_m), 4), MASK_INT4_SSE)); \ + xmm_rhs = _mm_shuffle_epi8( \ + INT4_LOOKUP_SSE, \ + _mm_and_si128(_mm_srli_epi32((xmm_q), 4), MASK_INT4_SSE)); \ + xmm_sum = _mm_add_epi32( \ + _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ + _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ + ONES_INT16_SSE), \ + xmm_sum); \ + } + +//! Calculate Fused-Multiply-Add (AVX) +#define FMA_INT4_AVX(ymm_m, ymm_q, ymm_sum) \ + { \ + __m256i ymm_lhs = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, _mm256_and_si256((ymm_m), MASK_INT4_AVX)); \ + __m256i ymm_rhs = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, _mm256_and_si256((ymm_q), MASK_INT4_AVX)); \ + ymm_sum = _mm256_add_epi32( \ + _mm256_madd_epi16( \ + _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_rhs), \ + _mm256_sign_epi8(ymm_lhs, ymm_rhs)), \ + ONES_INT16_AVX), \ + ymm_sum); \ + ymm_lhs = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, \ + _mm256_and_si256(_mm256_srli_epi32((ymm_m), 4), MASK_INT4_AVX)); \ + ymm_rhs = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, \ + _mm256_and_si256(_mm256_srli_epi32((ymm_q), 4), MASK_INT4_AVX)); \ + ymm_sum = _mm256_add_epi32( \ + _mm256_madd_epi16( \ + _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_rhs), \ + _mm256_sign_epi8(ymm_lhs, ymm_rhs)), \ + ONES_INT16_AVX), \ + ymm_sum); \ + } + +//! Compute the distance between matrix and query +#define FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) \ + { \ + __m128i xmm_lhs_0 = _mm_shuffle_epi8( \ + INT4_LOOKUP_SSE, _mm_and_si128((xmm_lhs), MASK_INT4_SSE)); \ + __m128i xmm_rhs_0 = _mm_shuffle_epi8( \ + INT4_LOOKUP_SSE, _mm_and_si128((xmm_rhs), MASK_INT4_SSE)); \ + __m128i xmm_lhs_1 = _mm_shuffle_epi8( \ + INT4_LOOKUP_SSE, \ + _mm_and_si128(_mm_srli_epi32((xmm_lhs), 4), MASK_INT4_SSE)); \ + __m128i xmm_rhs_1 = _mm_shuffle_epi8( \ + INT4_LOOKUP_SSE, \ + _mm_and_si128(_mm_srli_epi32((xmm_rhs), 4), MASK_INT4_SSE)); \ + xmm_lhs_0 = _mm_sign_epi8(xmm_lhs_0, xmm_rhs_0); \ + xmm_lhs_1 = _mm_sign_epi8(xmm_lhs_1, xmm_rhs_1); \ + xmm_rhs_0 = _mm_abs_epi8(xmm_rhs_0); \ + xmm_rhs_1 = _mm_abs_epi8(xmm_rhs_1); \ + xmm_lhs_0 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_lhs_0), \ + ONES_INT16_SSE); \ + xmm_lhs_1 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_1, xmm_lhs_1), \ + ONES_INT16_SSE); \ + xmm_sum = _mm_add_epi32(_mm_add_epi32(xmm_lhs_0, xmm_lhs_1), xmm_sum); \ + } + +//! Compute the distance between matrix and query +#define FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) \ + { \ + __m256i ymm_lhs_0 = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, _mm256_and_si256((ymm_lhs), MASK_INT4_AVX)); \ + __m256i ymm_rhs_0 = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, _mm256_and_si256((ymm_rhs), MASK_INT4_AVX)); \ + __m256i ymm_lhs_1 = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, \ + _mm256_and_si256(_mm256_srli_epi32((ymm_lhs), 4), MASK_INT4_AVX)); \ + __m256i ymm_rhs_1 = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, \ + _mm256_and_si256(_mm256_srli_epi32((ymm_rhs), 4), MASK_INT4_AVX)); \ + ymm_lhs_0 = _mm256_sign_epi8(ymm_lhs_0, ymm_rhs_0); \ + ymm_lhs_1 = _mm256_sign_epi8(ymm_lhs_1, ymm_rhs_1); \ + ymm_rhs_0 = _mm256_abs_epi8(ymm_rhs_0); \ + ymm_rhs_1 = _mm256_abs_epi8(ymm_rhs_1); \ + ymm_lhs_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_lhs_0), \ + ONES_INT16_AVX); \ + ymm_lhs_1 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_1, ymm_lhs_1), \ + ONES_INT16_AVX); \ + ymm_sum = \ + _mm256_add_epi32(_mm256_add_epi32(ymm_lhs_0, ymm_lhs_1), ymm_sum); \ + } + +#define ACCUM_FP16_STEP_GENERAL FMA_FP16_GENERAL +#define ACCUM_FP16_STEP_NEON FMA_FP16_NEON + +#define ACCUM_FP32_STEP_SSE FMA_FP32_SSE +#define ACCUM_FP32_STEP_AVX FMA_FP32_AVX +#define ACCUM_FP32_STEP_AVX512 FMA_FP32_AVX512 +#define ACCUM_FP32_STEP_NEON FMA_FP32_NEON + +#define ACCUM_INT4_STEP_SSE FMA_INT4_SSE +#define ACCUM_INT4_STEP_AVX FMA_INT4_AVX + +#define ACCUM_INT8_STEP_SSE FMA_INT8_SSE +#define ACCUM_INT8_STEP_AVX FMA_INT8_AVX diff --git a/src/ailego/math/distance_matrix_mips_utility.i b/src/ailego/math/distance_matrix_mips_utility.i new file mode 100644 index 00000000..871fdaa5 --- /dev/null +++ b/src/ailego/math/distance_matrix_mips_utility.i @@ -0,0 +1,160 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Calculate Fused-Multiply-Add (AVX512) +#define FMA_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ + zmm_sum = _mm512_fmadd_ps(zmm_m, zmm_q, zmm_sum); + +#define FMA_MASK_FP32_AVX512(zmm_m, zmm_q, zmm_sum, mask) \ + zmm_sum = _mm512_mask3_fmadd_ps(zmm_m, zmm_q, zmm_sum, mask); + +#define HorizontalAdd_FP16_NEON(v) \ + vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(v)), vcvt_high_f32_f16(v))) + +#define HorizontalAdd_FP32_V512_TO_V256(zmm) \ + _mm256_add_ps( \ + _mm512_castps512_ps256(zmm), \ + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(zmm), 1))) + +//! Calculate Fused-Multiply-Add (AVX, FP16) +#define FMA_FP16_GENERAL(lhs, rhs, sum, norm1, norm2) \ + { \ + float v1 = lhs; \ + float v2 = rhs; \ + sum += v1 * v2; \ + norm1 += v1 * v1; \ + norm2 += v2 * v2; \ + } + +//! Calculate Fused-Multiply-Add (GENERAL) +#define FMA_FP32_GENERAL(lhs, rhs, sum, norm1, norm2) \ + { \ + sum += (lhs) * (rhs); \ + norm1 += (lhs) * (lhs); \ + norm2 += (rhs) * (rhs); \ + } + +#if defined(__SSE4_1__) +//! Four-bits Convert Table +static const AILEGO_ALIGNED(32) int8_t Int4ConvertTable[32] = { + 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1, + 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1}; +#endif // __SSE4_1__ + +#if defined(__SSE4_1__) +static const __m128i MASK_INT4_SSE = _mm_set1_epi32(0x0f0f0f0f); +static const __m128i ONES_INT16_SSE = _mm_set1_epi32(0x00010001); +static const __m128i INT4_LOOKUP_SSE = + _mm_load_si128((const __m128i *)Int4ConvertTable); +#endif // __SSE4_1__ + +#if defined(__AVX2__) +static const __m256i MASK_INT4_AVX = _mm256_set1_epi32(0x0f0f0f0f); +static const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); +static const __m256i INT4_LOOKUP_AVX = + _mm256_load_si256((const __m256i *)Int4ConvertTable); +#endif // __AVX2__ + +//! Calculate Fused-Multiply-Add (GENERAL) +#define FMA_INT4_GENERAL(lhs, rhs, sum, norm1, norm2) \ + { \ + sum += Int4MulTable[(((lhs) << 4) & 0xf0) | (((rhs) >> 0) & 0xf)] + \ + Int4MulTable[(((lhs) >> 0) & 0xf0) | (((rhs) >> 4) & 0xf)]; \ + norm1 += static_cast( \ + ((int8_t)((lhs) << 4) >> 4) * ((int8_t)((lhs) << 4) >> 4) + \ + ((int8_t)((lhs) & 0xf0) >> 4) * ((int8_t)((lhs) & 0xf0) >> 4)); \ + norm2 += static_cast( \ + ((int8_t)((rhs) << 4) >> 4) * ((int8_t)((rhs) << 4) >> 4) + \ + ((int8_t)((rhs) & 0xf0) >> 4) * ((int8_t)((rhs) & 0xf0) >> 4)); \ + } + + +//! Compute the distance between matrix and query (SSE) +#define FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum_0, xmm_sum_norm1, \ + xmm_sum_norm2) \ + { \ + __m128i xmm_lhs_0 = _mm_shuffle_epi8( \ + INT4_LOOKUP_SSE, _mm_and_si128((xmm_lhs), MASK_INT4_SSE)); \ + __m128i xmm_rhs_0 = _mm_shuffle_epi8( \ + INT4_LOOKUP_SSE, _mm_and_si128((xmm_rhs), MASK_INT4_SSE)); \ + __m128i xmm_lhs_1 = _mm_shuffle_epi8( \ + INT4_LOOKUP_SSE, \ + _mm_and_si128(_mm_srli_epi32((xmm_lhs), 4), MASK_INT4_SSE)); \ + __m128i xmm_rhs_1 = _mm_shuffle_epi8( \ + INT4_LOOKUP_SSE, \ + _mm_and_si128(_mm_srli_epi32((xmm_rhs), 4), MASK_INT4_SSE)); \ + FMA_INT8_SSE(xmm_lhs_0, xmm_rhs_0, xmm_sum_0); \ + FMA_INT8_SSE(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); \ + FMA_INT8_SSE(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); \ + FMA_INT8_SSE(xmm_lhs_1, xmm_rhs_1, xmm_sum_0); \ + FMA_INT8_SSE(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); \ + FMA_INT8_SSE(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); \ + } + +//! Calculate Fused-Multiply-Add (GENERAL) +#define FMA_INT8_GENERAL(lhs, rhs, sum, norm1, norm2) \ + { \ + sum += static_cast(lhs * rhs); \ + norm1 += static_cast(lhs * lhs); \ + norm2 += static_cast(rhs * rhs); \ + } + +//! Calculate Fused-Multiply-Add (SSE) +#define FMA_INT8_SSE(xmm_lhs, xmm_rhs, xmm_sum) \ + xmm_sum = _mm_add_epi32( \ + _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ + _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ + ONES_INT16_SSE), \ + xmm_sum) + +//! Calculate Fused-Multiply-Add (AVX) +#define FMA_INT8_AVX(ymm_lhs, ymm_rhs, ymm_sum) \ + ymm_sum = _mm256_add_epi32( \ + _mm256_madd_epi16( \ + _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_rhs), \ + _mm256_sign_epi8(ymm_lhs, ymm_rhs)), \ + ONES_INT16_AVX), \ + ymm_sum) + +#define FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_rhs, ymm_sum) \ + ymm_sum = _mm256_add_epi32( \ + _mm256_set_m128i( \ + _mm_setzero_si128(), \ + _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ + _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ + ONES_INT16_SSE)), \ + ymm_sum) + +//! Compute the distance between matrix and query (AVX) +#define FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum_0, ymm_sum1, \ + ymm_sum_norm1, ymm_sum_norm2) \ + { \ + __m256i ymm_lhs_0 = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, _mm256_and_si256((ymm_lhs), MASK_INT4_AVX)); \ + __m256i ymm_rhs_0 = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, _mm256_and_si256((ymm_rhs), MASK_INT4_AVX)); \ + __m256i ymm_lhs_1 = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, \ + _mm256_and_si256(_mm256_srli_epi32((ymm_lhs), 4), MASK_INT4_AVX)); \ + __m256i ymm_rhs_1 = _mm256_shuffle_epi8( \ + INT4_LOOKUP_AVX, \ + _mm256_and_si256(_mm256_srli_epi32((ymm_rhs), 4), MASK_INT4_AVX)); \ + FMA_INT8_AVX(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); \ + FMA_INT8_AVX(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); \ + FMA_INT8_AVX(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); \ + FMA_INT8_AVX(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); \ + FMA_INT8_AVX(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); \ + FMA_INT8_AVX(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); \ + } + diff --git a/src/ailego/math/euclidean_distance_matrix.h b/src/ailego/math/euclidean_distance_matrix.h index 91555070..e8d5b4c8 100644 --- a/src/ailego/math/euclidean_distance_matrix.h +++ b/src/ailego/math/euclidean_distance_matrix.h @@ -462,246 +462,6 @@ struct SquaredEuclideanDistanceMatrix { static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; - -/*! Squared Euclidean Distance Matrix (FP32, M=2, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=2, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=4, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=4, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=4, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=8, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=8, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=8, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=8, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=16, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=16, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=16, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=16, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=16, N=16) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=32, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=32, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=32, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=32, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=32, N=16) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP32, M=32, N=32) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; #endif // __SSE__ || __ARM_NEON #if defined(__SSE__) || (defined(__ARM_NEON) && (defined(__aarch64__))) @@ -716,1515 +476,73 @@ struct EuclideanDistanceMatrix { static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; +#endif // __SSE__ || __ARM_NEON && __aarch64__ -/*! Euclidean Distance Matrix (FP32, M=2, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=2, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=4, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=4, N=2) +#if (defined(__F16C__) && defined(__AVX__)) || \ + (defined(__ARM_NEON) && defined(__aarch64__)) +/*! Squared Euclidean Distance Matrix (FP16, M=1, N=1) */ template <> -struct EuclideanDistanceMatrix { +struct SquaredEuclideanDistanceMatrix { //! Type of value - using ValueType = float; + using ValueType = Float16; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; -/*! Euclidean Distance Matrix (FP32, M=4, N=4) +/*! Euclidean Distance Matrix (FP16, M=1, N=1) */ template <> -struct EuclideanDistanceMatrix { +struct EuclideanDistanceMatrix { //! Type of value - using ValueType = float; + using ValueType = Float16; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; +#endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) -/*! Euclidean Distance Matrix (FP32, M=8, N=1) +#if defined(__SSE4_1__) +/*! Squared Euclidean Distance Matrix (INT8, M=1, N=1) */ template <> -struct EuclideanDistanceMatrix { +struct SquaredEuclideanDistanceMatrix { //! Type of value - using ValueType = float; + using ValueType = int8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; -/*! Euclidean Distance Matrix (FP32, M=8, N=2) +/*! Euclidean Distance Matrix (INT8, M=1, N=1) */ template <> -struct EuclideanDistanceMatrix { +struct EuclideanDistanceMatrix { //! Type of value - using ValueType = float; + using ValueType = int8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; -/*! Euclidean Distance Matrix (FP32, M=8, N=4) +/*! Squared Euclidean Distance Matrix (INT4, M=1, N=1) */ template <> -struct EuclideanDistanceMatrix { +struct SquaredEuclideanDistanceMatrix { //! Type of value - using ValueType = float; + using ValueType = uint8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; -/*! Euclidean Distance Matrix (FP32, M=8, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=16, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=16, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=16, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=16, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=16, N=16) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=32, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=32, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=32, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=32, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=32, N=16) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP32, M=32, N=32) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; -#endif // __SSE__ || __ARM_NEON && __aarch64__ - -#if (defined(__F16C__) && defined(__AVX__)) || \ - (defined(__ARM_NEON) && defined(__aarch64__)) -/*! Squared Euclidean Distance Matrix (FP16, M=1, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=1, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -#if !defined(__ARM_NEON) -/*! Squared Euclidean Distance Matrix (FP16, M=2, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=2, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=4, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=4, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=4, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=8, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=8, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=8, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=8, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=16, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=16, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=16, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=16, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=16, N=16) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=32, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=32, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=32, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=32, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=32, N=16) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (FP16, M=32, N=32) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=2, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=2, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=4, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=4, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=4, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=8, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=8, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=8, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=8, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=16, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=16, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=16, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=16, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=16, N=16) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=32, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=32, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=32, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=32, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=32, N=16) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (FP16, M=32, N=32) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; -#endif // !__ARM_NEON -#endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) - -#if defined(__SSE4_1__) -/*! Squared Euclidean Distance Matrix (INT8, M=1, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=2, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=2, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=4, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=4, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=4, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=8, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=8, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=8, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=8, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=16, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=16, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=16, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=16, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=16, N=16) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=32, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=32, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=32, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=32, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=32, N=16) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT8, M=32, N=32) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=1, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=2, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=2, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=4, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=4, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=4, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=8, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=8, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=8, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=8, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=16, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=16, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=16, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=16, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=16, N=16) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=32, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=32, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=32, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=32, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=32, N=16) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT8, M=32, N=32) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=1, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=2, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=2, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=4, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=4, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=4, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=8, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=8, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=8, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=8, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=16, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=16, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=16, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=16, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=16, N=16) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=32, N=1) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=32, N=2) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=32, N=4) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=32, N=8) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=32, N=16) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Squared Euclidean Distance Matrix (INT4, M=32, N=32) - */ -template <> -struct SquaredEuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=1, N=1) +/*! Euclidean Distance Matrix (INT4, M=1, N=1) */ template <> struct EuclideanDistanceMatrix { @@ -2235,246 +553,6 @@ struct EuclideanDistanceMatrix { static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; - -/*! Euclidean Distance Matrix (INT4, M=2, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=2, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=4, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=4, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=4, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=8, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=8, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=8, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=8, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=16, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=16, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=16, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=16, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=16, N=16) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=32, N=1) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=32, N=2) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=32, N=4) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=32, N=8) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=32, N=16) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Euclidean Distance Matrix (INT4, M=32, N=32) - */ -template <> -struct EuclideanDistanceMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; #endif // __SSE4_1__ /*! Squared Euclidean Distance Sparse Matrix diff --git a/src/ailego/math/euclidean_distance_matrix_fp16.cc b/src/ailego/math/euclidean_distance_matrix_fp16.cc deleted file mode 100644 index ca24561a..00000000 --- a/src/ailego/math/euclidean_distance_matrix_fp16.cc +++ /dev/null @@ -1,615 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "distance_matrix_accum_fp16.i" -#include "euclidean_distance_matrix.h" - -namespace zvec { -namespace ailego { - -#define ACCUM_FP32_STEP_SSE SSD_FP32_SSE -#define ACCUM_FP32_STEP_AVX SSD_FP32_AVX -#define ACCUM_FP32_STEP_AVX512 SSD_FP32_AVX512 -#define ACCUM_FP32_STEP_NEON SSD_FP32_NEON -#define ACCUM_FP16_STEP_GENERAL SSD_FP16_GENERAL -#define ACCUM_FP16_STEP_NEON SSD_FP16_NEON - -//! Calculate sum of squared difference (SSE) -#define SSD_FP32_SSE(xmm_m, xmm_q, xmm_sum) \ - { \ - __m128 xmm_d = _mm_sub_ps(xmm_m, xmm_q); \ - xmm_sum = _mm_fmadd_ps(xmm_d, xmm_d, xmm_sum); \ - } - -//! Calculate sum of squared difference (AVX) -#define SSD_FP32_AVX(ymm_m, ymm_q, ymm_sum) \ - { \ - __m256 ymm_d = _mm256_sub_ps(ymm_m, ymm_q); \ - ymm_sum = _mm256_fmadd_ps(ymm_d, ymm_d, ymm_sum); \ - } - -//! Calculate sum of squared difference (AVX512) -#define SSD_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ - { \ - __m512 zmm_d = _mm512_sub_ps(zmm_m, zmm_q); \ - zmm_sum = _mm512_fmadd_ps(zmm_d, zmm_d, zmm_sum); \ - } - -//! Calculate sum of squared difference (GENERAL) -#define SSD_FP16_GENERAL(m, q, sum) \ - { \ - float x = m - q; \ - sum += (x * x); \ - } - -//! Calculate sum of squared difference (NEON) -#define SSD_FP16_NEON(v_m, v_q, v_sum) \ - { \ - float16x8_t v_d = vsubq_f16(v_m, v_q); \ - v_sum = vfmaq_f16(v_sum, v_d, v_d); \ - } - -//! Calculate sum of squared difference (NEON) -#define SSD_FP32_NEON(v_m, v_q, v_sum) \ - { \ - float32x4_t v_d = vsubq_f32(v_m, v_q); \ - v_sum = vfmaq_f32(v_sum, v_d, v_d); \ - } - -#if (defined(__F16C__) && defined(__AVX__)) || \ - (defined(__ARM_NEON) && defined(__aarch64__)) - -#if defined(__AVX512FP16__) -//! Squared Euclidean Distance -static inline float SquaredEuclideanDistanceAVX512FP16(const Float16 *lhs, - const Float16 *rhs, - size_t size) { - const Float16 *last = lhs + size; - const Float16 *last_aligned = lhs + ((size >> 6) << 6); - - __m512h zmm_sum_0 = _mm512_setzero_ph(); - __m512h zmm_sum_1 = _mm512_setzero_ph(); - - if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { - for (; lhs != last_aligned; lhs += 64, rhs += 64) { - __m512h zmm_d_0 = - _mm512_sub_ph(_mm512_load_ph(lhs + 0), _mm512_load_ph(rhs + 0)); - __m512h zmm_d_1 = - _mm512_sub_ph(_mm512_load_ph(lhs + 32), _mm512_load_ph(rhs + 32)); - zmm_sum_0 = _mm512_fmadd_ph(zmm_d_0, zmm_d_0, zmm_sum_0); - zmm_sum_1 = _mm512_fmadd_ph(zmm_d_1, zmm_d_1, zmm_sum_1); - } - - if (last >= last_aligned + 32) { - __m512h zmm_d = _mm512_sub_ph(_mm512_load_ph(lhs), _mm512_load_ph(rhs)); - zmm_sum_0 = _mm512_fmadd_ph(zmm_d, zmm_d, zmm_sum_0); - lhs += 32; - rhs += 32; - } - } else { - for (; lhs != last_aligned; lhs += 64, rhs += 64) { - __m512h zmm_d_0 = - _mm512_sub_ph(_mm512_loadu_ph(lhs + 0), _mm512_loadu_ph(rhs + 0)); - __m512h zmm_d_1 = - _mm512_sub_ph(_mm512_loadu_ph(lhs + 32), _mm512_loadu_ph(rhs + 32)); - zmm_sum_0 = _mm512_fmadd_ph(zmm_d_0, zmm_d_0, zmm_sum_0); - zmm_sum_1 = _mm512_fmadd_ph(zmm_d_1, zmm_d_1, zmm_sum_1); - } - - if (last >= last_aligned + 32) { - __m512h zmm_d = _mm512_sub_ph(_mm512_loadu_ph(lhs), _mm512_loadu_ph(rhs)); - zmm_sum_0 = _mm512_fmadd_ph(zmm_d, zmm_d, zmm_sum_0); - lhs += 32; - rhs += 32; - } - } - - zmm_sum_0 = _mm512_add_ph(zmm_sum_0, zmm_sum_1); - if (lhs != last) { - __mmask32 mask = (__mmask32)((1 << (last - lhs)) - 1); - __m512i zmm_undefined = _mm512_undefined_epi32(); - __m512h zmm_undefined_ph = _mm512_undefined_ph(); - __m512h zmm_d = _mm512_mask_sub_ph( - zmm_undefined_ph, mask, - _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, lhs)), - _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, rhs))); - zmm_sum_0 = _mm512_mask3_fmadd_ph(zmm_d, zmm_d, zmm_sum_0, mask); - } - - return HorizontalAdd_FP16_V512(zmm_sum_0); -} -#endif - - -//! Compute the distance between matrix and query (FP16, M=1, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP16_1X1_NEON(m, q, dim, out, 0ull, ) -#else -#if defined(__AVX512FP16__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { - *out = SquaredEuclideanDistanceAVX512FP16(m, q, dim); - return; - } -#endif -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_1X1_AVX512(m, q, dim, out, 0ull, ) - return; - } -#endif - ACCUM_FP16_1X1_AVX(m, q, dim, out, 0ull, ) -#endif //__ARM_NEON -} - -//! Compute the distance between matrix and query (FP16, M=1, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP16_1X1_NEON(m, q, dim, out, 0ull, std::sqrt) -#else -#if defined(__AVX512FP16__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { - *out = std::sqrt(SquaredEuclideanDistanceAVX512FP16(m, q, dim)); - return; - } -#endif -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_1X1_AVX512(m, q, dim, out, 0ull, std::sqrt) - return; - } -#endif - ACCUM_FP16_1X1_AVX(m, q, dim, out, 0ull, std::sqrt) -#endif //__ARM_NEON -} - -#if !defined(__ARM_NEON) -//! Compute the distance between matrix and query (FP16, M=2, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { - ACCUM_FP16_2X1_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=2, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { - ACCUM_FP16_2X2_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { - ACCUM_FP16_4X1_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { - ACCUM_FP16_4X2_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { - ACCUM_FP16_4X4_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { - ACCUM_FP16_8X1_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { - ACCUM_FP16_8X2_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { - ACCUM_FP16_8X4_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { - ACCUM_FP16_8X8_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_16X1_AVX512(m, q, dim, out, ) - return; - } -#endif // __AVX512F__ - - ACCUM_FP16_16X1_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_16X2_AVX512(m, q, dim, out, ) - return; - } -#endif // __AVX512F__ - - ACCUM_FP16_16X2_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_16X4_AVX512(m, q, dim, out, ) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_16X4_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_16X8_AVX512(m, q, dim, out, ) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_16X8_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=16) -void SquaredEuclideanDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_16X16_AVX512(m, q, dim, out, ) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_16X16_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X1_AVX512(m, q, dim, out, ) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X1_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X2_AVX512(m, q, dim, out, ) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X2_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X4_AVX512(m, q, dim, out, ) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X4_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X8_AVX512(m, q, dim, out, ) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X8_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=16) -void SquaredEuclideanDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X16_AVX512(m, q, dim, out, ) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X16_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=32) -void SquaredEuclideanDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X32_AVX512(m, q, dim, out, ) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X32_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=2, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_2X1_AVX(m, q, dim, out, _mm_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=2, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_2X2_AVX(m, q, dim, out, _mm_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_4X1_AVX(m, q, dim, out, _mm_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_4X2_AVX(m, q, dim, out, _mm_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_4X4_AVX(m, q, dim, out, _mm_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_8X1_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_8X2_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_8X4_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_8X8_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_16X1_AVX512(m, q, dim, out, _mm512_sqrt_ps) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_16X1_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_16X2_AVX512(m, q, dim, out, _mm512_sqrt_ps) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_16X2_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_16X4_AVX512(m, q, dim, out, _mm512_sqrt_ps) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_16X4_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_16X8_AVX512(m, q, dim, out, _mm512_sqrt_ps) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_16X8_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=16) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_16X16_AVX512(m, q, dim, out, _mm512_sqrt_ps) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_16X16_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X1_AVX512(m, q, dim, out, _mm512_sqrt_ps) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X1_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X2_AVX512(m, q, dim, out, _mm512_sqrt_ps) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X2_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X4_AVX512(m, q, dim, out, _mm512_sqrt_ps) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X4_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X8_AVX512(m, q, dim, out, _mm512_sqrt_ps) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X8_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=16) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X16_AVX512(m, q, dim, out, _mm512_sqrt_ps) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X16_AVX(m, q, dim, out, _mm256_sqrt_ps) -} - -//! Compute the distance between matrix and query (FP16, M=32, N=32) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_32X32_AVX512(m, q, dim, out, _mm512_sqrt_ps) - return; - } -#endif // __AVX512F__ - ACCUM_FP16_32X32_AVX(m, q, dim, out, _mm256_sqrt_ps) -} -#endif // !__ARM_NEON -#endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_fp16_avx.cc b/src/ailego/math/euclidean_distance_matrix_fp16_avx.cc new file mode 100644 index 00000000..0adf738c --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_fp16_avx.cc @@ -0,0 +1,38 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp16.i" +#include "distance_matrix_euclidean_utility.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX__) + +void SquaredEuclideanDistanceAVX(const Float16 *lhs, const Float16 *rhs, + size_t size, float *out) { + ACCUM_FP16_1X1_AVX(lhs, rhs, size, out, 0ull, ) +} + +//! EuclideanDistance +void EuclideanDistanceAVX(const Float16 *lhs, const Float16 *rhs, size_t size, + float *out) { + ACCUM_FP16_1X1_AVX(lhs, rhs, size, out, 0ull, std::sqrt) +} + +#endif // __AVX__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_fp16_avx512.cc b/src/ailego/math/euclidean_distance_matrix_fp16_avx512.cc new file mode 100644 index 00000000..244f5db3 --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_fp16_avx512.cc @@ -0,0 +1,96 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp16.i" +#include "distance_matrix_euclidean_utility.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX512FP16__) +//! Squared Euclidean Distance +float SquaredEuclideanDistanceAVX512FP16(const Float16 *lhs, const Float16 *rhs, + size_t size) { + const Float16 *last = lhs + size; + const Float16 *last_aligned = lhs + ((size >> 6) << 6); + + __m512h zmm_sum_0 = _mm512_setzero_ph(); + __m512h zmm_sum_1 = _mm512_setzero_ph(); + + if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + __m512h zmm_d_0 = + _mm512_sub_ph(_mm512_load_ph(lhs + 0), _mm512_load_ph(rhs + 0)); + __m512h zmm_d_1 = + _mm512_sub_ph(_mm512_load_ph(lhs + 32), _mm512_load_ph(rhs + 32)); + zmm_sum_0 = _mm512_fmadd_ph(zmm_d_0, zmm_d_0, zmm_sum_0); + zmm_sum_1 = _mm512_fmadd_ph(zmm_d_1, zmm_d_1, zmm_sum_1); + } + + if (last >= last_aligned + 32) { + __m512h zmm_d = _mm512_sub_ph(_mm512_load_ph(lhs), _mm512_load_ph(rhs)); + zmm_sum_0 = _mm512_fmadd_ph(zmm_d, zmm_d, zmm_sum_0); + lhs += 32; + rhs += 32; + } + } else { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + __m512h zmm_d_0 = + _mm512_sub_ph(_mm512_loadu_ph(lhs + 0), _mm512_loadu_ph(rhs + 0)); + __m512h zmm_d_1 = + _mm512_sub_ph(_mm512_loadu_ph(lhs + 32), _mm512_loadu_ph(rhs + 32)); + zmm_sum_0 = _mm512_fmadd_ph(zmm_d_0, zmm_d_0, zmm_sum_0); + zmm_sum_1 = _mm512_fmadd_ph(zmm_d_1, zmm_d_1, zmm_sum_1); + } + + if (last >= last_aligned + 32) { + __m512h zmm_d = _mm512_sub_ph(_mm512_loadu_ph(lhs), _mm512_loadu_ph(rhs)); + zmm_sum_0 = _mm512_fmadd_ph(zmm_d, zmm_d, zmm_sum_0); + lhs += 32; + rhs += 32; + } + } + + zmm_sum_0 = _mm512_add_ph(zmm_sum_0, zmm_sum_1); + if (lhs != last) { + __mmask32 mask = (__mmask32)((1 << (last - lhs)) - 1); + __m512i zmm_undefined = _mm512_undefined_epi32(); + __m512h zmm_undefined_ph = _mm512_undefined_ph(); + __m512h zmm_d = _mm512_mask_sub_ph( + zmm_undefined_ph, mask, + _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, lhs)), + _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, rhs))); + zmm_sum_0 = _mm512_mask3_fmadd_ph(zmm_d, zmm_d, zmm_sum_0, mask); + } + + return HorizontalAdd_FP16_V512(zmm_sum_0); +} +#endif + +#if defined(__AVX512F__) +void SquaredEuclideanDistanceAVX512(const Float16 *lhs, const Float16 *rhs, + size_t size, float *out) { + ACCUM_FP16_1X1_AVX512(lhs, rhs, size, out, 0ull, ) +} + +//! EuclideanDistance +void EuclideanDistanceAVX512(const Float16 *lhs, const Float16 *rhs, + size_t size, float *out) { + ACCUM_FP16_1X1_AVX512(lhs, rhs, size, out, 0ull, std::sqrt) +} + +#endif +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_fp16_dispatch.cc b/src/ailego/math/euclidean_distance_matrix_fp16_dispatch.cc new file mode 100644 index 00000000..1d08b8bc --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_fp16_dispatch.cc @@ -0,0 +1,87 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) +void SquaredEuclideanDistanceNEON(const Float16 *lhs, const Float16 *rhs, + size_t size, float *out); +void EuclideanDistanceNEON(const Float16 *lhs, const Float16 *rhs, size_t size, + float *out); +#endif + +#if defined(__AVX512FP16__) +float SquaredEuclideanDistanceAVX512FP16(const Float16 *lhs, const Float16 *rhs, + size_t size); +#endif + +#if defined(__AVX512F__) +void SquaredEuclideanDistanceAVX512(const Float16 *lhs, const Float16 *rhs, + size_t size, float *out); + +void EuclideanDistanceAVX512(const Float16 *lhs, const Float16 *rhs, + size_t size, float *out); +#endif + +#if defined(__AVX__) +void SquaredEuclideanDistanceAVX(const Float16 *lhs, const Float16 *rhs, + size_t size, float *out); +void EuclideanDistanceAVX(const Float16 *lhs, const Float16 *rhs, size_t size, + float *out); +#endif + +#if (defined(__F16C__) && defined(__AVX__)) || \ + (defined(__ARM_NEON) && defined(__aarch64__)) +//! Compute the distance between matrix and query (FP16, M=1, N=1) +void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, + float *out) { +#if defined(__ARM_NEON) + SquaredEuclideanDistanceNEON(m, q, dim, out); +#else +#if defined(__AVX512FP16__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { + *out = SquaredEuclideanDistanceAVX512FP16(m, q, dim); + return; + } +#endif +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + SquaredEuclideanDistanceAVX512(m, q, dim, out); + // ACCUM_FP16_1X1_AVX512(m, q, dim, out, 0ull, ) + return; + } +#endif + SquaredEuclideanDistanceAVX(m, q, dim, out); + // ACCUM_FP16_1X1_AVX(m, q, dim, out, 0ull, ) +#endif //__ARM_NEON +} + +//! Compute the distance between matrix and query (FP16, M=1, N=1) +void EuclideanDistanceMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, float *out) { + SquaredEuclideanDistanceMatrix::Compute(m, q, dim, out); + *out = std::sqrt(*out); +} + +#endif + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_fp16_neon.cc b/src/ailego/math/euclidean_distance_matrix_fp16_neon.cc new file mode 100644 index 00000000..4527056b --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_fp16_neon.cc @@ -0,0 +1,35 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp16.i" +#include "distance_matrix_euclidean_utility.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) +void SquaredEuclideanDistanceNEON(const Float16 *lhs, const Float16 *rhs, + size_t size, float *out) { + ACCUM_FP16_1X1_NEON(lhs, rhs, size, out, 0ull, ) +} + +void EuclideanDistanceNEON(const Float16 *lhs, const Float16 *rhs, size_t size, + float *out) { + ACCUM_FP16_1X1_NEON(lhs, rhs, size, out, 0ull, std::sqrt) +} +#endif + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_fp16_sse.cc b/src/ailego/math/euclidean_distance_matrix_fp16_sse.cc new file mode 100644 index 00000000..6291346c --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_fp16_sse.cc @@ -0,0 +1,54 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "distance_matrix_accum_fp16.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#define ACCUM_FP32_STEP_SSE SSD_FP32_SSE +#define ACCUM_FP16_STEP_GENERAL SSD_FP16_GENERAL + +//! Calculate sum of squared difference (SSE) +#define SSD_FP32_SSE(xmm_m, xmm_q, xmm_sum) \ + { \ + __m128 xmm_d = _mm_sub_ps(xmm_m, xmm_q); \ + xmm_sum = _mm_fmadd_ps(xmm_d, xmm_d, xmm_sum); \ + } + +//! Calculate sum of squared difference (GENERAL) +#define SSD_FP16_GENERAL(m, q, sum) \ + { \ + float x = m - q; \ + sum += (x * x); \ + } + +//! Calculate sum of squared difference (NEON) +#define SSD_FP16_NEON(v_m, v_q, v_sum) \ + { \ + float16x8_t v_d = vsubq_f16(v_m, v_q); \ + v_sum = vfmaq_f16(v_sum, v_d, v_d); \ + } + +//! Calculate sum of squared difference (NEON) +#define SSD_FP32_NEON(v_m, v_q, v_sum) \ + { \ + float32x4_t v_d = vsubq_f32(v_m, v_q); \ + v_sum = vfmaq_f32(v_sum, v_d, v_d); \ + } + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_fp32.cc b/src/ailego/math/euclidean_distance_matrix_fp32.cc deleted file mode 100644 index 7a024731..00000000 --- a/src/ailego/math/euclidean_distance_matrix_fp32.cc +++ /dev/null @@ -1,930 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "distance_matrix_accum_fp32.i" -#include "euclidean_distance_matrix.h" - -namespace zvec { -namespace ailego { - -#define ACCUM_FP32_STEP_SSE SSD_FP32_SSE -#define ACCUM_FP32_STEP_AVX SSD_FP32_AVX -#define ACCUM_FP32_STEP_AVX512 SSD_FP32_AVX512 -#define ACCUM_FP32_STEP_NEON SSD_FP32_NEON - -//! Calculate sum of squared difference (GENERAL) -#define SSD_FP32_GENERAL(m, q, sum) \ - { \ - float x = m - q; \ - sum += (x * x); \ - } - -//! Calculate sum of squared difference (SSE) -#define SSD_FP32_SSE(xmm_m, xmm_q, xmm_sum) \ - { \ - __m128 xmm_d = _mm_sub_ps(xmm_m, xmm_q); \ - xmm_sum = _mm_fmadd_ps(xmm_d, xmm_d, xmm_sum); \ - } - -//! Calculate sum of squared difference (AVX) -#define SSD_FP32_AVX(ymm_m, ymm_q, ymm_sum) \ - { \ - __m256 ymm_d = _mm256_sub_ps(ymm_m, ymm_q); \ - ymm_sum = _mm256_fmadd_ps(ymm_d, ymm_d, ymm_sum); \ - } - -//! Calculate sum of squared difference (AVX512) -#define SSD_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ - { \ - __m512 zmm_d = _mm512_sub_ps(zmm_m, zmm_q); \ - zmm_sum = _mm512_fmadd_ps(zmm_d, zmm_d, zmm_sum); \ - } - -//! Calculate sum of squared difference (NEON) -#define SSD_FP32_NEON(v_m, v_q, v_sum) \ - { \ - float32x4_t v_d = vsubq_f32(v_m, v_q); \ - v_sum = vfmaq_f32(v_sum, v_d, v_d); \ - } - -#if defined(__ARM_NEON) -//! Squared Euclidean Distance -static inline float SquaredEuclideanDistanceNEON(const float *lhs, - const float *rhs, - size_t size) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 3) << 3); - - float32x4_t v_sum_0 = vdupq_n_f32(0); - float32x4_t v_sum_1 = vdupq_n_f32(0); - - for (; lhs != last_aligned; lhs += 8, rhs += 8) { - float32x4_t v_d_0 = vsubq_f32(vld1q_f32(lhs + 0), vld1q_f32(rhs + 0)); - float32x4_t v_d_1 = vsubq_f32(vld1q_f32(lhs + 4), vld1q_f32(rhs + 4)); - v_sum_0 = vfmaq_f32(v_sum_0, v_d_0, v_d_0); - v_sum_1 = vfmaq_f32(v_sum_1, v_d_1, v_d_1); - } - if (last >= last_aligned + 4) { - float32x4_t v_d = vsubq_f32(vld1q_f32(lhs), vld1q_f32(rhs)); - v_sum_0 = vfmaq_f32(v_sum_0, v_d, v_d); - lhs += 4; - rhs += 4; - } - - float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); - switch (last - lhs) { - case 3: - SSD_FP32_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - SSD_FP32_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - SSD_FP32_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __ARM_NEON - -#if defined(__SSE__) -//! Squared Euclidean Distance -static inline float SquaredEuclideanDistanceSSE(const float *lhs, - const float *rhs, size_t size) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 3) << 3); - - __m128 xmm_sum_0 = _mm_setzero_ps(); - __m128 xmm_sum_1 = _mm_setzero_ps(); - - if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { - for (; lhs != last_aligned; lhs += 8, rhs += 8) { - __m128 xmm_d_0 = _mm_sub_ps(_mm_load_ps(lhs + 0), _mm_load_ps(rhs + 0)); - __m128 xmm_d_1 = _mm_sub_ps(_mm_load_ps(lhs + 4), _mm_load_ps(rhs + 4)); - xmm_sum_0 = _mm_fmadd_ps(xmm_d_0, xmm_d_0, xmm_sum_0); - xmm_sum_1 = _mm_fmadd_ps(xmm_d_1, xmm_d_1, xmm_sum_1); - } - - if (last >= last_aligned + 4) { - __m128 xmm_d = _mm_sub_ps(_mm_load_ps(lhs), _mm_load_ps(rhs)); - xmm_sum_0 = _mm_fmadd_ps(xmm_d, xmm_d, xmm_sum_0); - lhs += 4; - rhs += 4; - } - } else { - for (; lhs != last_aligned; lhs += 8, rhs += 8) { - __m128 xmm_d_0 = _mm_sub_ps(_mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0)); - __m128 xmm_d_1 = _mm_sub_ps(_mm_loadu_ps(lhs + 4), _mm_loadu_ps(rhs + 4)); - xmm_sum_0 = _mm_fmadd_ps(xmm_d_0, xmm_d_0, xmm_sum_0); - xmm_sum_1 = _mm_fmadd_ps(xmm_d_1, xmm_d_1, xmm_sum_1); - } - - if (last >= last_aligned + 4) { - __m128 xmm_d = _mm_sub_ps(_mm_loadu_ps(lhs), _mm_loadu_ps(rhs)); - xmm_sum_0 = _mm_fmadd_ps(xmm_d, xmm_d, xmm_sum_0); - lhs += 4; - rhs += 4; - } - } - float result = HorizontalAdd_FP32_V128(_mm_add_ps(xmm_sum_0, xmm_sum_1)); - - switch (last - lhs) { - case 3: - SSD_FP32_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - SSD_FP32_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - SSD_FP32_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __SSE__ - -#if defined(__AVX__) -//! Squared Euclidean Distance -static inline float SquaredEuclideanDistanceAVX(const float *lhs, - const float *rhs, size_t size) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 4) << 4); - - __m256 ymm_sum_0 = _mm256_setzero_ps(); - __m256 ymm_sum_1 = _mm256_setzero_ps(); - - if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m256 ymm_d_0 = - _mm256_sub_ps(_mm256_load_ps(lhs + 0), _mm256_load_ps(rhs + 0)); - __m256 ymm_d_1 = - _mm256_sub_ps(_mm256_load_ps(lhs + 8), _mm256_load_ps(rhs + 8)); - ymm_sum_0 = _mm256_fmadd_ps(ymm_d_0, ymm_d_0, ymm_sum_0); - ymm_sum_1 = _mm256_fmadd_ps(ymm_d_1, ymm_d_1, ymm_sum_1); - } - - if (last >= last_aligned + 8) { - __m256 ymm_d = _mm256_sub_ps(_mm256_load_ps(lhs), _mm256_load_ps(rhs)); - ymm_sum_0 = _mm256_fmadd_ps(ymm_d, ymm_d, ymm_sum_0); - lhs += 8; - rhs += 8; - } - } else { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m256 ymm_d_0 = - _mm256_sub_ps(_mm256_loadu_ps(lhs + 0), _mm256_loadu_ps(rhs + 0)); - __m256 ymm_d_1 = - _mm256_sub_ps(_mm256_loadu_ps(lhs + 8), _mm256_loadu_ps(rhs + 8)); - ymm_sum_0 = _mm256_fmadd_ps(ymm_d_0, ymm_d_0, ymm_sum_0); - ymm_sum_1 = _mm256_fmadd_ps(ymm_d_1, ymm_d_1, ymm_sum_1); - } - - if (last >= last_aligned + 8) { - __m256 ymm_d = _mm256_sub_ps(_mm256_loadu_ps(lhs), _mm256_loadu_ps(rhs)); - ymm_sum_0 = _mm256_fmadd_ps(ymm_d, ymm_d, ymm_sum_0); - lhs += 8; - rhs += 8; - } - } - float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); - - switch (last - lhs) { - case 7: - SSD_FP32_GENERAL(lhs[6], rhs[6], result) - /* FALLTHRU */ - case 6: - SSD_FP32_GENERAL(lhs[5], rhs[5], result) - /* FALLTHRU */ - case 5: - SSD_FP32_GENERAL(lhs[4], rhs[4], result) - /* FALLTHRU */ - case 4: - SSD_FP32_GENERAL(lhs[3], rhs[3], result) - /* FALLTHRU */ - case 3: - SSD_FP32_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - SSD_FP32_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - SSD_FP32_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __AVX__ - -#if defined(__AVX512F__) -//! Squared Euclidean Distance -static inline float SquaredEuclideanDistanceAVX512(const float *lhs, - const float *rhs, - size_t size) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 5) << 5); - - __m512 zmm_sum_0 = _mm512_setzero_ps(); - __m512 zmm_sum_1 = _mm512_setzero_ps(); - - if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m512 zmm_d_0 = - _mm512_sub_ps(_mm512_load_ps(lhs + 0), _mm512_load_ps(rhs + 0)); - __m512 zmm_d_1 = - _mm512_sub_ps(_mm512_load_ps(lhs + 16), _mm512_load_ps(rhs + 16)); - zmm_sum_0 = _mm512_fmadd_ps(zmm_d_0, zmm_d_0, zmm_sum_0); - zmm_sum_1 = _mm512_fmadd_ps(zmm_d_1, zmm_d_1, zmm_sum_1); - } - - if (last >= last_aligned + 16) { - __m512 zmm_d = _mm512_sub_ps(_mm512_load_ps(lhs), _mm512_load_ps(rhs)); - zmm_sum_0 = _mm512_fmadd_ps(zmm_d, zmm_d, zmm_sum_0); - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m512 zmm_d_0 = - _mm512_sub_ps(_mm512_loadu_ps(lhs + 0), _mm512_loadu_ps(rhs + 0)); - __m512 zmm_d_1 = - _mm512_sub_ps(_mm512_loadu_ps(lhs + 16), _mm512_loadu_ps(rhs + 16)); - zmm_sum_0 = _mm512_fmadd_ps(zmm_d_0, zmm_d_0, zmm_sum_0); - zmm_sum_1 = _mm512_fmadd_ps(zmm_d_1, zmm_d_1, zmm_sum_1); - } - - if (last >= last_aligned + 16) { - __m512 zmm_d = _mm512_sub_ps(_mm512_loadu_ps(lhs), _mm512_loadu_ps(rhs)); - zmm_sum_0 = _mm512_fmadd_ps(zmm_d, zmm_d, zmm_sum_0); - lhs += 16; - rhs += 16; - } - } - - zmm_sum_0 = _mm512_add_ps(zmm_sum_0, zmm_sum_1); - if (lhs != last) { - __mmask16 mask = (__mmask16)((1 << (last - lhs)) - 1); - __m512 zmm_undefined = _mm512_undefined_ps(); - __m512 zmm_d = _mm512_mask_sub_ps( - zmm_undefined, mask, _mm512_mask_loadu_ps(zmm_undefined, mask, lhs), - _mm512_mask_loadu_ps(zmm_undefined, mask, rhs)); - zmm_sum_0 = _mm512_mask3_fmadd_ps(zmm_d, zmm_d, zmm_sum_0, mask); - } - return HorizontalAdd_FP32_V512(zmm_sum_0); -} -#endif - -#if defined(__SSE__) || defined(__ARM_NEON) -//! Compute the distance between matrix and query (FP32, M=1, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - *out = SquaredEuclideanDistanceNEON(m, q, dim); -#else -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - if (dim > 15) { - *out = SquaredEuclideanDistanceAVX512(m, q, dim); - return; - } - } -#endif // __AVX512F__ -#if defined(__AVX__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { - if (dim > 7) { - *out = SquaredEuclideanDistanceAVX(m, q, dim); - return; - } - } -#endif // __AVX__ - *out = SquaredEuclideanDistanceSSE(m, q, dim); -#endif // __ARM_NEON -} - -//! Compute the distance between matrix and query (FP32, M=2, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_2X1_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_2X1_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_2X1_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=2, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_2X2_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_2X2_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_2X2_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X1_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_4X1_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_4X1_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X2_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_4X2_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_4X2_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X4_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_4X4_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_4X4_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X1_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_8X1_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_8X1_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X2_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_8X2_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_8X2_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X4_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_8X4_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_8X4_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X8_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_8X8_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_8X8_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=16, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X1_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_16X1_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_16X1_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_16X1_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X2_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_16X2_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_16X2_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_16X2_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X4_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_16X4_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_16X4_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_16X4_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X8_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_16X8_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_16X8_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_16X8_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=16) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X16_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_16X16_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_16X16_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_16X16_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X1_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X1_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X1_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X1_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X2_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X2_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X2_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X2_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X4_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X4_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X4_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X4_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X8_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X8_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X8_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X8_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=16) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X16_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X16_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X16_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X16_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=32) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X32_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X32_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X32_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X32_SSE(m, q, dim, out, ) -#endif -} -#endif // __SSE__ || __ARM_NEON - -#if defined(__SSE__) || (defined(__ARM_NEON) && defined(__aarch64__)) -//! Compute the distance between matrix and query (FP32, M=1, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - *out = std::sqrt(SquaredEuclideanDistanceNEON(m, q, dim)); -#else -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - if (dim > 15) { - *out = std::sqrt(SquaredEuclideanDistanceAVX512(m, q, dim)); - return; - } - } -#endif // __AVX512F__ - -#if defined(__AVX__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { - if (dim > 7) { - *out = std::sqrt(SquaredEuclideanDistanceAVX(m, q, dim)); - return; - } - } -#endif // __AVX__ - *out = std::sqrt(SquaredEuclideanDistanceSSE(m, q, dim)); -#endif // __ARM_NEON -} - -//! Compute the distance between matrix and query (FP32, M=2, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_2X1_NEON(m, q, dim, out, vsqrt_f32) -#elif defined(__AVX__) - ACCUM_FP32_2X1_AVX(m, q, dim, out, _mm_sqrt_ps) -#else - ACCUM_FP32_2X1_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=2, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_2X2_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX__) - ACCUM_FP32_2X2_AVX(m, q, dim, out, _mm_sqrt_ps) -#else - ACCUM_FP32_2X2_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X1_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX__) - ACCUM_FP32_4X1_AVX(m, q, dim, out, _mm_sqrt_ps) -#else - ACCUM_FP32_4X1_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X2_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX__) - ACCUM_FP32_4X2_AVX(m, q, dim, out, _mm_sqrt_ps) -#else - ACCUM_FP32_4X2_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X4_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX__) - ACCUM_FP32_4X4_AVX(m, q, dim, out, _mm_sqrt_ps) -#else - ACCUM_FP32_4X4_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X1_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX__) - ACCUM_FP32_8X1_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_8X1_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X2_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX__) - ACCUM_FP32_8X2_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_8X2_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X4_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX__) - ACCUM_FP32_8X4_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_8X4_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X8_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX__) - ACCUM_FP32_8X8_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_8X8_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=16, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X1_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_16X1_AVX512(m, q, dim, out, _mm512_sqrt_ps) -#elif defined(__AVX__) - ACCUM_FP32_16X1_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_16X1_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X2_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_16X2_AVX512(m, q, dim, out, _mm512_sqrt_ps) -#elif defined(__AVX__) - ACCUM_FP32_16X2_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_16X2_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X4_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_16X4_AVX512(m, q, dim, out, _mm512_sqrt_ps) -#elif defined(__AVX__) - ACCUM_FP32_16X4_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_16X4_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X8_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_16X8_AVX512(m, q, dim, out, _mm512_sqrt_ps) -#elif defined(__AVX__) - ACCUM_FP32_16X8_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_16X8_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=16) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X16_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_16X16_AVX512(m, q, dim, out, _mm512_sqrt_ps) -#elif defined(__AVX__) - ACCUM_FP32_16X16_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_16X16_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X1_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X1_AVX512(m, q, dim, out, _mm512_sqrt_ps) -#elif defined(__AVX__) - ACCUM_FP32_32X1_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_32X1_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X2_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X2_AVX512(m, q, dim, out, _mm512_sqrt_ps) -#elif defined(__AVX__) - ACCUM_FP32_32X2_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_32X2_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X4_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X4_AVX512(m, q, dim, out, _mm512_sqrt_ps) -#elif defined(__AVX__) - ACCUM_FP32_32X4_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_32X4_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X8_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X8_AVX512(m, q, dim, out, _mm512_sqrt_ps) -#elif defined(__AVX__) - ACCUM_FP32_32X8_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_32X8_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=16) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X16_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X16_AVX512(m, q, dim, out, _mm512_sqrt_ps) -#elif defined(__AVX__) - ACCUM_FP32_32X16_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_32X16_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=32) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X32_NEON(m, q, dim, out, vsqrtq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X32_AVX512(m, q, dim, out, _mm512_sqrt_ps) -#elif defined(__AVX__) - ACCUM_FP32_32X32_AVX(m, q, dim, out, _mm256_sqrt_ps) -#else - ACCUM_FP32_32X32_SSE(m, q, dim, out, _mm_sqrt_ps) -#endif -} -#endif // __SSE__ || __ARM_NEON && __aarch64__ - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_fp32_avx.cc b/src/ailego/math/euclidean_distance_matrix_fp32_avx.cc new file mode 100644 index 00000000..3fdcad5a --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_fp32_avx.cc @@ -0,0 +1,94 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_euclidean_utility.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX__) +float SquaredEuclideanDistanceAVX(const float *lhs, const float *rhs, + size_t size) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 4) << 4); + + __m256 ymm_sum_0 = _mm256_setzero_ps(); + __m256 ymm_sum_1 = _mm256_setzero_ps(); + + if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m256 ymm_d_0 = + _mm256_sub_ps(_mm256_load_ps(lhs + 0), _mm256_load_ps(rhs + 0)); + __m256 ymm_d_1 = + _mm256_sub_ps(_mm256_load_ps(lhs + 8), _mm256_load_ps(rhs + 8)); + ymm_sum_0 = _mm256_fmadd_ps(ymm_d_0, ymm_d_0, ymm_sum_0); + ymm_sum_1 = _mm256_fmadd_ps(ymm_d_1, ymm_d_1, ymm_sum_1); + } + + if (last >= last_aligned + 8) { + __m256 ymm_d = _mm256_sub_ps(_mm256_load_ps(lhs), _mm256_load_ps(rhs)); + ymm_sum_0 = _mm256_fmadd_ps(ymm_d, ymm_d, ymm_sum_0); + lhs += 8; + rhs += 8; + } + } else { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m256 ymm_d_0 = + _mm256_sub_ps(_mm256_loadu_ps(lhs + 0), _mm256_loadu_ps(rhs + 0)); + __m256 ymm_d_1 = + _mm256_sub_ps(_mm256_loadu_ps(lhs + 8), _mm256_loadu_ps(rhs + 8)); + ymm_sum_0 = _mm256_fmadd_ps(ymm_d_0, ymm_d_0, ymm_sum_0); + ymm_sum_1 = _mm256_fmadd_ps(ymm_d_1, ymm_d_1, ymm_sum_1); + } + + if (last >= last_aligned + 8) { + __m256 ymm_d = _mm256_sub_ps(_mm256_loadu_ps(lhs), _mm256_loadu_ps(rhs)); + ymm_sum_0 = _mm256_fmadd_ps(ymm_d, ymm_d, ymm_sum_0); + lhs += 8; + rhs += 8; + } + } + float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); + + switch (last - lhs) { + case 7: + SSD_FP32_GENERAL(lhs[6], rhs[6], result) + /* FALLTHRU */ + case 6: + SSD_FP32_GENERAL(lhs[5], rhs[5], result) + /* FALLTHRU */ + case 5: + SSD_FP32_GENERAL(lhs[4], rhs[4], result) + /* FALLTHRU */ + case 4: + SSD_FP32_GENERAL(lhs[3], rhs[3], result) + /* FALLTHRU */ + case 3: + SSD_FP32_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + SSD_FP32_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + SSD_FP32_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +#endif // __AVX__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_fp32_avx512.cc b/src/ailego/math/euclidean_distance_matrix_fp32_avx512.cc new file mode 100644 index 00000000..f9a82506 --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_fp32_avx512.cc @@ -0,0 +1,81 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_euclidean_utility.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX512F__) +//! Squared Euclidean Distance +float SquaredEuclideanDistanceAVX512(const float *lhs, const float *rhs, + size_t size) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 5) << 5); + + __m512 zmm_sum_0 = _mm512_setzero_ps(); + __m512 zmm_sum_1 = _mm512_setzero_ps(); + + if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m512 zmm_d_0 = + _mm512_sub_ps(_mm512_load_ps(lhs + 0), _mm512_load_ps(rhs + 0)); + __m512 zmm_d_1 = + _mm512_sub_ps(_mm512_load_ps(lhs + 16), _mm512_load_ps(rhs + 16)); + zmm_sum_0 = _mm512_fmadd_ps(zmm_d_0, zmm_d_0, zmm_sum_0); + zmm_sum_1 = _mm512_fmadd_ps(zmm_d_1, zmm_d_1, zmm_sum_1); + } + + if (last >= last_aligned + 16) { + __m512 zmm_d = _mm512_sub_ps(_mm512_load_ps(lhs), _mm512_load_ps(rhs)); + zmm_sum_0 = _mm512_fmadd_ps(zmm_d, zmm_d, zmm_sum_0); + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m512 zmm_d_0 = + _mm512_sub_ps(_mm512_loadu_ps(lhs + 0), _mm512_loadu_ps(rhs + 0)); + __m512 zmm_d_1 = + _mm512_sub_ps(_mm512_loadu_ps(lhs + 16), _mm512_loadu_ps(rhs + 16)); + zmm_sum_0 = _mm512_fmadd_ps(zmm_d_0, zmm_d_0, zmm_sum_0); + zmm_sum_1 = _mm512_fmadd_ps(zmm_d_1, zmm_d_1, zmm_sum_1); + } + + if (last >= last_aligned + 16) { + __m512 zmm_d = _mm512_sub_ps(_mm512_loadu_ps(lhs), _mm512_loadu_ps(rhs)); + zmm_sum_0 = _mm512_fmadd_ps(zmm_d, zmm_d, zmm_sum_0); + lhs += 16; + rhs += 16; + } + } + + zmm_sum_0 = _mm512_add_ps(zmm_sum_0, zmm_sum_1); + if (lhs != last) { + __mmask16 mask = (__mmask16)((1 << (last - lhs)) - 1); + __m512 zmm_undefined = _mm512_undefined_ps(); + __m512 zmm_d = _mm512_mask_sub_ps( + zmm_undefined, mask, _mm512_mask_loadu_ps(zmm_undefined, mask, lhs), + _mm512_mask_loadu_ps(zmm_undefined, mask, rhs)); + zmm_sum_0 = _mm512_mask3_fmadd_ps(zmm_d, zmm_d, zmm_sum_0, mask); + } + return HorizontalAdd_FP32_V512(zmm_sum_0); +} + +#endif + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_fp32_dispatch.cc b/src/ailego/math/euclidean_distance_matrix_fp32_dispatch.cc new file mode 100644 index 00000000..08d31c6a --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_fp32_dispatch.cc @@ -0,0 +1,92 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) +void SquaredEuclideanDistanceNEON(const float *lhs, const float *rhs, + size_t size, float *out); +#endif + +#if defined(__AVX512F__) +float SquaredEuclideanDistanceAVX512(const float *lhs, const float *rhs, + size_t size); +float EuclideanDistanceAVX512(const float *lhs, const float *rhs, size_t size); +#endif + +#if defined(__AVX__) +float SquaredEuclideanDistanceAVX(const float *lhs, const float *rhs, + size_t size); +float EuclideanDistanceAVX(const float *lhs, const float *rhs, size_t size); +#endif + +#if defined(__SSE__) +float SquaredEuclideanDistanceSSE(const float *lhs, const float *rhs, + size_t size); +float EuclideanDistanceSSE(const float *lhs, const float *rhs, size_t size); +#endif + +//----------------------------------------------------------- +// SquaredEuclideanDistance +//----------------------------------------------------------- +#if defined(__SSE__) || defined(__ARM_NEON) +//! Compute the distance between matrix and query (FP32, M=1, N=1) +void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, + float *out) { +#if defined(__ARM_NEON) + SquaredEuclideanDistanceNEON(m, q, dim, out); +#else +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + if (dim > 15) { + *out = SquaredEuclideanDistanceAVX512(m, q, dim); + return; + } + } +#endif // __AVX512F__ +#if defined(__AVX__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { + if (dim > 7) { + *out = SquaredEuclideanDistanceAVX(m, q, dim); + return; + } + } +#endif // __AVX__ + *out = SquaredEuclideanDistanceSSE(m, q, dim); +#endif // __ARM_NEON +} +#endif // __SSE__ || __ARM_NEON + + +//----------------------------------------------------------- +// EuclideanDistance +//----------------------------------------------------------- +#if defined(__SSE__) || (defined(__ARM_NEON) && defined(__aarch64__)) +//! Compute the distance between matrix and query (FP32, M=1, N=1) +void EuclideanDistanceMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, float *out) { + SquaredEuclideanDistanceMatrix::Compute(m, q, dim, out); + *out = std::sqrt(*out); +} +#endif // __SSE__ || __ARM_NEON && __aarch64__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_fp32_neon.cc b/src/ailego/math/euclidean_distance_matrix_fp32_neon.cc new file mode 100644 index 00000000..3827fafe --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_fp32_neon.cc @@ -0,0 +1,62 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_euclidean_utility.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) +//! Squared Euclidean Distance +void SquaredEuclideanDistanceNEON(const float *lhs, const float *rhs, + size_t size, float *out) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 3) << 3); + + float32x4_t v_sum_0 = vdupq_n_f32(0); + float32x4_t v_sum_1 = vdupq_n_f32(0); + + for (; lhs != last_aligned; lhs += 8, rhs += 8) { + float32x4_t v_d_0 = vsubq_f32(vld1q_f32(lhs + 0), vld1q_f32(rhs + 0)); + float32x4_t v_d_1 = vsubq_f32(vld1q_f32(lhs + 4), vld1q_f32(rhs + 4)); + v_sum_0 = vfmaq_f32(v_sum_0, v_d_0, v_d_0); + v_sum_1 = vfmaq_f32(v_sum_1, v_d_1, v_d_1); + } + if (last >= last_aligned + 4) { + float32x4_t v_d = vsubq_f32(vld1q_f32(lhs), vld1q_f32(rhs)); + v_sum_0 = vfmaq_f32(v_sum_0, v_d, v_d); + lhs += 4; + rhs += 4; + } + + float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); + switch (last - lhs) { + case 3: + SSD_FP32_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + SSD_FP32_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + SSD_FP32_GENERAL(lhs[0], rhs[0], result) + } + *out = result; +} + +#endif // __ARM_NEON + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_fp32_sse.cc b/src/ailego/math/euclidean_distance_matrix_fp32_sse.cc new file mode 100644 index 00000000..a4cf588e --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_fp32_sse.cc @@ -0,0 +1,78 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_euclidean_utility.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__SSE__) +float SquaredEuclideanDistanceSSE(const float *lhs, const float *rhs, + size_t size) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 3) << 3); + + __m128 xmm_sum_0 = _mm_setzero_ps(); + __m128 xmm_sum_1 = _mm_setzero_ps(); + + if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { + for (; lhs != last_aligned; lhs += 8, rhs += 8) { + __m128 xmm_d_0 = _mm_sub_ps(_mm_load_ps(lhs + 0), _mm_load_ps(rhs + 0)); + __m128 xmm_d_1 = _mm_sub_ps(_mm_load_ps(lhs + 4), _mm_load_ps(rhs + 4)); + xmm_sum_0 = _mm_fmadd_ps(xmm_d_0, xmm_d_0, xmm_sum_0); + xmm_sum_1 = _mm_fmadd_ps(xmm_d_1, xmm_d_1, xmm_sum_1); + } + + if (last >= last_aligned + 4) { + __m128 xmm_d = _mm_sub_ps(_mm_load_ps(lhs), _mm_load_ps(rhs)); + xmm_sum_0 = _mm_fmadd_ps(xmm_d, xmm_d, xmm_sum_0); + lhs += 4; + rhs += 4; + } + } else { + for (; lhs != last_aligned; lhs += 8, rhs += 8) { + __m128 xmm_d_0 = _mm_sub_ps(_mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0)); + __m128 xmm_d_1 = _mm_sub_ps(_mm_loadu_ps(lhs + 4), _mm_loadu_ps(rhs + 4)); + xmm_sum_0 = _mm_fmadd_ps(xmm_d_0, xmm_d_0, xmm_sum_0); + xmm_sum_1 = _mm_fmadd_ps(xmm_d_1, xmm_d_1, xmm_sum_1); + } + + if (last >= last_aligned + 4) { + __m128 xmm_d = _mm_sub_ps(_mm_loadu_ps(lhs), _mm_loadu_ps(rhs)); + xmm_sum_0 = _mm_fmadd_ps(xmm_d, xmm_d, xmm_sum_0); + lhs += 4; + rhs += 4; + } + } + float result = HorizontalAdd_FP32_V128(_mm_add_ps(xmm_sum_0, xmm_sum_1)); + + switch (last - lhs) { + case 3: + SSD_FP32_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + SSD_FP32_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + SSD_FP32_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +#endif // __SSE__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_int4.cc b/src/ailego/math/euclidean_distance_matrix_int4.cc deleted file mode 100644 index bef43213..00000000 --- a/src/ailego/math/euclidean_distance_matrix_int4.cc +++ /dev/null @@ -1,801 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "distance_matrix_accum_int4.i" -#include "euclidean_distance_matrix.h" - -namespace zvec { -namespace ailego { - -#define ACCUM_INT4_STEP_SSE SSD_INT4_SSE -#define ACCUM_INT4_STEP_AVX SSD_INT4_AVX - -#if defined(__SSE4_1__) -static const __m128i MASK_INT4_SSE = _mm_set1_epi32(0xf0f0f0f0); -static const __m128i ONES_INT16_SSE = _mm_set1_epi32(0x00010001); -#endif // __SSE4_1__ - -#if defined(__AVX2__) -static const __m256i MASK_INT4_AVX = _mm256_set1_epi32(0xf0f0f0f0); -static const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); -#endif // __AVX2__ - -//! Calculate sum of squared difference (GENERAL) -#define SSD_INT4_GENERAL(m, q, sum) \ - sum += Int4SquaredDiffTable[(((m) << 4) & 0xf0) | (((q) >> 0) & 0xf)] + \ - Int4SquaredDiffTable[(((m) >> 0) & 0xf0) | (((q) >> 4) & 0xf)]; - -//! Calculate sum of squared difference (SSE) -#define SSD_INT4_SSE(xmm_m, xmm_q, xmm_sum) \ - { \ - __m128i xmm_lhs = \ - _mm_and_si128(_mm_slli_epi32((xmm_m), 4), MASK_INT4_SSE); \ - __m128i xmm_rhs = \ - _mm_and_si128(_mm_slli_epi32((xmm_q), 4), MASK_INT4_SSE); \ - xmm_lhs = _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), \ - _mm_min_epi8(xmm_lhs, xmm_rhs)), \ - 4); \ - xmm_sum = _mm_add_epi32( \ - _mm_madd_epi16(_mm_maddubs_epi16(xmm_lhs, xmm_lhs), ONES_INT16_SSE), \ - xmm_sum); \ - xmm_lhs = _mm_and_si128((xmm_m), MASK_INT4_SSE); \ - xmm_rhs = _mm_and_si128((xmm_q), MASK_INT4_SSE); \ - xmm_lhs = _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), \ - _mm_min_epi8(xmm_lhs, xmm_rhs)), \ - 4); \ - xmm_sum = _mm_add_epi32( \ - _mm_madd_epi16(_mm_maddubs_epi16(xmm_lhs, xmm_lhs), ONES_INT16_SSE), \ - xmm_sum); \ - } - -//! Calculate sum of squared difference (AVX) -#define SSD_INT4_AVX(ymm_m, ymm_q, ymm_sum) \ - { \ - __m256i ymm_lhs = \ - _mm256_and_si256(_mm256_slli_epi32((ymm_m), 4), MASK_INT4_AVX); \ - __m256i ymm_rhs = \ - _mm256_and_si256(_mm256_slli_epi32((ymm_q), 4), MASK_INT4_AVX); \ - ymm_lhs = \ - _mm256_srli_epi32(_mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), \ - _mm256_min_epi8(ymm_lhs, ymm_rhs)), \ - 4); \ - ymm_sum = _mm256_add_epi32( \ - _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_lhs, ymm_lhs), \ - ONES_INT16_AVX), \ - ymm_sum); \ - ymm_lhs = _mm256_and_si256((ymm_m), MASK_INT4_AVX); \ - ymm_rhs = _mm256_and_si256((ymm_q), MASK_INT4_AVX); \ - ymm_lhs = \ - _mm256_srli_epi32(_mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), \ - _mm256_min_epi8(ymm_lhs, ymm_rhs)), \ - 4); \ - ymm_sum = _mm256_add_epi32( \ - _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_lhs, ymm_lhs), \ - ONES_INT16_AVX), \ - ymm_sum); \ - } - -//! Compute the distance between matrix and query -#define SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) \ - { \ - __m128i xmm_lhs_0 = \ - _mm_and_si128(_mm_slli_epi32((xmm_lhs), 4), MASK_INT4_SSE); \ - __m128i xmm_rhs_0 = \ - _mm_and_si128(_mm_slli_epi32((xmm_rhs), 4), MASK_INT4_SSE); \ - __m128i xmm_lhs_1 = _mm_and_si128((xmm_lhs), MASK_INT4_SSE); \ - __m128i xmm_rhs_1 = _mm_and_si128((xmm_rhs), MASK_INT4_SSE); \ - xmm_lhs_0 = \ - _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs_0, xmm_rhs_0), \ - _mm_min_epi8(xmm_lhs_0, xmm_rhs_0)), \ - 4); \ - xmm_rhs_0 = \ - _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs_1, xmm_rhs_1), \ - _mm_min_epi8(xmm_lhs_1, xmm_rhs_1)), \ - 4); \ - xmm_lhs_0 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_lhs_0, xmm_lhs_0), \ - ONES_INT16_SSE); \ - xmm_rhs_0 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_rhs_0), \ - ONES_INT16_SSE); \ - xmm_sum = _mm_add_epi32(_mm_add_epi32(xmm_lhs_0, xmm_rhs_0), xmm_sum); \ - } - -//! Compute the distance between matrix and query -#define SSD_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) \ - { \ - __m256i ymm_lhs_0 = \ - _mm256_and_si256(_mm256_slli_epi32((ymm_lhs), 4), MASK_INT4_AVX); \ - __m256i ymm_rhs_0 = \ - _mm256_and_si256(_mm256_slli_epi32((ymm_rhs), 4), MASK_INT4_AVX); \ - __m256i ymm_lhs_1 = _mm256_and_si256((ymm_lhs), MASK_INT4_AVX); \ - __m256i ymm_rhs_1 = _mm256_and_si256((ymm_rhs), MASK_INT4_AVX); \ - ymm_lhs_0 = _mm256_srli_epi32( \ - _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_0, ymm_rhs_0), \ - _mm256_min_epi8(ymm_lhs_0, ymm_rhs_0)), \ - 4); \ - ymm_rhs_0 = _mm256_srli_epi32( \ - _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_1, ymm_rhs_1), \ - _mm256_min_epi8(ymm_lhs_1, ymm_rhs_1)), \ - 4); \ - ymm_lhs_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_lhs_0, ymm_lhs_0), \ - ONES_INT16_AVX); \ - ymm_rhs_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_rhs_0), \ - ONES_INT16_AVX); \ - ymm_sum = \ - _mm256_add_epi32(_mm256_add_epi32(ymm_lhs_0, ymm_rhs_0), ymm_sum); \ - } - -//! Compute the square root of value (SSE) -#define SQRT_FP32_SSE(v, ...) _mm_sqrt_ps(_mm_cvtepi32_ps(v)) - -//! Compute the square root of value (AVX) -#define SQRT_FP32_AVX(v, ...) _mm256_sqrt_ps(_mm256_cvtepi32_ps(v)) - -//! Compute the square root of value (AVX512) -#define SQRT_FP32_AVX512(v, ...) _mm512_sqrt_ps(_mm512_cvtepi32_ps(v)) - -#if defined(__SSE4_1__) -//! Squared Euclidean Distance -static inline float SquaredEuclideanDistanceSSE(const uint8_t *lhs, - const uint8_t *rhs, - size_t size) { - const uint8_t *last = lhs + size; - const uint8_t *last_aligned = lhs + ((size >> 4) << 4); - - __m128i xmm_sum = _mm_setzero_si128(); - - if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m128i xmm_lhs = _mm_load_si128((const __m128i *)(lhs)); - __m128i xmm_rhs = _mm_load_si128((const __m128i *)(rhs)); - SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) - } - } else { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)(lhs)); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)(rhs)); - SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) - } - } - float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); - - switch (last - lhs) { - case 15: - SSD_INT4_GENERAL(lhs[14], rhs[14], result) - /* FALLTHRU */ - case 14: - SSD_INT4_GENERAL(lhs[13], rhs[13], result) - /* FALLTHRU */ - case 13: - SSD_INT4_GENERAL(lhs[12], rhs[12], result) - /* FALLTHRU */ - case 12: - SSD_INT4_GENERAL(lhs[11], rhs[11], result) - /* FALLTHRU */ - case 11: - SSD_INT4_GENERAL(lhs[10], rhs[10], result) - /* FALLTHRU */ - case 10: - SSD_INT4_GENERAL(lhs[9], rhs[9], result) - /* FALLTHRU */ - case 9: - SSD_INT4_GENERAL(lhs[8], rhs[8], result) - /* FALLTHRU */ - case 8: - SSD_INT4_GENERAL(lhs[7], rhs[7], result) - /* FALLTHRU */ - case 7: - SSD_INT4_GENERAL(lhs[6], rhs[6], result) - /* FALLTHRU */ - case 6: - SSD_INT4_GENERAL(lhs[5], rhs[5], result) - /* FALLTHRU */ - case 5: - SSD_INT4_GENERAL(lhs[4], rhs[4], result) - /* FALLTHRU */ - case 4: - SSD_INT4_GENERAL(lhs[3], rhs[3], result) - /* FALLTHRU */ - case 3: - SSD_INT4_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - SSD_INT4_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - SSD_INT4_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __SSE4_1__ - -#if defined(__AVX2__) -//! Squared Euclidean Distance -static inline float SquaredEuclideanDistanceAVX(const uint8_t *lhs, - const uint8_t *rhs, - size_t size) { - const uint8_t *last = lhs + size; - const uint8_t *last_aligned = lhs + ((size >> 5) << 5); - - __m256i ymm_sum = _mm256_setzero_si256(); - - if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m256i ymm_lhs = _mm256_load_si256((const __m256i *)(lhs)); - __m256i ymm_rhs = _mm256_load_si256((const __m256i *)(rhs)); - SSD_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) - } - if (last >= lhs + 16) { - __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); - __m128i xmm_sum = _mm_setzero_si128(); - SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) - ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), - ymm_sum); - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)(lhs)); - __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)(rhs)); - SSD_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) - } - if (last >= lhs + 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); - __m128i xmm_sum = _mm_setzero_si128(); - SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) - ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), - ymm_sum); - lhs += 16; - rhs += 16; - } - } - float result = static_cast(HorizontalAdd_INT32_V256(ymm_sum)); - - switch (last - lhs) { - case 15: - SSD_INT4_GENERAL(lhs[14], rhs[14], result) - /* FALLTHRU */ - case 14: - SSD_INT4_GENERAL(lhs[13], rhs[13], result) - /* FALLTHRU */ - case 13: - SSD_INT4_GENERAL(lhs[12], rhs[12], result) - /* FALLTHRU */ - case 12: - SSD_INT4_GENERAL(lhs[11], rhs[11], result) - /* FALLTHRU */ - case 11: - SSD_INT4_GENERAL(lhs[10], rhs[10], result) - /* FALLTHRU */ - case 10: - SSD_INT4_GENERAL(lhs[9], rhs[9], result) - /* FALLTHRU */ - case 9: - SSD_INT4_GENERAL(lhs[8], rhs[8], result) - /* FALLTHRU */ - case 8: - SSD_INT4_GENERAL(lhs[7], rhs[7], result) - /* FALLTHRU */ - case 7: - SSD_INT4_GENERAL(lhs[6], rhs[6], result) - /* FALLTHRU */ - case 6: - SSD_INT4_GENERAL(lhs[5], rhs[5], result) - /* FALLTHRU */ - case 5: - SSD_INT4_GENERAL(lhs[4], rhs[4], result) - /* FALLTHRU */ - case 4: - SSD_INT4_GENERAL(lhs[3], rhs[3], result) - /* FALLTHRU */ - case 3: - SSD_INT4_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - SSD_INT4_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - SSD_INT4_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __AVX2__ - -#if defined(__SSE4_1__) -//! Compute the distance between matrix and query (INT4, M=1, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - if (dim > 63) { - *out = SquaredEuclideanDistanceAVX(m, q, dim >> 1); - return; - } -#endif // __AVX2__ - *out = SquaredEuclideanDistanceSSE(m, q, dim >> 1); -} - -//! Compute the distance between matrix and query (INT4, M=2, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_2X1_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT4_2X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=2, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_2X2_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT4_2X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X1_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT4_4X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X2_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT4_4X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X4_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT4_4X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_8X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_8X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_8X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_8X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_16X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_16X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_16X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_16X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=16) -void SquaredEuclideanDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X16_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_16X16_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=16) -void SquaredEuclideanDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X16_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X16_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=32) -void SquaredEuclideanDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X32_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X32_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=1, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - if (dim > 63) { - *out = std::sqrt(SquaredEuclideanDistanceAVX(m, q, dim >> 1)); - return; - } -#endif // __AVX2__ - *out = std::sqrt(SquaredEuclideanDistanceSSE(m, q, dim >> 1)); -} - -//! Compute the distance between matrix and query (INT4, M=2, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_2X1_AVX(m, q, dim, out, SQRT_FP32_SSE) -#else - ACCUM_INT4_2X1_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=2, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_2X2_AVX(m, q, dim, out, SQRT_FP32_SSE) -#else - ACCUM_INT4_2X2_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X1_AVX(m, q, dim, out, SQRT_FP32_SSE) -#else - ACCUM_INT4_4X1_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X2_AVX(m, q, dim, out, SQRT_FP32_SSE) -#else - ACCUM_INT4_4X2_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X4_AVX(m, q, dim, out, SQRT_FP32_SSE) -#else - ACCUM_INT4_4X4_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X1_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_8X1_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X2_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_8X2_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X4_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_8X4_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X8_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_8X8_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X1_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_16X1_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X2_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_16X2_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X4_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_16X4_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X8_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_16X8_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=16) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X16_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_16X16_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X1_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_32X1_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X2_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_32X2_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X4_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_32X4_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X8_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_32X8_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=16) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X16_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_32X16_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=32) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X32_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT4_32X32_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} -#endif // __SSE4_1__ - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_int4_avx2.cc b/src/ailego/math/euclidean_distance_matrix_int4_avx2.cc new file mode 100644 index 00000000..09232492 --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_int4_avx2.cc @@ -0,0 +1,118 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int4.i" +#include "distance_matrix_euclidean_utility.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX2__) +//! Squared Euclidean Distance +float SquaredEuclideanDistanceAVX2(const uint8_t *lhs, const uint8_t *rhs, + size_t size) { + const uint8_t *last = lhs + size; + const uint8_t *last_aligned = lhs + ((size >> 5) << 5); + + __m256i ymm_sum = _mm256_setzero_si256(); + + if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m256i ymm_lhs = _mm256_load_si256((const __m256i *)(lhs)); + __m256i ymm_rhs = _mm256_load_si256((const __m256i *)(rhs)); + SSD_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) + } + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); + __m128i xmm_sum = _mm_setzero_si128(); + SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) + ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), + ymm_sum); + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)(lhs)); + __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)(rhs)); + SSD_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) + } + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); + __m128i xmm_sum = _mm_setzero_si128(); + SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) + ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), + ymm_sum); + lhs += 16; + rhs += 16; + } + } + float result = static_cast(HorizontalAdd_INT32_V256(ymm_sum)); + + switch (last - lhs) { + case 15: + SSD_INT4_GENERAL(lhs[14], rhs[14], result) + /* FALLTHRU */ + case 14: + SSD_INT4_GENERAL(lhs[13], rhs[13], result) + /* FALLTHRU */ + case 13: + SSD_INT4_GENERAL(lhs[12], rhs[12], result) + /* FALLTHRU */ + case 12: + SSD_INT4_GENERAL(lhs[11], rhs[11], result) + /* FALLTHRU */ + case 11: + SSD_INT4_GENERAL(lhs[10], rhs[10], result) + /* FALLTHRU */ + case 10: + SSD_INT4_GENERAL(lhs[9], rhs[9], result) + /* FALLTHRU */ + case 9: + SSD_INT4_GENERAL(lhs[8], rhs[8], result) + /* FALLTHRU */ + case 8: + SSD_INT4_GENERAL(lhs[7], rhs[7], result) + /* FALLTHRU */ + case 7: + SSD_INT4_GENERAL(lhs[6], rhs[6], result) + /* FALLTHRU */ + case 6: + SSD_INT4_GENERAL(lhs[5], rhs[5], result) + /* FALLTHRU */ + case 5: + SSD_INT4_GENERAL(lhs[4], rhs[4], result) + /* FALLTHRU */ + case 4: + SSD_INT4_GENERAL(lhs[3], rhs[3], result) + /* FALLTHRU */ + case 3: + SSD_INT4_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + SSD_INT4_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + SSD_INT4_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +#endif // __AVX2__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_int4_dispatch.cc b/src/ailego/math/euclidean_distance_matrix_int4_dispatch.cc new file mode 100644 index 00000000..beeb7a2c --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_int4_dispatch.cc @@ -0,0 +1,60 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX2__) +float SquaredEuclideanDistanceAVX2(const uint8_t *lhs, const uint8_t *rhs, + size_t size); +float EuclideanDistanceAVX2(const uint8_t *lhs, const uint8_t *rhs, + size_t size); +#endif + +#if defined(__SSE4_1__) +float SquaredEuclideanDistanceSSE(const uint8_t *lhs, const uint8_t *rhs, + size_t size); +float EuclideanDistanceSSE(const uint8_t *lhs, const uint8_t *rhs, size_t size); +#endif + +#if defined(__SSE4_1__) +//! Compute the distance between matrix and query (INT4, M=1, N=1) +void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, + float *out) { +#if defined(__AVX2__) + if (dim > 63) { + *out = SquaredEuclideanDistanceAVX2(m, q, dim >> 1); + return; + } +#endif // __AVX2__ + *out = SquaredEuclideanDistanceSSE(m, q, dim >> 1); +} + +//! Compute the distance between matrix and query (INT4, M=1, N=1) +void EuclideanDistanceMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, float *out) { + SquaredEuclideanDistanceMatrix::Compute(m, q, dim, out); + *out = std::sqrt(*out); +} + +#endif // __SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_int4_sse.cc b/src/ailego/math/euclidean_distance_matrix_int4_sse.cc new file mode 100644 index 00000000..63e10da5 --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_int4_sse.cc @@ -0,0 +1,98 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int4.i" +#include "distance_matrix_euclidean_utility.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__SSE4_1__) +//! Squared Euclidean Distance +float SquaredEuclideanDistanceSSE(const uint8_t *lhs, const uint8_t *rhs, + size_t size) { + const uint8_t *last = lhs + size; + const uint8_t *last_aligned = lhs + ((size >> 4) << 4); + + __m128i xmm_sum = _mm_setzero_si128(); + + if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)(lhs)); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)(rhs)); + SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) + } + } else { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)(lhs)); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)(rhs)); + SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) + } + } + float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); + + switch (last - lhs) { + case 15: + SSD_INT4_GENERAL(lhs[14], rhs[14], result) + /* FALLTHRU */ + case 14: + SSD_INT4_GENERAL(lhs[13], rhs[13], result) + /* FALLTHRU */ + case 13: + SSD_INT4_GENERAL(lhs[12], rhs[12], result) + /* FALLTHRU */ + case 12: + SSD_INT4_GENERAL(lhs[11], rhs[11], result) + /* FALLTHRU */ + case 11: + SSD_INT4_GENERAL(lhs[10], rhs[10], result) + /* FALLTHRU */ + case 10: + SSD_INT4_GENERAL(lhs[9], rhs[9], result) + /* FALLTHRU */ + case 9: + SSD_INT4_GENERAL(lhs[8], rhs[8], result) + /* FALLTHRU */ + case 8: + SSD_INT4_GENERAL(lhs[7], rhs[7], result) + /* FALLTHRU */ + case 7: + SSD_INT4_GENERAL(lhs[6], rhs[6], result) + /* FALLTHRU */ + case 6: + SSD_INT4_GENERAL(lhs[5], rhs[5], result) + /* FALLTHRU */ + case 5: + SSD_INT4_GENERAL(lhs[4], rhs[4], result) + /* FALLTHRU */ + case 4: + SSD_INT4_GENERAL(lhs[3], rhs[3], result) + /* FALLTHRU */ + case 3: + SSD_INT4_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + SSD_INT4_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + SSD_INT4_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +#endif // __SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_int8.cc b/src/ailego/math/euclidean_distance_matrix_int8.cc deleted file mode 100644 index cd14d2d3..00000000 --- a/src/ailego/math/euclidean_distance_matrix_int8.cc +++ /dev/null @@ -1,884 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "distance_matrix_accum_int8.i" -#include "euclidean_distance_matrix.h" - -namespace zvec { -namespace ailego { - -#define ACCUM_INT8_STEP_SSE SSD_INT8_SSE -#define ACCUM_INT8_STEP_AVX SSD_INT8_AVX - -#if defined(__SSE4_1__) -static const __m128i ONES_INT16_SSE = _mm_set1_epi32(0x00010001); -#endif // __SSE4_1__ - -#if defined(__AVX2__) -static const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); -#endif // __AVX2__ - -//! Calculate sum of squared difference (GENERAL) -#define SSD_INT8_GENERAL(m, q, sum) \ - { \ - int32_t x = m - q; \ - sum += static_cast(x * x); \ - } - -//! Calculate sum of squared difference (SSE) -#define SSD_INT8_SSE(xmm_m, xmm_q, xmm_sum) \ - { \ - xmm_sum = _mm_add_epi32( \ - _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_m), \ - _mm_sign_epi8(xmm_m, xmm_m)), \ - ONES_INT16_SSE), \ - xmm_sum); \ - xmm_sum = _mm_add_epi32( \ - _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_q), \ - _mm_sign_epi8(xmm_q, xmm_q)), \ - ONES_INT16_SSE), \ - xmm_sum); \ - xmm_sum = _mm_sub_epi32( \ - xmm_sum, \ - _mm_slli_epi32( \ - _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_q), \ - _mm_sign_epi8(xmm_m, xmm_q)), \ - ONES_INT16_SSE), \ - 1)); \ - } - -//! Calculate sum of squared difference (AVX) -#define SSD_INT8_AVX(ymm_m, ymm_q, ymm_sum) \ - { \ - ymm_sum = _mm256_add_epi32( \ - _mm256_madd_epi16( \ - _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_m), \ - _mm256_sign_epi8(ymm_m, ymm_m)), \ - ONES_INT16_AVX), \ - ymm_sum); \ - ymm_sum = _mm256_add_epi32( \ - _mm256_madd_epi16( \ - _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_q), \ - _mm256_sign_epi8(ymm_q, ymm_q)), \ - ONES_INT16_AVX), \ - ymm_sum); \ - ymm_sum = _mm256_sub_epi32( \ - ymm_sum, _mm256_slli_epi32( \ - _mm256_madd_epi16( \ - _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_q), \ - _mm256_sign_epi8(ymm_m, ymm_q)), \ - ONES_INT16_AVX), \ - 1)); \ - } - -//! Compute the square root of value (SSE) -#define SQRT_FP32_SSE(v, ...) _mm_sqrt_ps(_mm_cvtepi32_ps(v)) - -//! Compute the square root of value (AVX) -#define SQRT_FP32_AVX(v, ...) _mm256_sqrt_ps(_mm256_cvtepi32_ps(v)) - -//! Compute the square root of value (AVX512) -#define SQRT_FP32_AVX512(v, ...) _mm512_sqrt_ps(_mm512_cvtepi32_ps(v)) - -#if defined(__SSE4_1__) -//! Squared Euclidean Distance -static inline float SquaredEuclideanDistanceSSE(const int8_t *lhs, - const int8_t *rhs, - size_t size) { - const int8_t *last = lhs + size; - const int8_t *last_aligned = lhs + ((size >> 5) << 5); - - __m128i xmm_sum_0 = _mm_setzero_si128(); - __m128i xmm_sum_1 = _mm_setzero_si128(); - - if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m128i xmm_lhs_0 = _mm_load_si128((const __m128i *)(lhs + 0)); - __m128i xmm_lhs_1 = _mm_load_si128((const __m128i *)(lhs + 16)); - __m128i xmm_rhs_0 = _mm_load_si128((const __m128i *)(rhs + 0)); - __m128i xmm_rhs_1 = _mm_load_si128((const __m128i *)(rhs + 16)); - - __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_0, xmm_rhs_0), - _mm_min_epi8(xmm_lhs_0, xmm_rhs_0)); - xmm_lhs_0 = _mm_cvtepu8_epi16(xmm_d); - xmm_rhs_0 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); - xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_1, xmm_rhs_1), - _mm_min_epi8(xmm_lhs_1, xmm_rhs_1)); - xmm_lhs_1 = _mm_cvtepu8_epi16(xmm_d); - xmm_rhs_1 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); - - xmm_sum_0 = - _mm_add_epi32(_mm_madd_epi16(xmm_lhs_0, xmm_lhs_0), xmm_sum_0); - xmm_sum_1 = - _mm_add_epi32(_mm_madd_epi16(xmm_rhs_0, xmm_rhs_0), xmm_sum_1); - xmm_sum_0 = - _mm_add_epi32(_mm_madd_epi16(xmm_lhs_1, xmm_lhs_1), xmm_sum_0); - xmm_sum_1 = - _mm_add_epi32(_mm_madd_epi16(xmm_rhs_1, xmm_rhs_1), xmm_sum_1); - } - - if (last >= last_aligned + 16) { - __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); - __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), - _mm_min_epi8(xmm_lhs, xmm_rhs)); - xmm_lhs = _mm_cvtepu8_epi16(xmm_d); - xmm_rhs = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); - xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(xmm_lhs, xmm_lhs), xmm_sum_0); - xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(xmm_rhs, xmm_rhs), xmm_sum_1); - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m128i xmm_lhs_0 = _mm_loadu_si128((const __m128i *)(lhs + 0)); - __m128i xmm_lhs_1 = _mm_loadu_si128((const __m128i *)(lhs + 16)); - __m128i xmm_rhs_0 = _mm_loadu_si128((const __m128i *)(rhs + 0)); - __m128i xmm_rhs_1 = _mm_loadu_si128((const __m128i *)(rhs + 16)); - - __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_0, xmm_rhs_0), - _mm_min_epi8(xmm_lhs_0, xmm_rhs_0)); - xmm_lhs_0 = _mm_cvtepu8_epi16(xmm_d); - xmm_rhs_0 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); - xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_1, xmm_rhs_1), - _mm_min_epi8(xmm_lhs_1, xmm_rhs_1)); - xmm_lhs_1 = _mm_cvtepu8_epi16(xmm_d); - xmm_rhs_1 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); - - xmm_sum_0 = - _mm_add_epi32(_mm_madd_epi16(xmm_lhs_0, xmm_lhs_0), xmm_sum_0); - xmm_sum_1 = - _mm_add_epi32(_mm_madd_epi16(xmm_rhs_0, xmm_rhs_0), xmm_sum_1); - xmm_sum_0 = - _mm_add_epi32(_mm_madd_epi16(xmm_lhs_1, xmm_lhs_1), xmm_sum_0); - xmm_sum_1 = - _mm_add_epi32(_mm_madd_epi16(xmm_rhs_1, xmm_rhs_1), xmm_sum_1); - } - - if (last >= last_aligned + 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); - __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), - _mm_min_epi8(xmm_lhs, xmm_rhs)); - xmm_lhs = _mm_cvtepu8_epi16(xmm_d); - xmm_rhs = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); - xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(xmm_lhs, xmm_lhs), xmm_sum_0); - xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(xmm_rhs, xmm_rhs), xmm_sum_1); - lhs += 16; - rhs += 16; - } - } - float result = static_cast( - HorizontalAdd_INT32_V128(_mm_add_epi32(xmm_sum_0, xmm_sum_1))); - - switch (last - lhs) { - case 15: - SSD_INT8_GENERAL(lhs[14], rhs[14], result) - /* FALLTHRU */ - case 14: - SSD_INT8_GENERAL(lhs[13], rhs[13], result) - /* FALLTHRU */ - case 13: - SSD_INT8_GENERAL(lhs[12], rhs[12], result) - /* FALLTHRU */ - case 12: - SSD_INT8_GENERAL(lhs[11], rhs[11], result) - /* FALLTHRU */ - case 11: - SSD_INT8_GENERAL(lhs[10], rhs[10], result) - /* FALLTHRU */ - case 10: - SSD_INT8_GENERAL(lhs[9], rhs[9], result) - /* FALLTHRU */ - case 9: - SSD_INT8_GENERAL(lhs[8], rhs[8], result) - /* FALLTHRU */ - case 8: - SSD_INT8_GENERAL(lhs[7], rhs[7], result) - /* FALLTHRU */ - case 7: - SSD_INT8_GENERAL(lhs[6], rhs[6], result) - /* FALLTHRU */ - case 6: - SSD_INT8_GENERAL(lhs[5], rhs[5], result) - /* FALLTHRU */ - case 5: - SSD_INT8_GENERAL(lhs[4], rhs[4], result) - /* FALLTHRU */ - case 4: - SSD_INT8_GENERAL(lhs[3], rhs[3], result) - /* FALLTHRU */ - case 3: - SSD_INT8_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - SSD_INT8_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - SSD_INT8_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __SSE4_1__ - -#if defined(__AVX2__) -//! Squared Euclidean Distance -static inline float SquaredEuclideanDistanceAVX(const int8_t *lhs, - const int8_t *rhs, - size_t size) { - const int8_t *last = lhs + size; - const int8_t *last_aligned = lhs + ((size >> 6) << 6); - float result = 0.0; - - __m256i ymm_sum_0 = _mm256_setzero_si256(); - __m256i ymm_sum_1 = _mm256_setzero_si256(); - - if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { - for (; lhs != last_aligned; lhs += 64, rhs += 64) { - __m256i ymm_lhs_0 = _mm256_load_si256((const __m256i *)(lhs + 0)); - __m256i ymm_lhs_1 = _mm256_load_si256((const __m256i *)(lhs + 32)); - __m256i ymm_rhs_0 = _mm256_load_si256((const __m256i *)(rhs + 0)); - __m256i ymm_rhs_1 = _mm256_load_si256((const __m256i *)(rhs + 32)); - - __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_0, ymm_rhs_0), - _mm256_min_epi8(ymm_lhs_0, ymm_rhs_0)); - ymm_lhs_0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); - ymm_rhs_0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); - ymm_sum_0 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_0, ymm_lhs_0), ymm_sum_0); - ymm_sum_1 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_0, ymm_rhs_0), ymm_sum_1); - - ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_1, ymm_rhs_1), - _mm256_min_epi8(ymm_lhs_1, ymm_rhs_1)); - ymm_lhs_1 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); - ymm_rhs_1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); - ymm_sum_0 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_1, ymm_lhs_1), ymm_sum_0); - ymm_sum_1 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_1, ymm_rhs_1), ymm_sum_1); - } - - if (last >= last_aligned + 32) { - __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); - __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); - __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), - _mm256_min_epi8(ymm_lhs, ymm_rhs)); - ymm_lhs = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); - ymm_rhs = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); - ymm_sum_0 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs, ymm_lhs), ymm_sum_0); - ymm_sum_1 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs, ymm_rhs), ymm_sum_1); - lhs += 32; - rhs += 32; - } - } else { - for (; lhs != last_aligned; lhs += 64, rhs += 64) { - __m256i ymm_lhs_0 = _mm256_loadu_si256((const __m256i *)(lhs + 0)); - __m256i ymm_lhs_1 = _mm256_loadu_si256((const __m256i *)(lhs + 32)); - __m256i ymm_rhs_0 = _mm256_loadu_si256((const __m256i *)(rhs + 0)); - __m256i ymm_rhs_1 = _mm256_loadu_si256((const __m256i *)(rhs + 32)); - - __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_0, ymm_rhs_0), - _mm256_min_epi8(ymm_lhs_0, ymm_rhs_0)); - ymm_lhs_0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); - ymm_rhs_0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); - ymm_sum_0 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_0, ymm_lhs_0), ymm_sum_0); - ymm_sum_1 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_0, ymm_rhs_0), ymm_sum_1); - - ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_1, ymm_rhs_1), - _mm256_min_epi8(ymm_lhs_1, ymm_rhs_1)); - ymm_lhs_1 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); - ymm_rhs_1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); - ymm_sum_0 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_1, ymm_lhs_1), ymm_sum_0); - ymm_sum_1 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_1, ymm_rhs_1), ymm_sum_1); - } - - if (last >= last_aligned + 32) { - __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); - __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); - __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), - _mm256_min_epi8(ymm_lhs, ymm_rhs)); - ymm_lhs = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); - ymm_rhs = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); - ymm_sum_0 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs, ymm_lhs), ymm_sum_0); - ymm_sum_1 = - _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs, ymm_rhs), ymm_sum_1); - lhs += 32; - rhs += 32; - } - } - result = static_cast( - HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); - - if (last >= lhs + 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); - __m128i xmm_sum = _mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), - _mm_min_epi8(xmm_lhs, xmm_rhs)); - xmm_lhs = _mm_cvtepu8_epi16(xmm_sum); - xmm_rhs = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_sum, xmm_sum)); - xmm_sum = _mm_add_epi32(_mm_madd_epi16(xmm_lhs, xmm_lhs), - _mm_madd_epi16(xmm_rhs, xmm_rhs)); - result += static_cast(HorizontalAdd_INT32_V128(xmm_sum)); - lhs += 16; - rhs += 16; - } - switch (last - lhs) { - case 15: - SSD_INT8_GENERAL(lhs[14], rhs[14], result) - /* FALLTHRU */ - case 14: - SSD_INT8_GENERAL(lhs[13], rhs[13], result) - /* FALLTHRU */ - case 13: - SSD_INT8_GENERAL(lhs[12], rhs[12], result) - /* FALLTHRU */ - case 12: - SSD_INT8_GENERAL(lhs[11], rhs[11], result) - /* FALLTHRU */ - case 11: - SSD_INT8_GENERAL(lhs[10], rhs[10], result) - /* FALLTHRU */ - case 10: - SSD_INT8_GENERAL(lhs[9], rhs[9], result) - /* FALLTHRU */ - case 9: - SSD_INT8_GENERAL(lhs[8], rhs[8], result) - /* FALLTHRU */ - case 8: - SSD_INT8_GENERAL(lhs[7], rhs[7], result) - /* FALLTHRU */ - case 7: - SSD_INT8_GENERAL(lhs[6], rhs[6], result) - /* FALLTHRU */ - case 6: - SSD_INT8_GENERAL(lhs[5], rhs[5], result) - /* FALLTHRU */ - case 5: - SSD_INT8_GENERAL(lhs[4], rhs[4], result) - /* FALLTHRU */ - case 4: - SSD_INT8_GENERAL(lhs[3], rhs[3], result) - /* FALLTHRU */ - case 3: - SSD_INT8_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - SSD_INT8_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - SSD_INT8_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __AVX2__ - -#if defined(__SSE4_1__) -//! Compute the distance between matrix and query (INT8, M=1, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - if (dim > 31) { - *out = SquaredEuclideanDistanceAVX(m, q, dim); - return; - } -#endif // __AVX2__ - *out = SquaredEuclideanDistanceSSE(m, q, dim); -} - -//! Compute the distance between matrix and query (INT8, M=2, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_2X1_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT8_2X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=2, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_2X2_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT8_2X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X1_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT8_4X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X2_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT8_4X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X4_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT8_4X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_8X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_8X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_8X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_8X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_16X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_16X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_16X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_16X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=16) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X16_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_16X16_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=1) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=2) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=4) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=8) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=16) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X16_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X16_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=32) -void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X32_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X32_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=1, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - if (dim > 31) { - *out = std::sqrt(SquaredEuclideanDistanceAVX(m, q, dim)); - return; - } -#endif // __AVX2__ - *out = std::sqrt(SquaredEuclideanDistanceSSE(m, q, dim)); -} - -//! Compute the distance between matrix and query (INT8, M=2, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_2X1_AVX(m, q, dim, out, SQRT_FP32_SSE) -#else - ACCUM_INT8_2X1_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=2, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_2X2_AVX(m, q, dim, out, SQRT_FP32_SSE) -#else - ACCUM_INT8_2X2_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X1_AVX(m, q, dim, out, SQRT_FP32_SSE) -#else - ACCUM_INT8_4X1_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X2_AVX(m, q, dim, out, SQRT_FP32_SSE) -#else - ACCUM_INT8_4X2_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X4_AVX(m, q, dim, out, SQRT_FP32_SSE) -#else - ACCUM_INT8_4X4_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X1_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_8X1_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X2_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_8X2_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X4_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_8X4_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X8_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_8X8_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X1_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_16X1_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X2_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_16X2_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X4_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_16X4_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X8_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_16X8_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=16) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X16_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_16X16_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=1) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X1_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_32X1_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=2) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X2_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_32X2_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=4) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X4_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_32X4_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=8) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X8_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_32X8_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=16) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X16_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_32X16_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=32) -void EuclideanDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X32_AVX(m, q, dim, out, SQRT_FP32_AVX) -#else - ACCUM_INT8_32X32_SSE(m, q, dim, out, SQRT_FP32_SSE) -#endif // __AVX2__ -} -#endif // __SSE4_1__ - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_int8_avx2.cc b/src/ailego/math/euclidean_distance_matrix_int8_avx2.cc new file mode 100644 index 00000000..014281cd --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_int8_avx2.cc @@ -0,0 +1,182 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int8.i" +#include "distance_matrix_euclidean_utility.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX2__) +//! Squared Euclidean Distance +float SquaredEuclideanDistanceAVX2(const int8_t *lhs, const int8_t *rhs, + size_t size) { + const int8_t *last = lhs + size; + const int8_t *last_aligned = lhs + ((size >> 6) << 6); + float result = 0.0; + + __m256i ymm_sum_0 = _mm256_setzero_si256(); + __m256i ymm_sum_1 = _mm256_setzero_si256(); + + if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + __m256i ymm_lhs_0 = _mm256_load_si256((const __m256i *)(lhs + 0)); + __m256i ymm_lhs_1 = _mm256_load_si256((const __m256i *)(lhs + 32)); + __m256i ymm_rhs_0 = _mm256_load_si256((const __m256i *)(rhs + 0)); + __m256i ymm_rhs_1 = _mm256_load_si256((const __m256i *)(rhs + 32)); + + __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_0, ymm_rhs_0), + _mm256_min_epi8(ymm_lhs_0, ymm_rhs_0)); + ymm_lhs_0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); + ymm_rhs_0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); + ymm_sum_0 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_0, ymm_lhs_0), ymm_sum_0); + ymm_sum_1 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_0, ymm_rhs_0), ymm_sum_1); + + ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_1, ymm_rhs_1), + _mm256_min_epi8(ymm_lhs_1, ymm_rhs_1)); + ymm_lhs_1 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); + ymm_rhs_1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); + ymm_sum_0 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_1, ymm_lhs_1), ymm_sum_0); + ymm_sum_1 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_1, ymm_rhs_1), ymm_sum_1); + } + + if (last >= last_aligned + 32) { + __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); + __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); + __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), + _mm256_min_epi8(ymm_lhs, ymm_rhs)); + ymm_lhs = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); + ymm_rhs = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); + ymm_sum_0 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs, ymm_lhs), ymm_sum_0); + ymm_sum_1 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs, ymm_rhs), ymm_sum_1); + lhs += 32; + rhs += 32; + } + } else { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + __m256i ymm_lhs_0 = _mm256_loadu_si256((const __m256i *)(lhs + 0)); + __m256i ymm_lhs_1 = _mm256_loadu_si256((const __m256i *)(lhs + 32)); + __m256i ymm_rhs_0 = _mm256_loadu_si256((const __m256i *)(rhs + 0)); + __m256i ymm_rhs_1 = _mm256_loadu_si256((const __m256i *)(rhs + 32)); + + __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_0, ymm_rhs_0), + _mm256_min_epi8(ymm_lhs_0, ymm_rhs_0)); + ymm_lhs_0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); + ymm_rhs_0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); + ymm_sum_0 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_0, ymm_lhs_0), ymm_sum_0); + ymm_sum_1 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_0, ymm_rhs_0), ymm_sum_1); + + ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_1, ymm_rhs_1), + _mm256_min_epi8(ymm_lhs_1, ymm_rhs_1)); + ymm_lhs_1 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); + ymm_rhs_1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); + ymm_sum_0 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_1, ymm_lhs_1), ymm_sum_0); + ymm_sum_1 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_1, ymm_rhs_1), ymm_sum_1); + } + + if (last >= last_aligned + 32) { + __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); + __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); + __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), + _mm256_min_epi8(ymm_lhs, ymm_rhs)); + ymm_lhs = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); + ymm_rhs = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); + ymm_sum_0 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs, ymm_lhs), ymm_sum_0); + ymm_sum_1 = + _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs, ymm_rhs), ymm_sum_1); + lhs += 32; + rhs += 32; + } + } + result = static_cast( + HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); + + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); + __m128i xmm_sum = _mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), + _mm_min_epi8(xmm_lhs, xmm_rhs)); + xmm_lhs = _mm_cvtepu8_epi16(xmm_sum); + xmm_rhs = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_sum, xmm_sum)); + xmm_sum = _mm_add_epi32(_mm_madd_epi16(xmm_lhs, xmm_lhs), + _mm_madd_epi16(xmm_rhs, xmm_rhs)); + result += static_cast(HorizontalAdd_INT32_V128(xmm_sum)); + lhs += 16; + rhs += 16; + } + switch (last - lhs) { + case 15: + SSD_INT8_GENERAL(lhs[14], rhs[14], result) + /* FALLTHRU */ + case 14: + SSD_INT8_GENERAL(lhs[13], rhs[13], result) + /* FALLTHRU */ + case 13: + SSD_INT8_GENERAL(lhs[12], rhs[12], result) + /* FALLTHRU */ + case 12: + SSD_INT8_GENERAL(lhs[11], rhs[11], result) + /* FALLTHRU */ + case 11: + SSD_INT8_GENERAL(lhs[10], rhs[10], result) + /* FALLTHRU */ + case 10: + SSD_INT8_GENERAL(lhs[9], rhs[9], result) + /* FALLTHRU */ + case 9: + SSD_INT8_GENERAL(lhs[8], rhs[8], result) + /* FALLTHRU */ + case 8: + SSD_INT8_GENERAL(lhs[7], rhs[7], result) + /* FALLTHRU */ + case 7: + SSD_INT8_GENERAL(lhs[6], rhs[6], result) + /* FALLTHRU */ + case 6: + SSD_INT8_GENERAL(lhs[5], rhs[5], result) + /* FALLTHRU */ + case 5: + SSD_INT8_GENERAL(lhs[4], rhs[4], result) + /* FALLTHRU */ + case 4: + SSD_INT8_GENERAL(lhs[3], rhs[3], result) + /* FALLTHRU */ + case 3: + SSD_INT8_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + SSD_INT8_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + SSD_INT8_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +#endif // __AVX2__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_int8_dispatch.cc b/src/ailego/math/euclidean_distance_matrix_int8_dispatch.cc new file mode 100644 index 00000000..54e9a75b --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_int8_dispatch.cc @@ -0,0 +1,59 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX2__) +float SquaredEuclideanDistanceAVX2(const int8_t *lhs, const int8_t *rhs, + size_t size); +float EuclideanDistanceAVX2(const int8_t *lhs, const int8_t *rhs, size_t size); +#endif + +#if defined(__SSE4_1__) +float SquaredEuclideanDistanceSSE(const int8_t *lhs, const int8_t *rhs, + size_t size); +float EuclideanDistanceSSE(const int8_t *lhs, const int8_t *rhs, size_t size); +#endif + + +#if defined(__SSE4_1__) +//! Compute the distance between matrix and query (INT8, M=1, N=1) +void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, + float *out) { +#if defined(__AVX2__) + if (dim > 31) { + *out = SquaredEuclideanDistanceAVX2(m, q, dim); + return; + } +#endif // __AVX2__ + *out = SquaredEuclideanDistanceSSE(m, q, dim); +} + +//! Compute the distance between matrix and query (INT8, M=1, N=1) +void EuclideanDistanceMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, float *out) { + SquaredEuclideanDistanceMatrix::Compute(m, q, dim, out); + *out = std::sqrt(*out); +} +#endif // __SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/euclidean_distance_matrix_int8_sse.cc b/src/ailego/math/euclidean_distance_matrix_int8_sse.cc new file mode 100644 index 00000000..ca18ae98 --- /dev/null +++ b/src/ailego/math/euclidean_distance_matrix_int8_sse.cc @@ -0,0 +1,164 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int8.i" +#include "distance_matrix_euclidean_utility.i" +#include "euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__SSE4_1__) +//! Squared Euclidean Distance +float SquaredEuclideanDistanceSSE(const int8_t *lhs, const int8_t *rhs, + size_t size) { + const int8_t *last = lhs + size; + const int8_t *last_aligned = lhs + ((size >> 5) << 5); + + __m128i xmm_sum_0 = _mm_setzero_si128(); + __m128i xmm_sum_1 = _mm_setzero_si128(); + + if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m128i xmm_lhs_0 = _mm_load_si128((const __m128i *)(lhs + 0)); + __m128i xmm_lhs_1 = _mm_load_si128((const __m128i *)(lhs + 16)); + __m128i xmm_rhs_0 = _mm_load_si128((const __m128i *)(rhs + 0)); + __m128i xmm_rhs_1 = _mm_load_si128((const __m128i *)(rhs + 16)); + + __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_0, xmm_rhs_0), + _mm_min_epi8(xmm_lhs_0, xmm_rhs_0)); + xmm_lhs_0 = _mm_cvtepu8_epi16(xmm_d); + xmm_rhs_0 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); + xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_1, xmm_rhs_1), + _mm_min_epi8(xmm_lhs_1, xmm_rhs_1)); + xmm_lhs_1 = _mm_cvtepu8_epi16(xmm_d); + xmm_rhs_1 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); + + xmm_sum_0 = + _mm_add_epi32(_mm_madd_epi16(xmm_lhs_0, xmm_lhs_0), xmm_sum_0); + xmm_sum_1 = + _mm_add_epi32(_mm_madd_epi16(xmm_rhs_0, xmm_rhs_0), xmm_sum_1); + xmm_sum_0 = + _mm_add_epi32(_mm_madd_epi16(xmm_lhs_1, xmm_lhs_1), xmm_sum_0); + xmm_sum_1 = + _mm_add_epi32(_mm_madd_epi16(xmm_rhs_1, xmm_rhs_1), xmm_sum_1); + } + + if (last >= last_aligned + 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); + __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), + _mm_min_epi8(xmm_lhs, xmm_rhs)); + xmm_lhs = _mm_cvtepu8_epi16(xmm_d); + xmm_rhs = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); + xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(xmm_lhs, xmm_lhs), xmm_sum_0); + xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(xmm_rhs, xmm_rhs), xmm_sum_1); + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m128i xmm_lhs_0 = _mm_loadu_si128((const __m128i *)(lhs + 0)); + __m128i xmm_lhs_1 = _mm_loadu_si128((const __m128i *)(lhs + 16)); + __m128i xmm_rhs_0 = _mm_loadu_si128((const __m128i *)(rhs + 0)); + __m128i xmm_rhs_1 = _mm_loadu_si128((const __m128i *)(rhs + 16)); + + __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_0, xmm_rhs_0), + _mm_min_epi8(xmm_lhs_0, xmm_rhs_0)); + xmm_lhs_0 = _mm_cvtepu8_epi16(xmm_d); + xmm_rhs_0 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); + xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_1, xmm_rhs_1), + _mm_min_epi8(xmm_lhs_1, xmm_rhs_1)); + xmm_lhs_1 = _mm_cvtepu8_epi16(xmm_d); + xmm_rhs_1 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); + + xmm_sum_0 = + _mm_add_epi32(_mm_madd_epi16(xmm_lhs_0, xmm_lhs_0), xmm_sum_0); + xmm_sum_1 = + _mm_add_epi32(_mm_madd_epi16(xmm_rhs_0, xmm_rhs_0), xmm_sum_1); + xmm_sum_0 = + _mm_add_epi32(_mm_madd_epi16(xmm_lhs_1, xmm_lhs_1), xmm_sum_0); + xmm_sum_1 = + _mm_add_epi32(_mm_madd_epi16(xmm_rhs_1, xmm_rhs_1), xmm_sum_1); + } + + if (last >= last_aligned + 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); + __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), + _mm_min_epi8(xmm_lhs, xmm_rhs)); + xmm_lhs = _mm_cvtepu8_epi16(xmm_d); + xmm_rhs = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); + xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(xmm_lhs, xmm_lhs), xmm_sum_0); + xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(xmm_rhs, xmm_rhs), xmm_sum_1); + lhs += 16; + rhs += 16; + } + } + float result = static_cast( + HorizontalAdd_INT32_V128(_mm_add_epi32(xmm_sum_0, xmm_sum_1))); + + switch (last - lhs) { + case 15: + SSD_INT8_GENERAL(lhs[14], rhs[14], result) + /* FALLTHRU */ + case 14: + SSD_INT8_GENERAL(lhs[13], rhs[13], result) + /* FALLTHRU */ + case 13: + SSD_INT8_GENERAL(lhs[12], rhs[12], result) + /* FALLTHRU */ + case 12: + SSD_INT8_GENERAL(lhs[11], rhs[11], result) + /* FALLTHRU */ + case 11: + SSD_INT8_GENERAL(lhs[10], rhs[10], result) + /* FALLTHRU */ + case 10: + SSD_INT8_GENERAL(lhs[9], rhs[9], result) + /* FALLTHRU */ + case 9: + SSD_INT8_GENERAL(lhs[8], rhs[8], result) + /* FALLTHRU */ + case 8: + SSD_INT8_GENERAL(lhs[7], rhs[7], result) + /* FALLTHRU */ + case 7: + SSD_INT8_GENERAL(lhs[6], rhs[6], result) + /* FALLTHRU */ + case 6: + SSD_INT8_GENERAL(lhs[5], rhs[5], result) + /* FALLTHRU */ + case 5: + SSD_INT8_GENERAL(lhs[4], rhs[4], result) + /* FALLTHRU */ + case 4: + SSD_INT8_GENERAL(lhs[3], rhs[3], result) + /* FALLTHRU */ + case 3: + SSD_INT8_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + SSD_INT8_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + SSD_INT8_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +#endif // __SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/hamming_distance_matrix.cc b/src/ailego/math/hamming_distance_matrix.cc index 0e990d4c..5009a788 100644 --- a/src/ailego/math/hamming_distance_matrix.cc +++ b/src/ailego/math/hamming_distance_matrix.cc @@ -314,248 +314,6 @@ void HammingDistanceMatrix::Compute(const ValueType *m, *out = static_cast(HammingDistance(m, q, cnt)); } -#if defined(__SSSE3__) -//! Compute the distance between matrix and query (UINT32, M=2, N=1) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_2X1_AVX(m, q, cnt, out, _mm_cvtepi32_ps) -#else - POPCNT_UINT32_2X1_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=2, N=2) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_2X2_AVX(m, q, cnt, out, _mm_cvtepi32_ps) -#else - POPCNT_UINT32_2X2_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=4, N=1) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_4X1_AVX(m, q, cnt, out, _mm_cvtepi32_ps) -#else - POPCNT_UINT32_4X1_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=4, N=2) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_4X2_AVX(m, q, cnt, out, _mm_cvtepi32_ps) -#else - POPCNT_UINT32_4X2_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=4, N=4) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_4X4_AVX(m, q, cnt, out, _mm_cvtepi32_ps) -#else - POPCNT_UINT32_4X4_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=8, N=1) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_8X1_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_8X1_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=8, N=2) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_8X2_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_8X2_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=8, N=4) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_8X4_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_8X4_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=8, N=8) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_8X8_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_8X8_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=16, N=1) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_16X1_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_16X1_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=16, N=2) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_16X2_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_16X2_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=16, N=4) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_16X4_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_16X4_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=16, N=8) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_16X8_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_16X8_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=16, N=16) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_16X16_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_16X16_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=1) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X1_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_32X1_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=2) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X2_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_32X2_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=4) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X4_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_32X4_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=8) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X8_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_32X8_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=16) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X16_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_32X16_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=32) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X32_AVX(m, q, cnt, out, _mm256_cvtepi32_ps) -#else - POPCNT_UINT32_32X32_SSE(m, q, cnt, out, _mm_cvtepi32_ps) -#endif -} -#endif // __SSSE3__ - #if defined(AILEGO_M64) //! Compute the distance between matrix and query (UINT64, M=1, N=1) void HammingDistanceMatrix::Compute(const ValueType *m, @@ -571,167 +329,6 @@ void HammingDistanceMatrix::Compute(const ValueType *m, *out = static_cast(HammingDistance(m, q, cnt)); } -#if defined(__AVX2__) -//! Compute the distance between matrix and query (UINT64, M=2, N=1) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_2X1_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=2, N=2) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_2X2_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=4, N=1) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_4X1_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=4, N=2) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_4X2_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=4, N=4) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_4X4_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=8, N=1) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_8X1_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=8, N=2) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_8X2_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=8, N=4) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_8X4_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=8, N=8) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_8X8_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=16, N=1) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_16X1_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=16, N=2) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_16X2_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=16, N=4) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_16X4_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=16, N=8) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_16X8_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=16, N=16) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_16X16_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=1) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X1_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=2) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X2_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=4) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X4_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=8) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X8_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=16) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X16_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=32) -void HammingDistanceMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X32_AVX(m, q, cnt, out, CONVERT_UINT64_TO_FP32) -} -#endif // __AVX2__ #endif // AILEGO_M64 //! Compute the distance between matrix and query (UINT32, M=1, N=1) @@ -747,227 +344,6 @@ void HammingSquareRootDistanceMatrix::Compute( *out = std::sqrt(static_cast(HammingDistance(m, q, cnt))); } -#if defined(__SSSE3__) -//! Compute the distance between matrix and query (UINT32, M=2, N=1) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_2X1_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#else - POPCNT_UINT32_2X1_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=2, N=2) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_2X2_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#else - POPCNT_UINT32_2X2_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=4, N=1) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_4X1_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#else - POPCNT_UINT32_4X1_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=4, N=2) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_4X2_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#else - POPCNT_UINT32_4X2_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=4, N=4) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_4X4_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#else - POPCNT_UINT32_4X4_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=8, N=1) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_8X1_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_8X1_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=8, N=2) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_8X2_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_8X2_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=8, N=4) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_8X4_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_8X4_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=8, N=8) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_8X8_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_8X8_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=16, N=1) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_16X1_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_16X1_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=16, N=2) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_16X2_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_16X2_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=16, N=4) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_16X4_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_16X4_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=16, N=8) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_16X8_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_16X8_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=16, N=16) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_16X16_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_16X16_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=1) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X1_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_32X1_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=2) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X2_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_32X2_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=4) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X4_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_32X4_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=8) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X8_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_32X8_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=16) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X16_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_32X16_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (UINT32, M=32, N=32) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 5); -#if defined(__AVX2__) - POPCNT_UINT32_32X32_AVX(m, q, cnt, out, SQRT_UINT32_TO_FP32_AVX) -#else - POPCNT_UINT32_32X32_SSE(m, q, cnt, out, SQRT_UINT32_TO_FP32_SSE) -#endif -} -#endif // __SSSE3__ #if defined(AILEGO_M64) //! Compute the distance between matrix and query (UINT64, M=1, N=1) @@ -983,147 +359,6 @@ void HammingSquareRootDistanceMatrix::Compute( *out = std::sqrt(static_cast(HammingDistance(m, q, cnt))); } -#if defined(__AVX2__) -//! Compute the distance between matrix and query (UINT64, M=2, N=1) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_2X1_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=2, N=2) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_2X2_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=4, N=1) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_4X1_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=4, N=2) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_4X2_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=4, N=4) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_4X4_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=8, N=1) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_8X1_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=8, N=2) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_8X2_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=8, N=4) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_8X4_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=8, N=8) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_8X8_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=16, N=1) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_16X1_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=16, N=2) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_16X2_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=16, N=4) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_16X4_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=16, N=8) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_16X8_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=16, N=16) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_16X16_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=1) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X1_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=2) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X2_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=4) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X4_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=8) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X8_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=16) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X16_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} - -//! Compute the distance between matrix and query (UINT64, M=32, N=32) -void HammingSquareRootDistanceMatrix::Compute( - const ValueType *m, const ValueType *q, size_t dim, float *out) { - size_t cnt = (dim >> 6); - POPCNT_UINT64_32X32_AVX(m, q, cnt, out, SQRT_UINT64_TO_FP32) -} -#endif // __AVX2__ #endif // AILEGO_M64 } // namespace ailego diff --git a/src/ailego/math/hamming_distance_matrix.h b/src/ailego/math/hamming_distance_matrix.h index 5178e161..1ee67a50 100644 --- a/src/ailego/math/hamming_distance_matrix.h +++ b/src/ailego/math/hamming_distance_matrix.h @@ -83,248 +83,6 @@ struct HammingDistanceMatrix { float *out); }; -#if defined(__SSSE3__) -/*! Hamming Distance Matrix (UINT32, M=2, N=1) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=2, N=2) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=4, N=1) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=4, N=2) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=4, N=4) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=8, N=1) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=8, N=2) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=8, N=4) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=8, N=8) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=16, N=1) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=16, N=2) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=16, N=4) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=16, N=8) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=16, N=16) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=32, N=1) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=32, N=2) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=32, N=4) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=32, N=8) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=32, N=16) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT32, M=32, N=32) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; -#endif // __SSSE3__ - #if defined(AILEGO_M64) /*! Hamming Distance Matrix (UINT64) */ @@ -381,71 +139,46 @@ struct HammingDistanceMatrix { float *out); }; -#if defined(__AVX2__) -/*! Hamming Distance Matrix (UINT64, M=2, N=1) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; +#endif // AILEGO_M64 -/*! Hamming Distance Matrix (UINT64, M=2, N=2) +/*! Hamming Square Root Distance Matrix */ -template <> -struct HammingDistanceMatrix { +template +struct HammingSquareRootDistanceMatrix { //! Type of value - using ValueType = uint64_t; + using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=4, N=1) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; + static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, + float *out) { + ailego_assert(m && q && dim && out); - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); + HammingDistanceMatrix::Compute(m, q, dim, out); + for (size_t i = 0; i < N * M; ++i) { + float val = *out; + *out++ = std::sqrt(val); + } + } }; -/*! Hamming Distance Matrix (UINT64, M=4, N=2) +/*! Hamming Square Root Distance Matrix (UINT32, M=1, N=1) */ template <> -struct HammingDistanceMatrix { +struct HammingSquareRootDistanceMatrix { //! Type of value - using ValueType = uint64_t; + using ValueType = uint32_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; -/*! Hamming Distance Matrix (UINT64, M=4, N=4) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; -/*! Hamming Distance Matrix (UINT64, M=8, N=1) +#if defined(AILEGO_M64) +/*! Hamming Square Root Distance Matrix (UINT64, M=1, N=1) */ template <> -struct HammingDistanceMatrix { +struct HammingSquareRootDistanceMatrix { //! Type of value using ValueType = uint64_t; @@ -454,704 +187,6 @@ struct HammingDistanceMatrix { float *out); }; -/*! Hamming Distance Matrix (UINT64, M=8, N=2) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=8, N=4) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=8, N=8) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=16, N=1) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=16, N=2) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=16, N=4) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=16, N=8) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=16, N=16) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=32, N=1) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=32, N=2) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=32, N=4) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=32, N=8) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=32, N=16) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Distance Matrix (UINT64, M=32, N=32) - */ -template <> -struct HammingDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; -#endif // __AVX2__ -#endif // AILEGO_M64 - -/*! Hamming Square Root Distance Matrix - */ -template -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = typename std::remove_cv::type; - - //! Compute the distance between matrix and query - static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out) { - ailego_assert(m && q && dim && out); - - HammingDistanceMatrix::Compute(m, q, dim, out); - for (size_t i = 0; i < N * M; ++i) { - float val = *out; - *out++ = std::sqrt(val); - } - } -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=1, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -#if defined(__SSSE3__) -/*! Hamming Square Root Distance Matrix (UINT32, M=2, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=2, N=2) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=4, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=4, N=2) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=4, N=4) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=8, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=8, N=2) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=8, N=4) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=8, N=8) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=16, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=16, N=2) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=16, N=4) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=16, N=8) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=16, N=16) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=32, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=32, N=2) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=32, N=4) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=32, N=8) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=32, N=16) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT32, M=32, N=32) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint32_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; -#endif // __SSSE3__ - -#if defined(AILEGO_M64) -/*! Hamming Square Root Distance Matrix (UINT64, M=1, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -#if defined(__AVX2__) -/*! Hamming Square Root Distance Matrix (UINT64, M=2, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=2, N=2) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=4, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=4, N=2) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=4, N=4) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=8, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=8, N=2) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=8, N=4) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=8, N=8) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=16, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=16, N=2) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=16, N=4) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=16, N=8) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=16, N=16) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=32, N=1) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=32, N=2) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=32, N=4) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=32, N=8) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=32, N=16) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Hamming Square Root Distance Matrix (UINT64, M=32, N=32) - */ -template <> -struct HammingSquareRootDistanceMatrix { - //! Type of value - using ValueType = uint64_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; -#endif // __AVX2__ #endif // AILEGO_M64 } // namespace ailego diff --git a/src/ailego/math/inner_product_matrix.h b/src/ailego/math/inner_product_matrix.h index e1e87183..d141722b 100644 --- a/src/ailego/math/inner_product_matrix.h +++ b/src/ailego/math/inner_product_matrix.h @@ -735,1757 +735,77 @@ struct InnerProductMatrix { float *out); }; -/*! Inner Product Matrix (FP32, M=2, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=2, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=4, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=4, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=4, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=8, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=8, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=8, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=8, N=8) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=16, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=16, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=16, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=16, N=8) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=16, N=16) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=32, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=32, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=32, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=32, N=8) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=32, N=16) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP32, M=32, N=32) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=1, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=2, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=2, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=4, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=4, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=4, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=8, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=8, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=8, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=8, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=16, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=16, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=16, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=16, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=16, N=16) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=32, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=32, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=32, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=32, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=32, N=16) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP32, M=32, N=32) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = float; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; -#endif // __SSE__ || __ARM_NEON - -#if (defined(__F16C__) && defined(__AVX__)) || \ - (defined(__ARM_NEON) && defined(__aarch64__)) -/*! Inner Product Matrix (FP16, M=1, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=1, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -#if !defined(__ARM_NEON) -/*! Inner Product Matrix (FP16, M=2, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=2, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=4, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=4, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=4, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=8, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=8, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=8, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=8, N=8) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=16, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=16, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=16, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=16, N=8) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=16, N=16) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=32, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=32, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=32, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=32, N=8) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=32, N=16) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (FP16, M=32, N=32) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=2, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=2, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=4, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=4, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=4, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=8, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=8, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=8, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=8, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=16, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=16, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=16, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=16, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=16, N=16) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=32, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=32, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=32, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=32, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=32, N=16) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (FP16, M=32, N=32) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = Float16; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; -#endif // !__ARM_NEON -#endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) - -#if defined(__SSE4_1__) -/*! Inner Product Matrix (INT8, M=1, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=2, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=2, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=4, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=4, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=4, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=8, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=8, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=8, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=8, N=8) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=16, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=16, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=16, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=16, N=8) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=16, N=16) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=32, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=32, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=32, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=32, N=8) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=32, N=16) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT8, M=32, N=32) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=1, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=2, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=2, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=4, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=4, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=4, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=8, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=8, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=8, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=8, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=16, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=16, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=16, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=16, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=16, N=16) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=32, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=32, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=32, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=32, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=32, N=16) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT8, M=32, N=32) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = int8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=1, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=2, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=2, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=4, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=4, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=4, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=8, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=8, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=8, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=8, N=8) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=16, N=1) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=16, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=16, N=4) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Inner Product Matrix (INT4, M=16, N=8) +/*! Minus Inner Product Matrix (FP32, M=1, N=1) */ template <> -struct InnerProductMatrix { +struct MinusInnerProductMatrix { //! Type of value - using ValueType = uint8_t; + using ValueType = float; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; +#endif // __SSE__ || __ARM_NEON -/*! Inner Product Matrix (INT4, M=16, N=16) +#if (defined(__F16C__) && defined(__AVX__)) || \ + (defined(__ARM_NEON) && defined(__aarch64__)) +/*! Inner Product Matrix (FP16, M=1, N=1) */ template <> -struct InnerProductMatrix { +struct InnerProductMatrix { //! Type of value - using ValueType = uint8_t; + using ValueType = Float16; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; -/*! Inner Product Matrix (INT4, M=32, N=1) +/*! Minus Inner Product Matrix (FP16, M=1, N=1) */ template <> -struct InnerProductMatrix { +struct MinusInnerProductMatrix { //! Type of value - using ValueType = uint8_t; + using ValueType = Float16; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; -/*! Inner Product Matrix (INT4, M=32, N=2) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; +#endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) -/*! Inner Product Matrix (INT4, M=32, N=4) +#if defined(__SSE4_1__) +/*! Inner Product Matrix (INT8, M=1, N=1) */ template <> -struct InnerProductMatrix { +struct InnerProductMatrix { //! Type of value - using ValueType = uint8_t; + using ValueType = int8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; -/*! Inner Product Matrix (INT4, M=32, N=8) +/*! Minus Inner Product Matrix (INT8, M=1, N=1) */ template <> -struct InnerProductMatrix { +struct MinusInnerProductMatrix { //! Type of value - using ValueType = uint8_t; + using ValueType = int8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; -/*! Inner Product Matrix (INT4, M=32, N=16) - */ -template <> -struct InnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; -/*! Inner Product Matrix (INT4, M=32, N=32) +/*! Inner Product Matrix (INT4, M=1, N=1) */ template <> -struct InnerProductMatrix { +struct InnerProductMatrix { //! Type of value using ValueType = uint8_t; @@ -2505,246 +825,6 @@ struct MinusInnerProductMatrix { static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; - -/*! Minus Inner Product Matrix (INT4, M=2, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=2, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=4, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=4, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=4, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=8, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=8, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=8, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=8, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=16, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=16, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=16, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=16, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=16, N=16) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=32, N=1) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=32, N=2) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=32, N=4) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=32, N=8) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=32, N=16) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; - -/*! Minus Inner Product Matrix (INT4, M=32, N=32) - */ -template <> -struct MinusInnerProductMatrix { - //! Type of value - using ValueType = uint8_t; - - //! Compute the distance between matrix and query - static void Compute(const ValueType *m, const ValueType *q, size_t dim, - float *out); -}; #endif // __SSE4_1__ template diff --git a/src/ailego/math/inner_product_matrix_fp16.cc b/src/ailego/math/inner_product_matrix_fp16.cc deleted file mode 100644 index 682cc918..00000000 --- a/src/ailego/math/inner_product_matrix_fp16.cc +++ /dev/null @@ -1,1948 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "distance_matrix_accum_fp16.i" -#include "inner_product_matrix.h" - -namespace zvec { -namespace ailego { - -#define ACCUM_FP32_STEP_SSE FMA_FP32_SSE -#define ACCUM_FP32_STEP_AVX FMA_FP32_AVX -#define ACCUM_FP32_STEP_AVX512 FMA_FP32_AVX512 -#define ACCUM_FP32_STEP_NEON FMA_FP32_NEON -#define ACCUM_FP16_STEP_GENERAL FMA_FP16_GENERAL -#define ACCUM_FP16_STEP_NEON FMA_FP16_NEON - -#if defined(__AVX512F__) && !defined(__AVX512DQ__) -#define _mm512_xor_ps(a, b) \ - _mm512_castsi512_ps( \ - _mm512_xor_epi32(_mm512_castps_si512(a), _mm512_castps_si512(b))) -#endif // __AVX512DQ__ - -#if defined(__SSE__) -static const __m128 NEGZEROS_FP32_SSE = _mm_set1_ps(-0.0f); -#endif // __SSE__ - -#if defined(__AVX__) -static const __m256 NEGZEROS_FP32_AVX = _mm256_set1_ps(-0.0f); -#endif // __AVX__ - -#if defined(__AVX512F__) -static const __m512 NEGZEROS_FP32_AVX512 = _mm512_set1_ps(-0.0f); -#endif // __AVX512F__ - -//! Reverse sign of value (GENERAL) -#define NEGATE_FP32_GENERAL(v) -(v) - -#define NEGATE_FP32_SSE(v, ...) _mm_xor_ps(v, NEGZEROS_FP32_SSE) - -//! Reverse sign of value (AVX) -#define NEGATE_FP32_AVX(v, ...) _mm256_xor_ps(v, NEGZEROS_FP32_AVX) - -//! Reverse sign of value (AVX512) -#define NEGATE_FP32_AVX512(v, ...) _mm512_xor_ps(v, NEGZEROS_FP32_AVX512) - -//! Calculate Fused-Multiply-Add (SSE) -#define FMA_FP32_SSE(xmm_m, xmm_q, xmm_sum) \ - xmm_sum = _mm_fmadd_ps(xmm_m, xmm_q, xmm_sum); - -//! Calculate Fused-Multiply-Add (AVX) -#define FMA_FP32_AVX(ymm_m, ymm_q, ymm_sum) \ - ymm_sum = _mm256_fmadd_ps(ymm_m, ymm_q, ymm_sum); - -//! Calculate Fused-Multiply-Add (AVX512) -#define FMA_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ - zmm_sum = _mm512_fmadd_ps(zmm_m, zmm_q, zmm_sum); - -//! Calculate Fused-Multiply-Add (AVX512FP16) -#define FMA_FP16_AVX512FP16(zmm_m, zmm_q, zmm_sum) \ - zmm_sum = _mm512_fmadd_ph(zmm_m, zmm_q, zmm_sum); - -//! Calculate Fused-Multiply-Add (GENERAL) -#define FMA_FP16_GENERAL(m, q, sum) sum += (m * q); - -//! Calculate Fused-Multiply-Add (NEON) -#define FMA_FP16_NEON(v_m, v_q, v_sum) v_sum = vfmaq_f16(v_sum, v_m, v_q); - -//! Calculate Fused-Multiply-Add (NEON) -#define FMA_FP32_NEON(v_m, v_q, v_sum) v_sum = vfmaq_f32(v_sum, v_m, v_q); - -#if (defined(__F16C__) && defined(__AVX__)) || \ - (defined(__ARM_NEON) && defined(__aarch64__)) - -#if defined(__AVX512FP16__) -//! Inner Product -static inline float InnerProductAVX512FP16(const Float16 *lhs, - const Float16 *rhs, size_t size) { - const Float16 *last = lhs + size; - const Float16 *last_aligned = lhs + ((size >> 6) << 6); - - __m512h zmm_sum_0 = _mm512_setzero_ph(); - __m512h zmm_sum_1 = _mm512_setzero_ph(); - - if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { - for (; lhs != last_aligned; lhs += 64, rhs += 64) { - FMA_FP16_AVX512FP16(_mm512_load_ph(lhs + 0), _mm512_load_ph(rhs + 0), - zmm_sum_0) - - FMA_FP16_AVX512FP16(_mm512_load_ph(lhs + 32), _mm512_load_ph(rhs + 32), - zmm_sum_1) - } - - if (last >= last_aligned + 32) { - FMA_FP16_AVX512FP16(_mm512_load_ph(lhs), _mm512_load_ph(rhs), zmm_sum_0) - lhs += 32; - rhs += 32; - } - } else { - for (; lhs != last_aligned; lhs += 64, rhs += 64) { - FMA_FP16_AVX512FP16(_mm512_loadu_ph(lhs + 0), _mm512_loadu_ph(rhs + 0), - zmm_sum_0) - - FMA_FP16_AVX512FP16(_mm512_loadu_ph(lhs + 32), _mm512_loadu_ph(rhs + 32), - zmm_sum_1) - } - - if (last >= last_aligned + 32) { - FMA_FP16_AVX512FP16(_mm512_loadu_ph(lhs), _mm512_loadu_ph(rhs), zmm_sum_0) - lhs += 32; - rhs += 32; - } - } - - zmm_sum_0 = _mm512_add_ph(zmm_sum_0, zmm_sum_1); - - if (lhs != last) { - __mmask32 mask = (__mmask32)((1 << (last - lhs)) - 1); - __m512i zmm_undefined = _mm512_undefined_epi32(); - zmm_sum_0 = _mm512_mask3_fmadd_ph( - _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, lhs)), - _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, rhs)), - zmm_sum_0, mask); - } - - return HorizontalAdd_FP16_V512(zmm_sum_0); -} - -#endif - -//! Compute the distance between matrix and query (FP16, M=1, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP16_1X1_NEON(m, q, dim, out, 0ull, ) -#else -#if defined(__AVX512FP16__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { - *out = InnerProductAVX512FP16(m, q, dim); - return; - } -#endif //__AVX512FP16__ -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_1X1_AVX512(m, q, dim, out, 0ull, ) - return; - } -#endif //__AVX512F__ - ACCUM_FP16_1X1_AVX(m, q, dim, out, 0ull, ) -#endif //__ARM_NEON -} - -//! Compute the distance between matrix and query (FP16, M=1, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP16_1X1_NEON(m, q, dim, out, 0ull, NEGATE_FP32_GENERAL) -#else -#if defined(__AVX512FP16__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { - *out = -InnerProductAVX512FP16(m, q, dim); - return; - } -#endif //__AVX512FP16__ -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - ACCUM_FP16_1X1_AVX512(m, q, dim, out, 0ull, NEGATE_FP32_GENERAL) - return; - } -#endif //__AVX512F__ - ACCUM_FP16_1X1_AVX(m, q, dim, out, 0ull, NEGATE_FP32_GENERAL) -#endif //__ARM_NEON -} - -#if !defined(__ARM_NEON) -//! Compute the distance between matrix and query (FP16, M=2, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { - ACCUM_FP16_2X1_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=2, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { - ACCUM_FP16_2X2_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { - ACCUM_FP16_4X1_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { - ACCUM_FP16_4X2_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { - ACCUM_FP16_4X4_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { - ACCUM_FP16_8X1_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { - ACCUM_FP16_8X2_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { - ACCUM_FP16_8X4_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { - ACCUM_FP16_8X8_AVX(m, q, dim, out, ) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_16X1_AVX512(m, q, dim, out, ) -#else - ACCUM_FP16_16X1_AVX(m, q, dim, out, ) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=16, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_16X2_AVX512(m, q, dim, out, ) -#else - ACCUM_FP16_16X2_AVX(m, q, dim, out, ) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=16, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_16X4_AVX512(m, q, dim, out, ) -#else - ACCUM_FP16_16X4_AVX(m, q, dim, out, ) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=16, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_16X8_AVX512(m, q, dim, out, ) -#else - ACCUM_FP16_16X8_AVX(m, q, dim, out, ) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=16, N=16) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_16X16_AVX512(m, q, dim, out, ) -#else - ACCUM_FP16_16X16_AVX(m, q, dim, out, ) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X1_AVX512(m, q, dim, out, ) -#else - ACCUM_FP16_32X1_AVX(m, q, dim, out, ) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X2_AVX512(m, q, dim, out, ) -#else - ACCUM_FP16_32X2_AVX(m, q, dim, out, ) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X4_AVX512(m, q, dim, out, ) -#else - ACCUM_FP16_32X4_AVX(m, q, dim, out, ) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X8_AVX512(m, q, dim, out, ) -#else - ACCUM_FP16_32X8_AVX(m, q, dim, out, ) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=16) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X16_AVX512(m, q, dim, out, ) -#else - ACCUM_FP16_32X16_AVX(m, q, dim, out, ) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=32) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X32_AVX512(m, q, dim, out, ) -#else - ACCUM_FP16_32X32_AVX(m, q, dim, out, ) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=2, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_2X1_AVX(m, q, dim, out, NEGATE_FP32_SSE) -} - -//! Compute the distance between matrix and query (FP16, M=2, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_2X2_AVX(m, q, dim, out, NEGATE_FP32_SSE) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_4X1_AVX(m, q, dim, out, NEGATE_FP32_SSE) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_4X2_AVX(m, q, dim, out, NEGATE_FP32_SSE) -} - -//! Compute the distance between matrix and query (FP16, M=4, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_4X4_AVX(m, q, dim, out, NEGATE_FP32_SSE) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_8X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_8X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_8X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -} - -//! Compute the distance between matrix and query (FP16, M=8, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { - ACCUM_FP16_8X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -} - -//! Compute the distance between matrix and query (FP16, M=16, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_16X1_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#else - ACCUM_FP16_16X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=16, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_16X2_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#else - ACCUM_FP16_16X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=16, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_16X4_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#else - ACCUM_FP16_16X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=16, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_16X8_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#else - ACCUM_FP16_16X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=16, N=16) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_16X16_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#else - ACCUM_FP16_16X16_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X1_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#else - ACCUM_FP16_32X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X2_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#else - ACCUM_FP16_32X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X4_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#else - ACCUM_FP16_32X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X8_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#else - ACCUM_FP16_32X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=16) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X16_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#else - ACCUM_FP16_32X16_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#endif // __AVX512F__ -} - -//! Compute the distance between matrix and query (FP16, M=32, N=32) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX512F__) - ACCUM_FP16_32X32_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#else - ACCUM_FP16_32X32_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#endif // __AVX512F__ -} -#endif // !__ARM_NEON -#endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) - -// sparse -#if defined(__AVX__) -const static __m128i SHUFFLE_MASK256[256] = { - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, -127, -127), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 5, - 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, - 6, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, - 6, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, - 6, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, - 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4, 3, - 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 3, - 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, - 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, - 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 11, 10), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 13, 12), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 11, 10), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 11, 10, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 11, 10, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 11, 10, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 11, 10, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, - 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, - 2), - _mm_set_epi8(-127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 15, 14), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 11, 10), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 11, 10, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 11, 10, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 11, 10, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 11, 10, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, - 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, - 2), - _mm_set_epi8(-127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 13, 12), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, - 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, - 2), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 11, 10), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 3, - 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, - 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, - 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, - 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, - 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, - 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 3, - 2), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, - 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, - 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, - 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 3, - 2), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, - 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 3, - 2), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, - 4), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), - _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), -}; - -constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; - -float InnerProductSparseInSegmentAVX(uint32_t m_sparse_count, - const uint16_t *m_sparse_index, - const Float16 *m_sparse_value, - uint32_t q_sparse_count, - const uint16_t *q_sparse_index, - const Float16 *q_sparse_value) { - float sum = 0.0f; - - // handle if the first dim is zero - bool m_zero = false; - Float16 m_zero_value{0.0f}; - if (m_sparse_count > 0 && m_sparse_index[0] == 0) { - m_sparse_count--; - m_sparse_index++; - m_zero_value = *m_sparse_value++; - m_zero = true; - } - - bool q_zero = false; - Float16 q_zero_value{0.0f}; - if (q_sparse_count > 0 && q_sparse_index[0] == 0) { - q_sparse_count--; - q_sparse_index++; - q_zero_value = *q_sparse_value++; - q_zero = true; - } - - if (m_zero && q_zero) { - sum = m_zero_value * q_zero_value; - } - - size_t i1 = 0, i2 = 0; - size_t end1 = m_sparse_count / 8 * 8; - size_t end2 = q_sparse_count / 8 * 8; - - uint16_t fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; - uint16_t fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; - - Float16 *val_start_1 = reinterpret_cast(fixed_buffer_1); - Float16 *val_start_2 = reinterpret_cast(fixed_buffer_2); - - Float16 *val_1 = val_start_1; - Float16 *val_2 = val_start_2; - - if (i1 < end1 && i2 < end2) { - while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { - i1 += 8; - if (i1 >= end1) goto do_scalar; - } - - while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { - i2 += 8; - if (i2 >= end2) goto do_scalar; - } - - __m128i mm_index_m = - _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); - __m128i mm_index_q = - _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); - - while (true) { -#ifdef DEBUG_PRINT - std::cout << "index 1: " << std::endl; - print_data16(&mm_index_m); - - std::cout << "index 2: " << std::endl; - print_data16(&mm_index_q); -#endif - - __m128i mm_cmp_res = - _mm_cmpistrm(mm_index_q, mm_index_m, - _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); - -#ifdef DEBUG_PRINT - std::cout << "cmp res: " << std::endl; - print_data16(&mm_cmp_res); -#endif - - int r = _mm_extract_epi32(mm_cmp_res, 0); - - if (r) { - int r1 = r; - - __m128i v = _mm_loadu_si128( - reinterpret_cast(&m_sparse_value[i1])); - __m128i vs = _mm_shuffle_epi8(v, SHUFFLE_MASK256[r1]); - - _mm_storeu_si128(reinterpret_cast<__m128i *>(val_1), vs); - val_1 += _mm_popcnt_u32(r1); - - mm_cmp_res = _mm_cmpistrm( - mm_index_m, mm_index_q, - _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); - r = _mm_extract_epi32(mm_cmp_res, 0); - - r1 = r; - - v = _mm_loadu_si128( - reinterpret_cast(&q_sparse_value[i2])); - vs = _mm_shuffle_epi8(v, SHUFFLE_MASK256[r1]); - - _mm_storeu_si128(reinterpret_cast<__m128i *>(val_2), vs); - val_2 += _mm_popcnt_u32(r1); - } - - const uint16_t id1_max = m_sparse_index[i1 + 7]; - - if (id1_max <= q_sparse_index[i2 + 7]) { - i1 += 8; - if (i1 >= end1) goto do_scalar; - mm_index_m = _mm_loadu_si128( - reinterpret_cast(&m_sparse_index[i1])); - } - - if (id1_max >= q_sparse_index[i2 + 7]) { - i2 += 8; - if (i2 >= end2) goto do_scalar; - mm_index_q = _mm_loadu_si128( - reinterpret_cast(&q_sparse_index[i2])); - } - } - } - -do_scalar: - while (i1 < m_sparse_count && i2 < q_sparse_count) { - if (m_sparse_index[i1] == q_sparse_index[i2]) { - *val_1++ = m_sparse_value[i1]; - *val_2++ = q_sparse_value[i2]; - - ++i1; - ++i2; - } else if (m_sparse_index[i1] < q_sparse_index[i2]) { - ++i1; - } else { - ++i2; - } - } - - size_t res_num = val_1 - val_start_1; - - size_t res_num8 = res_num / 8 * 8; - - if (res_num8) { - __m256 sum256 = _mm256_setzero_ps(); - - for (size_t k = 0; k < res_num8; k += 8) { - __m256 ymm_1 = - _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(val_start_1 + k))); - __m256 ymm_2 = - _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(val_start_2 + k))); - ACCUM_FP32_STEP_AVX(ymm_1, ymm_2, sum256); - } - - sum += HorizontalAdd_FP32_V256(sum256); - } - - for (size_t k = res_num8; k < res_num; ++k) - sum += val_start_1[k] * val_start_2[k]; - - return sum; -} - -#elif defined(__AVX512FP16__) -const static __m128i SHUFFLE_MASK256[256] = { - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, -127, -127), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 5, - 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, - 6, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, - 6, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, - 6, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, - 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4, 3, - 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 3, - 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, - 8, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, - 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, - 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 11, 10), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, - 10, 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 13, 12), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 13, 12, 11, 10), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 11, 10, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 11, 10, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 11, 10, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 11, 10, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, - 12, 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, - 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, - 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, - 2), - _mm_set_epi8(-127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, 15, 14), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 11, 10), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 11, 10, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 11, 10, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 11, 10, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 11, 10, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, - 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, - 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, - 2), - _mm_set_epi8(-127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 13, 12), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 5, 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 5, 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 7, 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 7, 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 9, 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 9, 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 3, 2, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 9, 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 9, 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, - 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, - 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, - 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, - 2), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, - 14, 13, 12, 11, 10), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 3, - 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, - 4, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, - 4, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, - 6, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, - 6, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, - 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 3, - 2), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, - 8, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, - 8, 3, 2), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, - 8, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 3, - 2), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, - 8, 7, 6), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 3, - 2), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, - 4), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), - _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), - _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), -}; - -constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; - -float InnerProductSparseInSegmentAVX512FP16(uint32_t m_sparse_count, - const uint16_t *m_sparse_index, - const Float16 *m_sparse_value, - uint32_t q_sparse_count, - const uint16_t *q_sparse_index, - const Float16 *q_sparse_value) { - float sum = 0.0f; - - // handle if the first dim is zero - bool m_zero = false; - Float16 m_zero_value{0.0f}; - if (m_sparse_count > 0 && m_sparse_index[0] == 0) { - m_sparse_count--; - m_sparse_index++; - m_zero_value = *m_sparse_value++; - m_zero = true; - } - - bool q_zero = false; - Float16 q_zero_value{0.0f}; - if (q_sparse_count > 0 && q_sparse_index[0] == 0) { - q_sparse_count--; - q_sparse_index++; - q_zero_value = *q_sparse_value++; - q_zero = true; - } - - if (m_zero && q_zero) { - sum = m_zero_value * q_zero_value; - } - - size_t i1 = 0, i2 = 0; - size_t end1 = m_sparse_count / 8 * 8; - size_t end2 = q_sparse_count / 8 * 8; - - uint16_t fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; - uint16_t fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; - - Float16 *val_start_1 = reinterpret_cast(fixed_buffer_1); - Float16 *val_start_2 = reinterpret_cast(fixed_buffer_2); - - Float16 *val_1 = val_start_1; - Float16 *val_2 = val_start_2; - - if (i1 < end1 && i2 < end2) { - while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { - i1 += 8; - if (i1 >= end1) goto do_scalar; - } - - while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { - i2 += 8; - if (i2 >= end2) goto do_scalar; - } - - __m128i mm_index_m = - _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); - __m128i mm_index_q = - _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); - - while (true) { -#ifdef DEBUG_PRINT - std::cout << "index 1: " << std::endl; - print_data16(&mm_index_m); - - std::cout << "index 2: " << std::endl; - print_data16(&mm_index_q); -#endif - - __m128i mm_cmp_res = - _mm_cmpistrm(mm_index_q, mm_index_m, - _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); - -#ifdef DEBUG_PRINT - std::cout << "cmp res: " << std::endl; - print_data16(&mm_cmp_res); -#endif - - int r = _mm_extract_epi32(mm_cmp_res, 0); - - if (r) { - int r1 = r; - - __m128i v = _mm_loadu_si128( - reinterpret_cast(&m_sparse_value[i1])); - __m128h vs = _mm_castsi128_ph(_mm_shuffle_epi8(v, SHUFFLE_MASK256[r1])); - - _mm_storeu_ph(val_1, vs); - val_1 += _mm_popcnt_u32(r1); - - mm_cmp_res = _mm_cmpistrm( - mm_index_m, mm_index_q, - _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); - r = _mm_extract_epi32(mm_cmp_res, 0); - - r1 = r; - - v = _mm_loadu_si128( - reinterpret_cast(&q_sparse_value[i2])); - vs = _mm_castsi128_ph(_mm_shuffle_epi8(v, SHUFFLE_MASK256[r1])); - - _mm_storeu_ph(val_2, vs); - val_2 += _mm_popcnt_u32(r1); - } - - const uint16_t id1_max = m_sparse_index[i1 + 7]; - - if (id1_max <= q_sparse_index[i2 + 7]) { - i1 += 8; - if (i1 >= end1) goto do_scalar; - mm_index_m = _mm_loadu_si128( - reinterpret_cast(&m_sparse_index[i1])); - } - - if (id1_max >= q_sparse_index[i2 + 7]) { - i2 += 8; - if (i2 >= end2) goto do_scalar; - mm_index_q = _mm_loadu_si128( - reinterpret_cast(&q_sparse_index[i2])); - } - } - } - -do_scalar: - while (i1 < m_sparse_count && i2 < q_sparse_count) { - if (m_sparse_index[i1] == q_sparse_index[i2]) { - *val_1++ = m_sparse_value[i1]; - *val_2++ = q_sparse_value[i2]; - - ++i1; - ++i2; - } else if (m_sparse_index[i1] < q_sparse_index[i2]) { - ++i1; - } else { - ++i2; - } - } - - size_t res_num = val_1 - val_start_1; - - size_t res_num8 = res_num / 8 * 8; - - if (res_num8) { - __m128h sum128 = _mm_set1_ph(0); - - for (size_t k = 0; k < res_num8; k += 8) { - sum128 = _mm_add_ph(sum128, _mm_mul_ph(_mm_loadu_ph(val_start_1 + k), - _mm_loadu_ph(val_start_2 + k))); - } - - Float16 __attribute__((aligned(16))) tmp_res[8]; - _mm_store_ph(tmp_res, sum128); - sum += (tmp_res[0] + tmp_res[1] + tmp_res[2] + tmp_res[3] + tmp_res[4] + - tmp_res[5] + tmp_res[6] + tmp_res[7]); - } - - for (size_t k = res_num8; k < res_num; ++k) - sum += val_start_1[k] * val_start_2[k]; - - return sum; -} - -#else -float InnerProductSparseInSegment(uint32_t m_sparse_count, - const uint16_t *m_sparse_index, - const Float16 *m_sparse_value, - uint32_t q_sparse_count, - const uint16_t *q_sparse_index, - const Float16 *q_sparse_value) { - float sum = 0.0f; - - size_t m_i = 0; - size_t q_i = 0; - while (m_i < m_sparse_count && q_i < q_sparse_count) { - if (m_sparse_index[m_i] == q_sparse_index[q_i]) { - sum += m_sparse_value[m_i] * q_sparse_value[q_i]; - - ++m_i; - ++q_i; - } else if (m_sparse_index[m_i] < q_sparse_index[q_i]) { - ++m_i; - } else { - ++q_i; - } - } - - return sum; -} -#endif // __AVX512FP16__ - -template <> -float MinusInnerProductSparseMatrix:: - ComputeInnerProductSparseInSegment(uint32_t m_sparse_count, - const uint16_t *m_sparse_index, - const ValueType *m_sparse_value, - uint32_t q_sparse_count, - const uint16_t *q_sparse_index, - const ValueType *q_sparse_value) { -#if defined(__AVX__) - return InnerProductSparseInSegmentAVX(m_sparse_count, m_sparse_index, - m_sparse_value, q_sparse_count, - q_sparse_index, q_sparse_value); -#elif defined(__AVX512FP16__) - return InnerProductSparseInSegmentAVX512FP16(m_sparse_count, m_sparse_index, - m_sparse_value, q_sparse_count, - q_sparse_index, q_sparse_value); -#else - return InnerProductSparseInSegment(m_sparse_count, m_sparse_index, - m_sparse_value, q_sparse_count, - q_sparse_index, q_sparse_value); -#endif -} - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_fp16_avx.cc b/src/ailego/math/inner_product_matrix_fp16_avx.cc new file mode 100644 index 00000000..a68b1fb0 --- /dev/null +++ b/src/ailego/math/inner_product_matrix_fp16_avx.cc @@ -0,0 +1,706 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp16.i" +#include "distance_matrix_inner_product_utility.i" +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +// sparse +#if defined(__AVX__) +const static __m128i SHUFFLE_MASK256[256] = { + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, -127, -127), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 5, + 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, + 6, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, + 6, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, + 6, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, + 8, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, + 8, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, + 8, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, + 8, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, + 8, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, + 8, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, + 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, + 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 11, 10), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, + 10, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, + 10, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, + 10, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, + 10, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, + 10, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, + 10, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, + 10, 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, + 10, 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, + 10, 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, + 10, 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 13, 12), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 11, 10), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 11, 10, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 11, 10, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 11, 10, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 11, 10, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, + 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, + 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, + 12, 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, + 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, + 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, + 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, + 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, + 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, + 2), + _mm_set_epi8(-127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 15, 14), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 11, 10), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 11, 10, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 11, 10, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 11, 10, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 11, 10, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, + 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, + 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, + 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, + 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, + 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, + 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, + 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, + 2), + _mm_set_epi8(-127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 13, 12), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 13, 12, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 13, 12, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 13, 12, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 13, 12, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 13, 12, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, + 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, + 2), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, + 14, 13, 12, 11, 10), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 11, 10, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 11, 10, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 11, 10, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, + 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, + 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 11, 10, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, + 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, + 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, + 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 3, + 2), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, + 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, + 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, + 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 3, + 2), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, + 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 3, + 2), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, + 4), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), + _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), +}; + +constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; + +float InnerProductSparseInSegmentAVX(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const Float16 *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const Float16 *q_sparse_value) { + float sum = 0.0f; + + // handle if the first dim is zero + bool m_zero = false; + Float16 m_zero_value{0.0f}; + if (m_sparse_count > 0 && m_sparse_index[0] == 0) { + m_sparse_count--; + m_sparse_index++; + m_zero_value = *m_sparse_value++; + m_zero = true; + } + + bool q_zero = false; + Float16 q_zero_value{0.0f}; + if (q_sparse_count > 0 && q_sparse_index[0] == 0) { + q_sparse_count--; + q_sparse_index++; + q_zero_value = *q_sparse_value++; + q_zero = true; + } + + if (m_zero && q_zero) { + sum = m_zero_value * q_zero_value; + } + + size_t i1 = 0, i2 = 0; + size_t end1 = m_sparse_count / 8 * 8; + size_t end2 = q_sparse_count / 8 * 8; + + uint16_t fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; + uint16_t fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; + + Float16 *val_start_1 = reinterpret_cast(fixed_buffer_1); + Float16 *val_start_2 = reinterpret_cast(fixed_buffer_2); + + Float16 *val_1 = val_start_1; + Float16 *val_2 = val_start_2; + + if (i1 < end1 && i2 < end2) { + while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { + i1 += 8; + if (i1 >= end1) goto do_scalar; + } + + while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { + i2 += 8; + if (i2 >= end2) goto do_scalar; + } + + __m128i mm_index_m = + _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); + __m128i mm_index_q = + _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); + + while (true) { +#ifdef DEBUG_PRINT + std::cout << "index 1: " << std::endl; + print_data16(&mm_index_m); + + std::cout << "index 2: " << std::endl; + print_data16(&mm_index_q); +#endif + + __m128i mm_cmp_res = + _mm_cmpistrm(mm_index_q, mm_index_m, + _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); + +#ifdef DEBUG_PRINT + std::cout << "cmp res: " << std::endl; + print_data16(&mm_cmp_res); +#endif + + int r = _mm_extract_epi32(mm_cmp_res, 0); + + if (r) { + int r1 = r; + + __m128i v = _mm_loadu_si128( + reinterpret_cast(&m_sparse_value[i1])); + __m128i vs = _mm_shuffle_epi8(v, SHUFFLE_MASK256[r1]); + + _mm_storeu_si128(reinterpret_cast<__m128i *>(val_1), vs); + val_1 += _mm_popcnt_u32(r1); + + mm_cmp_res = _mm_cmpistrm( + mm_index_m, mm_index_q, + _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); + r = _mm_extract_epi32(mm_cmp_res, 0); + + r1 = r; + + v = _mm_loadu_si128( + reinterpret_cast(&q_sparse_value[i2])); + vs = _mm_shuffle_epi8(v, SHUFFLE_MASK256[r1]); + + _mm_storeu_si128(reinterpret_cast<__m128i *>(val_2), vs); + val_2 += _mm_popcnt_u32(r1); + } + + const uint16_t id1_max = m_sparse_index[i1 + 7]; + + if (id1_max <= q_sparse_index[i2 + 7]) { + i1 += 8; + if (i1 >= end1) goto do_scalar; + mm_index_m = _mm_loadu_si128( + reinterpret_cast(&m_sparse_index[i1])); + } + + if (id1_max >= q_sparse_index[i2 + 7]) { + i2 += 8; + if (i2 >= end2) goto do_scalar; + mm_index_q = _mm_loadu_si128( + reinterpret_cast(&q_sparse_index[i2])); + } + } + } + +do_scalar: + while (i1 < m_sparse_count && i2 < q_sparse_count) { + if (m_sparse_index[i1] == q_sparse_index[i2]) { + *val_1++ = m_sparse_value[i1]; + *val_2++ = q_sparse_value[i2]; + + ++i1; + ++i2; + } else if (m_sparse_index[i1] < q_sparse_index[i2]) { + ++i1; + } else { + ++i2; + } + } + + size_t res_num = val_1 - val_start_1; + + size_t res_num8 = res_num / 8 * 8; + + if (res_num8) { + __m256 sum256 = _mm256_setzero_ps(); + + for (size_t k = 0; k < res_num8; k += 8) { + __m256 ymm_1 = + _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(val_start_1 + k))); + __m256 ymm_2 = + _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(val_start_2 + k))); + ACCUM_FP32_STEP_AVX(ymm_1, ymm_2, sum256); + } + + sum += HorizontalAdd_FP32_V256(sum256); + } + + for (size_t k = res_num8; k < res_num; ++k) + sum += val_start_1[k] * val_start_2[k]; + + return sum; +} + +#endif // __AVX__ + + +#if defined(__AVX__) +void InnerProductAVX(const Float16 *lhs, const Float16 *rhs, size_t size, + float *out) { + ACCUM_FP16_1X1_AVX(lhs, rhs, size, out, 0ull, ) +} + +void MinusInnerProductAVX(const Float16 *lhs, const Float16 *rhs, size_t size, + float *out) { + ACCUM_FP16_1X1_AVX(lhs, rhs, size, out, 0ull, NEGATE_FP32_GENERAL) +} +#endif +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_fp16_avx512.cc b/src/ailego/math/inner_product_matrix_fp16_avx512.cc new file mode 100644 index 00000000..7e07952e --- /dev/null +++ b/src/ailego/math/inner_product_matrix_fp16_avx512.cc @@ -0,0 +1,766 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp16.i" +#include "distance_matrix_inner_product_utility.i" +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX512FP16__) +//! Inner Product +float InnerProductAVX512FP16(const Float16 *lhs, const Float16 *rhs, + size_t size) { + const Float16 *last = lhs + size; + const Float16 *last_aligned = lhs + ((size >> 6) << 6); + + __m512h zmm_sum_0 = _mm512_setzero_ph(); + __m512h zmm_sum_1 = _mm512_setzero_ph(); + + if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + FMA_FP16_AVX512FP16(_mm512_load_ph(lhs + 0), _mm512_load_ph(rhs + 0), + zmm_sum_0) + + FMA_FP16_AVX512FP16(_mm512_load_ph(lhs + 32), _mm512_load_ph(rhs + 32), + zmm_sum_1) + } + + if (last >= last_aligned + 32) { + FMA_FP16_AVX512FP16(_mm512_load_ph(lhs), _mm512_load_ph(rhs), zmm_sum_0) + lhs += 32; + rhs += 32; + } + } else { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + FMA_FP16_AVX512FP16(_mm512_loadu_ph(lhs + 0), _mm512_loadu_ph(rhs + 0), + zmm_sum_0) + + FMA_FP16_AVX512FP16(_mm512_loadu_ph(lhs + 32), _mm512_loadu_ph(rhs + 32), + zmm_sum_1) + } + + if (last >= last_aligned + 32) { + FMA_FP16_AVX512FP16(_mm512_loadu_ph(lhs), _mm512_loadu_ph(rhs), zmm_sum_0) + lhs += 32; + rhs += 32; + } + } + + zmm_sum_0 = _mm512_add_ph(zmm_sum_0, zmm_sum_1); + + if (lhs != last) { + __mmask32 mask = (__mmask32)((1 << (last - lhs)) - 1); + __m512i zmm_undefined = _mm512_undefined_epi32(); + zmm_sum_0 = _mm512_mask3_fmadd_ph( + _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, lhs)), + _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, rhs)), + zmm_sum_0, mask); + } + + return HorizontalAdd_FP16_V512(zmm_sum_0); +} + +#endif + +// sparse +#if defined(__AVX512FP16__) +constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; + +float InnerProductSparseInSegmentAVX512FP16(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const Float16 *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const Float16 *q_sparse_value) { + const static __m128i SHUFFLE_MASK256[256] = { + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, -127, -127), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 7, 6, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 7, 6, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 7, 6, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 9, 8, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 9, 8, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 9, 8, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 9, 8, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 9, 8, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 9, 8, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 11, 10), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 11, 10, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 11, 10, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 11, 10, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 11, 10, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 11, 10, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 11, 10, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 11, 10, 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 11, 10, 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 11, 10, 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 11, 10, 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 13, 12), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 13, 12, 11, 10), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 11, 10, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 11, 10, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, + 10, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 11, 10, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, + 10, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, + 10, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 11, 10, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, + 10, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, + 10, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, + 10, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, + 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, + 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 13, 12, 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, + 10, 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, + 10, 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, + 10, 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, + 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, + 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, + 10, 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, + 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, + 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, + 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, + 2), + _mm_set_epi8(-127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, 15, 14), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 11, 10), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 11, 10, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 11, 10, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, + 10, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 11, 10, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, + 10, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, + 10, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 11, 10, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, + 10, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, + 10, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, + 10, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, + 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, + 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, + 10, 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, + 10, 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, + 10, 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, + 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, + 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, + 10, 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, + 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, + 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, + 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, + 2), + _mm_set_epi8(-127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 13, 12), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 13, 12, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 13, 12, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 13, 12, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 13, 12, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, + 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, + 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 13, 12, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, + 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, + 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, + 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, + 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, + 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, + 2), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + 15, 14, 13, 12, 11, 10), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 11, 10, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 11, 10, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 11, 10, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, + 5, 4, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, + 5, 4, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 11, 10, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, + 7, 6, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, + 7, 6, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, + 12, 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, + 9, 8, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, + 9, 8, 3, 2), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, + 9, 8, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, + 3, 2), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, + 9, 8, 7, 6), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, + 1, 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, + 3, 2), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, + 5, 4), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), + _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), + _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + }; + + float sum = 0.0f; + + // handle if the first dim is zero + bool m_zero = false; + Float16 m_zero_value{0.0f}; + if (m_sparse_count > 0 && m_sparse_index[0] == 0) { + m_sparse_count--; + m_sparse_index++; + m_zero_value = *m_sparse_value++; + m_zero = true; + } + + bool q_zero = false; + Float16 q_zero_value{0.0f}; + if (q_sparse_count > 0 && q_sparse_index[0] == 0) { + q_sparse_count--; + q_sparse_index++; + q_zero_value = *q_sparse_value++; + q_zero = true; + } + + if (m_zero && q_zero) { + sum = m_zero_value * q_zero_value; + } + + size_t i1 = 0, i2 = 0; + size_t end1 = m_sparse_count / 8 * 8; + size_t end2 = q_sparse_count / 8 * 8; + + uint16_t fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; + uint16_t fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; + + Float16 *val_start_1 = reinterpret_cast(fixed_buffer_1); + Float16 *val_start_2 = reinterpret_cast(fixed_buffer_2); + + Float16 *val_1 = val_start_1; + Float16 *val_2 = val_start_2; + + if (i1 < end1 && i2 < end2) { + while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { + i1 += 8; + if (i1 >= end1) goto do_scalar; + } + + while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { + i2 += 8; + if (i2 >= end2) goto do_scalar; + } + + __m128i mm_index_m = + _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); + __m128i mm_index_q = + _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); + + while (true) { +#ifdef DEBUG_PRINT + std::cout << "index 1: " << std::endl; + print_data16(&mm_index_m); + + std::cout << "index 2: " << std::endl; + print_data16(&mm_index_q); +#endif + + __m128i mm_cmp_res = + _mm_cmpistrm(mm_index_q, mm_index_m, + _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); + +#ifdef DEBUG_PRINT + std::cout << "cmp res: " << std::endl; + print_data16(&mm_cmp_res); +#endif + + int r = _mm_extract_epi32(mm_cmp_res, 0); + + if (r) { + int r1 = r; + + __m128i v = _mm_loadu_si128( + reinterpret_cast(&m_sparse_value[i1])); + __m128h vs = _mm_castsi128_ph(_mm_shuffle_epi8(v, SHUFFLE_MASK256[r1])); + + _mm_storeu_ph(val_1, vs); + val_1 += _mm_popcnt_u32(r1); + + mm_cmp_res = _mm_cmpistrm( + mm_index_m, mm_index_q, + _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); + r = _mm_extract_epi32(mm_cmp_res, 0); + + r1 = r; + + v = _mm_loadu_si128( + reinterpret_cast(&q_sparse_value[i2])); + vs = _mm_castsi128_ph(_mm_shuffle_epi8(v, SHUFFLE_MASK256[r1])); + + _mm_storeu_ph(val_2, vs); + val_2 += _mm_popcnt_u32(r1); + } + + const uint16_t id1_max = m_sparse_index[i1 + 7]; + + if (id1_max <= q_sparse_index[i2 + 7]) { + i1 += 8; + if (i1 >= end1) goto do_scalar; + mm_index_m = _mm_loadu_si128( + reinterpret_cast(&m_sparse_index[i1])); + } + + if (id1_max >= q_sparse_index[i2 + 7]) { + i2 += 8; + if (i2 >= end2) goto do_scalar; + mm_index_q = _mm_loadu_si128( + reinterpret_cast(&q_sparse_index[i2])); + } + } + } + +do_scalar: + while (i1 < m_sparse_count && i2 < q_sparse_count) { + if (m_sparse_index[i1] == q_sparse_index[i2]) { + *val_1++ = m_sparse_value[i1]; + *val_2++ = q_sparse_value[i2]; + + ++i1; + ++i2; + } else if (m_sparse_index[i1] < q_sparse_index[i2]) { + ++i1; + } else { + ++i2; + } + } + + size_t res_num = val_1 - val_start_1; + + size_t res_num8 = res_num / 8 * 8; + + if (res_num8) { + __m128h sum128 = _mm_set1_ph(0); + + for (size_t k = 0; k < res_num8; k += 8) { + sum128 = _mm_add_ph(sum128, _mm_mul_ph(_mm_loadu_ph(val_start_1 + k), + _mm_loadu_ph(val_start_2 + k))); + } + + Float16 __attribute__((aligned(16))) tmp_res[8]; + _mm_store_ph(tmp_res, sum128); + sum += (tmp_res[0] + tmp_res[1] + tmp_res[2] + tmp_res[3] + tmp_res[4] + + tmp_res[5] + tmp_res[6] + tmp_res[7]); + } + + for (size_t k = res_num8; k < res_num; ++k) + sum += val_start_1[k] * val_start_2[k]; + + return sum; +} + +#endif // __AVX512FP16__ + +#if defined(__AVX512F__) +void InnerProductAVX512(const Float16 *lhs, const Float16 *rhs, size_t size, + float *out) { + ACCUM_FP16_1X1_AVX512(lhs, rhs, size, out, 0ull, ) +} + +void MinusInnerProductAVX512(const Float16 *lhs, const Float16 *rhs, + size_t size, float *out) { + ACCUM_FP16_1X1_AVX512(lhs, rhs, size, out, 0ull, NEGATE_FP32_GENERAL) +} +#endif //__AVX512F__ + + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_fp16_dispatch.cc b/src/ailego/math/inner_product_matrix_fp16_dispatch.cc new file mode 100644 index 00000000..86760130 --- /dev/null +++ b/src/ailego/math/inner_product_matrix_fp16_dispatch.cc @@ -0,0 +1,162 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) +float InnerProductNEON(const Float16 *lhs, const Float16 *rhs, size_t size); +float MinusInnerProductNEON(const Float16 *lhs, const Float16 *rhs, + size_t size); +#endif + +#if defined(__AVX__) +void InnerProductAVX(const Float16 *lhs, const Float16 *rhs, size_t size, + float *out); +void MinusInnerProductAVX(const Float16 *lhs, const Float16 *rhs, size_t size, + float *out); +float InnerProductSparseInSegmentAVX(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const Float16 *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const Float16 *q_sparse_value); +#endif + +#if defined(__AVX512F__) +void InnerProductAVX512(const Float16 *lhs, const Float16 *rhs, size_t size, + float *out); +void MinusInnerProductAVX512(const Float16 *lhs, const Float16 *rhs, + size_t size, float *out); +#endif + +#if defined(__AVX512FP16__) +float InnerProductAVX512FP16(const Float16 *lhs, const Float16 *rhs, + size_t size); +float InnerProductSparseInSegmentAVX512FP16(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const Float16 *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const Float16 *q_sparse_value); +#endif + +#if (defined(__F16C__) && defined(__AVX__)) || \ + (defined(__ARM_NEON) && defined(__aarch64__)) +//! Compute the distance between matrix and query (FP16, M=1, N=1) +void InnerProductMatrix::Compute(const ValueType *m, + const ValueType *q, size_t dim, + float *out) { +#if defined(__ARM_NEON) + *out = InnerProductNEON(m, q, dim); +#else +#if defined(__AVX512FP16__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { + *out = InnerProductAVX512FP16(m, q, dim); + return; + } +#endif //__AVX512FP16__ +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + InnerProductAVX512(m, q, dim, out); + return; + } +#endif //__AVX512F__ + InnerProductAVX(m, q, dim, out); +#endif //__ARM_NEON +} + +//! Compute the distance between matrix and query (FP16, M=1, N=1) +void MinusInnerProductMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, float *out) { +#if defined(__ARM_NEON) + *out = MinusInnerProductNEON(m, q, dim); +#else +#if defined(__AVX512FP16__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { + *out = -InnerProductAVX512FP16(m, q, dim); + return; + } +#endif //__AVX512FP16__ +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + MinusInnerProductAVX512(m, q, dim, out); + return; + } +#endif //__AVX512F__ + + MinusInnerProductAVX(m, q, dim, out); + +#endif //__ARM_NEON +} + +#endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) + +// sparse +float InnerProductSparseInSegment(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const Float16 *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const Float16 *q_sparse_value) { + float sum = 0.0f; + + size_t m_i = 0; + size_t q_i = 0; + while (m_i < m_sparse_count && q_i < q_sparse_count) { + if (m_sparse_index[m_i] == q_sparse_index[q_i]) { + sum += m_sparse_value[m_i] * q_sparse_value[q_i]; + + ++m_i; + ++q_i; + } else if (m_sparse_index[m_i] < q_sparse_index[q_i]) { + ++m_i; + } else { + ++q_i; + } + } + + return sum; +} + +template <> +float MinusInnerProductSparseMatrix:: + ComputeInnerProductSparseInSegment(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const ValueType *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const ValueType *q_sparse_value) { +#if defined(__AVX512FP16__) + return InnerProductSparseInSegmentAVX512FP16(m_sparse_count, m_sparse_index, + m_sparse_value, q_sparse_count, + q_sparse_index, q_sparse_value); +#elif defined(__AVX__) + return InnerProductSparseInSegmentAVX(m_sparse_count, m_sparse_index, + m_sparse_value, q_sparse_count, + q_sparse_index, q_sparse_value); + +#else + return InnerProductSparseInSegment(m_sparse_count, m_sparse_index, + m_sparse_value, q_sparse_count, + q_sparse_index, q_sparse_value); +#endif +} + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_fp16_neon.cc b/src/ailego/math/inner_product_matrix_fp16_neon.cc new file mode 100644 index 00000000..a7c3090d --- /dev/null +++ b/src/ailego/math/inner_product_matrix_fp16_neon.cc @@ -0,0 +1,42 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp16.i" +#include "distance_matrix_inner_product_utility.i" +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) +float InnerProductNEON(const Float16 *lhs, const Float16 *rhs, size_t size) { + float score; + + ACCUM_FP16_1X1_NEON(lhs, rhs, size, &score, 0ull, ) + + return score; +} + +float MinusInnerProductNEON(const Float16 *lhs, const Float16 *rhs, + size_t size) { + float score; + + ACCUM_FP16_1X1_NEON(lhs, rhs, size, &score, 0ull, NEGATE_FP32_GENERAL) + + return score; +} +#endif + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_fp32.cc b/src/ailego/math/inner_product_matrix_fp32.cc deleted file mode 100644 index 78e260d0..00000000 --- a/src/ailego/math/inner_product_matrix_fp32.cc +++ /dev/null @@ -1,1180 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "distance_matrix_accum_fp32.i" -#include "inner_product_matrix.h" - -namespace zvec { -namespace ailego { - -#define ACCUM_FP32_STEP_SSE FMA_FP32_SSE -#define ACCUM_FP32_STEP_AVX FMA_FP32_AVX -#define ACCUM_FP32_STEP_AVX512 FMA_FP32_AVX512 -#define ACCUM_FP32_STEP_NEON FMA_FP32_NEON - -#if defined(__AVX512F__) && !defined(__AVX512DQ__) -#define _mm512_xor_ps(a, b) \ - _mm512_castsi512_ps( \ - _mm512_xor_epi32(_mm512_castps_si512(a), _mm512_castps_si512(b))) -#endif // __AVX512DQ__ - -#if defined(__SSE__) -static const __m128 NEGZEROS_FP32_SSE = _mm_set1_ps(-0.0f); -#endif // __SSE__ - -#if defined(__AVX__) -static const __m256 NEGZEROS_FP32_AVX = _mm256_set1_ps(-0.0f); -#endif // __AVX__ - -#if defined(__AVX512F__) -static const __m512 NEGZEROS_FP32_AVX512 = _mm512_set1_ps(-0.0f); -#endif // __AVX512F__ - -//! Reverse sign of value (SSE) -#define NEGATE_FP32_SSE(v, ...) _mm_xor_ps(v, NEGZEROS_FP32_SSE) - -//! Reverse sign of value (AVX) -#define NEGATE_FP32_AVX(v, ...) _mm256_xor_ps(v, NEGZEROS_FP32_AVX) - -//! Reverse sign of value (AVX512) -#define NEGATE_FP32_AVX512(v, ...) _mm512_xor_ps(v, NEGZEROS_FP32_AVX512) - -//! Calculate Fused-Multiply-Add (GENERAL) -#define FMA_FP32_GENERAL(m, q, sum) sum += (m * q); - -//! Calculate Fused-Multiply-Add (SSE) -#define FMA_FP32_SSE(xmm_m, xmm_q, xmm_sum) \ - xmm_sum = _mm_fmadd_ps(xmm_m, xmm_q, xmm_sum); - -//! Calculate Fused-Multiply-Add (AVX) -#define FMA_FP32_AVX(ymm_m, ymm_q, ymm_sum) \ - ymm_sum = _mm256_fmadd_ps(ymm_m, ymm_q, ymm_sum); - -//! Calculate Fused-Multiply-Add (AVX512) -#define FMA_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ - zmm_sum = _mm512_fmadd_ps(zmm_m, zmm_q, zmm_sum); - -//! Calculate Fused-Multiply-Add (NEON) -#define FMA_FP32_NEON(v_m, v_q, v_sum) v_sum = vfmaq_f32(v_sum, v_m, v_q); - -#if defined(__ARM_NEON) -//! Inner Product -static inline float InnerProductNEON(const float *lhs, const float *rhs, - size_t size) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 3) << 3); - - float32x4_t v_sum_0 = vdupq_n_f32(0); - float32x4_t v_sum_1 = vdupq_n_f32(0); - - for (; lhs != last_aligned; lhs += 8, rhs += 8) { - v_sum_0 = vfmaq_f32(v_sum_0, vld1q_f32(lhs + 0), vld1q_f32(rhs + 0)); - v_sum_1 = vfmaq_f32(v_sum_1, vld1q_f32(lhs + 4), vld1q_f32(rhs + 4)); - } - if (last >= last_aligned + 4) { - v_sum_0 = vfmaq_f32(v_sum_0, vld1q_f32(lhs), vld1q_f32(rhs)); - lhs += 4; - rhs += 4; - } - - float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); - switch (last - lhs) { - case 3: - FMA_FP32_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - FMA_FP32_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - FMA_FP32_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __ARM_NEON - -#if defined(__SSE__) -//! Inner Product -static inline float InnerProductSSE(const float *lhs, const float *rhs, - size_t size) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 3) << 3); - - __m128 xmm_sum_0 = _mm_setzero_ps(); - __m128 xmm_sum_1 = _mm_setzero_ps(); - - if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { - for (; lhs != last_aligned; lhs += 8, rhs += 8) { - __m128 xmm_lhs_0 = _mm_load_ps(lhs + 0); - __m128 xmm_lhs_1 = _mm_load_ps(lhs + 4); - __m128 xmm_rhs_0 = _mm_load_ps(rhs + 0); - __m128 xmm_rhs_1 = _mm_load_ps(rhs + 4); - xmm_sum_0 = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum_0); - xmm_sum_1 = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum_1); - } - - if (last >= last_aligned + 4) { - xmm_sum_0 = _mm_fmadd_ps(_mm_load_ps(lhs), _mm_load_ps(rhs), xmm_sum_0); - lhs += 4; - rhs += 4; - } - } else { - for (; lhs != last_aligned; lhs += 8, rhs += 8) { - __m128 xmm_lhs_0 = _mm_loadu_ps(lhs + 0); - __m128 xmm_lhs_1 = _mm_loadu_ps(lhs + 4); - __m128 xmm_rhs_0 = _mm_loadu_ps(rhs + 0); - __m128 xmm_rhs_1 = _mm_loadu_ps(rhs + 4); - xmm_sum_0 = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum_0); - xmm_sum_1 = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum_1); - } - - if (last >= last_aligned + 4) { - xmm_sum_0 = _mm_fmadd_ps(_mm_loadu_ps(lhs), _mm_loadu_ps(rhs), xmm_sum_0); - lhs += 4; - rhs += 4; - } - } - float result = HorizontalAdd_FP32_V128(_mm_add_ps(xmm_sum_0, xmm_sum_1)); - - switch (last - lhs) { - case 3: - FMA_FP32_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - FMA_FP32_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - FMA_FP32_GENERAL(lhs[0], rhs[0], result) - } - return result; -} - -#endif // __SSE__ - -// #if 1 -#if defined(__SSE4_1__) -const static __m128i SHUFFLE_MASK16[16] = { - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, -127, -127), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, - 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 13, 12), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, - 4), - _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), -}; - -constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; - -float InnerProductSparseInSegmentSSE(uint32_t m_sparse_count, - const uint16_t *m_sparse_index, - const float *m_sparse_value, - uint32_t q_sparse_count, - const uint16_t *q_sparse_index, - const float *q_sparse_value) { - float sum = 0.0f; - - // handle if the first dim is zero - bool m_zero = false; - float m_zero_value = 0.0f; - if (m_sparse_count > 0 && m_sparse_index[0] == 0) { - m_sparse_count--; - m_sparse_index++; - m_zero_value = *m_sparse_value++; - m_zero = true; - } - - bool q_zero = false; - float q_zero_value = 0.0f; - if (q_sparse_count > 0 && q_sparse_index[0] == 0) { - q_sparse_count--; - q_sparse_index++; - q_zero_value = *q_sparse_value++; - q_zero = true; - } - - if (m_zero && q_zero) { - sum = m_zero_value * q_zero_value; - } - - size_t i1 = 0, i2 = 0; - size_t end1 = m_sparse_count / 8 * 8; - size_t end2 = q_sparse_count / 8 * 8; - - // std::vector mem1; - // std::vector mem2; - - float fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; - float fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; - - float *val_start_1 = fixed_buffer_1; - float *val_start_2 = fixed_buffer_2; - - // uint32_t max_count = std::max(m_sparse_count, q_sparse_count); - - // if (MAX_SPARSE_BUFFER_LENGTH < max_count) { - // mem1.reserve(max_count); - // mem2.reserve(max_count); - - // val_start_1 = mem1.data(); - // val_start_2 = mem2.data(); - // } - - float *val_1 = val_start_1; - float *val_2 = val_start_2; - - if (i1 < end1 && i2 < end2) { - while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { - i1 += 8; - if (i1 >= end1) goto do_scalar; - } - - while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { - i2 += 8; - if (i2 >= end2) goto do_scalar; - } - - __m128i mm_index_m = - _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); - __m128i mm_index_q = - _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); - - while (true) { -#ifdef DEBUG_PRINT - std::cout << "index 1: " << std::endl; - print_data16(&mm_index_m); - - std::cout << "index 2: " << std::endl; - print_data16(&mm_index_q); -#endif - - __m128i mm_cmp_res = - _mm_cmpistrm(mm_index_q, mm_index_m, - _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); - -#ifdef DEBUG_PRINT - std::cout << "cmp res: " << std::endl; - print_data16(&mm_cmp_res); -#endif - - int r = _mm_extract_epi32(mm_cmp_res, 0); - - if (r) { - int r1 = r & 15; - - __m128i v = _mm_loadu_si128( - reinterpret_cast(&m_sparse_value[i1])); - __m128 vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); - - _mm_storeu_ps(val_1, vs); - val_1 += _mm_popcnt_u32(r1); - - int r2 = (r >> 4) & 15; - v = _mm_loadu_si128( - reinterpret_cast(&m_sparse_value[i1 + 4])); - vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); - _mm_storeu_ps(val_1, vs); - val_1 += _mm_popcnt_u32(r2); - - mm_cmp_res = _mm_cmpistrm( - mm_index_m, mm_index_q, - _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); - r = _mm_extract_epi32(mm_cmp_res, 0); - - r1 = r & 15; - - v = _mm_loadu_si128( - reinterpret_cast(&q_sparse_value[i2])); - vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); - _mm_storeu_ps(val_2, vs); - val_2 += _mm_popcnt_u32(r1); - - r2 = (r >> 4) & 15; - v = _mm_loadu_si128( - reinterpret_cast(&q_sparse_value[i2 + 4])); - vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); - _mm_storeu_ps(val_2, vs); - val_2 += _mm_popcnt_u32(r2); - } - - const uint16_t id1_max = m_sparse_index[i1 + 7]; - - if (id1_max <= q_sparse_index[i2 + 7]) { - i1 += 8; - if (i1 >= end1) goto do_scalar; - mm_index_m = _mm_loadu_si128( - reinterpret_cast(&m_sparse_index[i1])); - } - - if (id1_max >= q_sparse_index[i2 + 7]) { - i2 += 8; - if (i2 >= end2) goto do_scalar; - mm_index_q = _mm_loadu_si128( - reinterpret_cast(&q_sparse_index[i2])); - } - } - } - -do_scalar: - while (i1 < m_sparse_count && i2 < q_sparse_count) { - if (m_sparse_index[i1] == q_sparse_index[i2]) { - *val_1++ = m_sparse_value[i1]; - *val_2++ = q_sparse_value[i2]; - - ++i1; - ++i2; - } else if (m_sparse_index[i1] < q_sparse_index[i2]) { - ++i1; - } else { - ++i2; - } - } - - size_t res_num = val_1 - val_start_1; - - // if (res_num != val_2 - val_start_2) { - // std::cerr << "size mismatch!" << std::endl; - // } - - size_t res_num4 = res_num / 4 * 4; - - if (res_num4) { - __m128 sum128 = _mm_set1_ps(0); - - for (size_t k = 0; k < res_num4; k += 4) { - sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(val_start_1 + k), - _mm_loadu_ps(val_start_2 + k))); - } - - float __attribute__((aligned(16))) tmp_res[4]; - _mm_store_ps(tmp_res, sum128); - sum += (tmp_res[0] + tmp_res[1] + tmp_res[2] + tmp_res[3]); - } - - for (size_t k = res_num4; k < res_num; ++k) - sum += val_start_1[k] * val_start_2[k]; - - return sum; -} -#else -float InnerProductSparseInSegment(uint32_t m_sparse_count, - const uint16_t *m_sparse_index, - const float *m_sparse_value, - uint32_t q_sparse_count, - const uint16_t *q_sparse_index, - const float *q_sparse_value) { - float sum = 0.0f; - - size_t m_i = 0; - size_t q_i = 0; - while (m_i < m_sparse_count && q_i < q_sparse_count) { - if (m_sparse_index[m_i] == q_sparse_index[q_i]) { - sum += m_sparse_value[m_i] * q_sparse_value[q_i]; - - ++m_i; - ++q_i; - } else if (m_sparse_index[m_i] < q_sparse_index[q_i]) { - ++m_i; - } else { - ++q_i; - } - } - - return sum; -} -#endif // __SSE4_1__ - -template <> -float MinusInnerProductSparseMatrix::ComputeInnerProductSparseInSegment( - uint32_t m_sparse_count, const uint16_t *m_sparse_index, - const ValueType *m_sparse_value, uint32_t q_sparse_count, - const uint16_t *q_sparse_index, const ValueType *q_sparse_value) { -#if defined(__SSE4_1__) - return InnerProductSparseInSegmentSSE(m_sparse_count, m_sparse_index, - m_sparse_value, q_sparse_count, - q_sparse_index, q_sparse_value); -#else - return InnerProductSparseInSegment(m_sparse_count, m_sparse_index, - m_sparse_value, q_sparse_count, - q_sparse_index, q_sparse_value); -#endif -} - -#if defined(__AVX__) -//! Inner Product -static inline float InnerProductAVX(const float *lhs, const float *rhs, - size_t size) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 4) << 4); - - __m256 ymm_sum_0 = _mm256_setzero_ps(); - __m256 ymm_sum_1 = _mm256_setzero_ps(); - - if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m256 ymm_lhs_0 = _mm256_load_ps(lhs + 0); - __m256 ymm_lhs_1 = _mm256_load_ps(lhs + 8); - __m256 ymm_rhs_0 = _mm256_load_ps(rhs + 0); - __m256 ymm_rhs_1 = _mm256_load_ps(rhs + 8); - ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); - } - - if (last >= last_aligned + 8) { - ymm_sum_0 = - _mm256_fmadd_ps(_mm256_load_ps(lhs), _mm256_load_ps(rhs), ymm_sum_0); - lhs += 8; - rhs += 8; - } - } else { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m256 ymm_lhs_0 = _mm256_loadu_ps(lhs + 0); - __m256 ymm_lhs_1 = _mm256_loadu_ps(lhs + 8); - __m256 ymm_rhs_0 = _mm256_loadu_ps(rhs + 0); - __m256 ymm_rhs_1 = _mm256_loadu_ps(rhs + 8); - ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); - } - - if (last >= last_aligned + 8) { - ymm_sum_0 = _mm256_fmadd_ps(_mm256_loadu_ps(lhs), _mm256_loadu_ps(rhs), - ymm_sum_0); - lhs += 8; - rhs += 8; - } - } - float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); - - switch (last - lhs) { - case 7: - FMA_FP32_GENERAL(lhs[6], rhs[6], result) - /* FALLTHRU */ - case 6: - FMA_FP32_GENERAL(lhs[5], rhs[5], result) - /* FALLTHRU */ - case 5: - FMA_FP32_GENERAL(lhs[4], rhs[4], result) - /* FALLTHRU */ - case 4: - FMA_FP32_GENERAL(lhs[3], rhs[3], result) - /* FALLTHRU */ - case 3: - FMA_FP32_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - FMA_FP32_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - FMA_FP32_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __AVX__ - -#if defined(__AVX512F__) -//! Inner Product -static inline float InnerProductAVX512(const float *lhs, const float *rhs, - size_t size) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 5) << 5); - - __m512 zmm_sum_0 = _mm512_setzero_ps(); - __m512 zmm_sum_1 = _mm512_setzero_ps(); - - if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - FMA_FP32_AVX512(_mm512_load_ps(lhs + 0), _mm512_load_ps(rhs + 0), - zmm_sum_0) - - FMA_FP32_AVX512(_mm512_load_ps(lhs + 16), _mm512_load_ps(rhs + 16), - zmm_sum_1) - } - - if (last >= last_aligned + 16) { - FMA_FP32_AVX512(_mm512_load_ps(lhs), _mm512_load_ps(rhs), zmm_sum_0) - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - FMA_FP32_AVX512(_mm512_loadu_ps(lhs + 0), _mm512_loadu_ps(rhs + 0), - zmm_sum_0) - - FMA_FP32_AVX512(_mm512_loadu_ps(lhs + 16), _mm512_loadu_ps(rhs + 16), - zmm_sum_1) - } - - if (last >= last_aligned + 16) { - FMA_FP32_AVX512(_mm512_loadu_ps(lhs), _mm512_loadu_ps(rhs), zmm_sum_0) - lhs += 16; - rhs += 16; - } - } - - zmm_sum_0 = _mm512_add_ps(zmm_sum_0, zmm_sum_1); - if (lhs != last) { - __mmask16 mask = (__mmask16)((1 << (last - lhs)) - 1); - __m512 zmm_undefined = _mm512_undefined_ps(); - zmm_sum_0 = _mm512_mask3_fmadd_ps( - _mm512_mask_loadu_ps(zmm_undefined, mask, lhs), - _mm512_mask_loadu_ps(zmm_undefined, mask, rhs), zmm_sum_0, mask); - } - return HorizontalAdd_FP32_V512(zmm_sum_0); -} -#endif - -#if defined(__SSE__) || defined(__ARM_NEON) -//! Compute the distance between matrix and query (FP32, M=1, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - *out = InnerProductNEON(m, q, dim); -#else -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - if (dim > 15) { - *out = InnerProductAVX512(m, q, dim); - return; - } - } -#endif // __AVX512F__ -#if defined(__AVX__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { - if (dim > 7) { - *out = InnerProductAVX(m, q, dim); - return; - } - } -#endif // __AVX__ - *out = InnerProductSSE(m, q, dim); -#endif // __ARM_NEON -} - -//! Compute the distance between matrix and query (FP32, M=2, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_2X1_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_2X1_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_2X1_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=2, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_2X2_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_2X2_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_2X2_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X1_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_4X1_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_4X1_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X2_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_4X2_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_4X2_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X4_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_4X4_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_4X4_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X1_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_8X1_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_8X1_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X2_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_8X2_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_8X2_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X4_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_8X4_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_8X4_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X8_NEON(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_8X8_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_8X8_SSE(m, q, dim, out, ) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=16, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X1_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_16X1_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_16X1_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_16X1_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X2_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_16X2_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_16X2_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_16X2_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X4_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_16X4_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_16X4_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_16X4_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X8_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_16X8_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_16X8_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_16X8_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=16) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X16_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_16X16_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_16X16_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_16X16_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X1_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X1_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X1_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X1_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X2_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X2_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X2_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X2_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X4_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X4_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X4_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X4_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X8_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X8_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X8_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X8_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=16) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X16_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X16_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X16_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X16_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=32) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X32_NEON(m, q, dim, out, ) -#elif defined(__AVX512F__) - ACCUM_FP32_32X32_AVX512(m, q, dim, out, ) -#elif defined(__AVX__) - ACCUM_FP32_32X32_AVX(m, q, dim, out, ) -#else - ACCUM_FP32_32X32_SSE(m, q, dim, out, ) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=1, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - *out = -InnerProductNEON(m, q, dim); -#else -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - if (dim > 15) { - *out = -InnerProductAVX512(m, q, dim); - return; - } - } -#endif // __AVX512F__ -#if defined(__AVX__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { - if (dim > 7) { - *out = -InnerProductAVX(m, q, dim); - return; - } - } -#endif // __AVX__ - *out = -InnerProductSSE(m, q, dim); -#endif // __ARM_NEON -} - -//! Compute the distance between matrix and query (FP32, M=2, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_2X1_NEON(m, q, dim, out, vneg_f32) -#elif defined(__AVX__) - ACCUM_FP32_2X1_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_FP32_2X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=2, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_2X2_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX__) - ACCUM_FP32_2X2_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_FP32_2X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X1_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX__) - ACCUM_FP32_4X1_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_FP32_4X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X2_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX__) - ACCUM_FP32_4X2_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_FP32_4X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=4, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_4X4_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX__) - ACCUM_FP32_4X4_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_FP32_4X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X1_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX__) - ACCUM_FP32_8X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_8X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X2_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX__) - ACCUM_FP32_8X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_8X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X4_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX__) - ACCUM_FP32_8X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_8X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=8, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_8X8_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX__) - ACCUM_FP32_8X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_8X8_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX__ -} - -//! Compute the distance between matrix and query (FP32, M=16, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X1_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_16X1_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#elif defined(__AVX__) - ACCUM_FP32_16X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_16X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X2_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_16X2_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#elif defined(__AVX__) - ACCUM_FP32_16X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_16X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X4_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_16X4_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#elif defined(__AVX__) - ACCUM_FP32_16X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_16X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X8_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_16X8_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#elif defined(__AVX__) - ACCUM_FP32_16X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_16X8_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=16, N=16) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_16X16_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_16X16_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#elif defined(__AVX__) - ACCUM_FP32_16X16_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_16X16_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X1_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X1_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#elif defined(__AVX__) - ACCUM_FP32_32X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_32X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X2_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X2_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#elif defined(__AVX__) - ACCUM_FP32_32X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_32X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X4_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X4_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#elif defined(__AVX__) - ACCUM_FP32_32X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_32X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X8_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X8_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#elif defined(__AVX__) - ACCUM_FP32_32X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_32X8_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=16) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X16_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X16_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#elif defined(__AVX__) - ACCUM_FP32_32X16_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_32X16_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif -} - -//! Compute the distance between matrix and query (FP32, M=32, N=32) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__ARM_NEON) - ACCUM_FP32_32X32_NEON(m, q, dim, out, vnegq_f32) -#elif defined(__AVX512F__) - ACCUM_FP32_32X32_AVX512(m, q, dim, out, NEGATE_FP32_AVX512) -#elif defined(__AVX__) - ACCUM_FP32_32X32_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_FP32_32X32_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif -} -#endif // __SSE__ || __ARM_NEON - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_fp32_avx.cc b/src/ailego/math/inner_product_matrix_fp32_avx.cc new file mode 100644 index 00000000..128adfdf --- /dev/null +++ b/src/ailego/math/inner_product_matrix_fp32_avx.cc @@ -0,0 +1,94 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_inner_product_utility.i" +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX__) +//! Inner Product +float InnerProductAVX(const float *lhs, const float *rhs, size_t size) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 4) << 4); + + __m256 ymm_sum_0 = _mm256_setzero_ps(); + __m256 ymm_sum_1 = _mm256_setzero_ps(); + + if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m256 ymm_lhs_0 = _mm256_load_ps(lhs + 0); + __m256 ymm_lhs_1 = _mm256_load_ps(lhs + 8); + __m256 ymm_rhs_0 = _mm256_load_ps(rhs + 0); + __m256 ymm_rhs_1 = _mm256_load_ps(rhs + 8); + ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); + } + + if (last >= last_aligned + 8) { + ymm_sum_0 = + _mm256_fmadd_ps(_mm256_load_ps(lhs), _mm256_load_ps(rhs), ymm_sum_0); + lhs += 8; + rhs += 8; + } + } else { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m256 ymm_lhs_0 = _mm256_loadu_ps(lhs + 0); + __m256 ymm_lhs_1 = _mm256_loadu_ps(lhs + 8); + __m256 ymm_rhs_0 = _mm256_loadu_ps(rhs + 0); + __m256 ymm_rhs_1 = _mm256_loadu_ps(rhs + 8); + ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); + } + + if (last >= last_aligned + 8) { + ymm_sum_0 = _mm256_fmadd_ps(_mm256_loadu_ps(lhs), _mm256_loadu_ps(rhs), + ymm_sum_0); + lhs += 8; + rhs += 8; + } + } + float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); + + switch (last - lhs) { + case 7: + FMA_FP32_GENERAL(lhs[6], rhs[6], result) + /* FALLTHRU */ + case 6: + FMA_FP32_GENERAL(lhs[5], rhs[5], result) + /* FALLTHRU */ + case 5: + FMA_FP32_GENERAL(lhs[4], rhs[4], result) + /* FALLTHRU */ + case 4: + FMA_FP32_GENERAL(lhs[3], rhs[3], result) + /* FALLTHRU */ + case 3: + FMA_FP32_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + FMA_FP32_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + FMA_FP32_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +#endif // __AVX__ + +} // namespace ailego +} // namespace zvec diff --git a/src/ailego/math/inner_product_matrix_fp32_avx512.cc b/src/ailego/math/inner_product_matrix_fp32_avx512.cc new file mode 100644 index 00000000..af3bf74c --- /dev/null +++ b/src/ailego/math/inner_product_matrix_fp32_avx512.cc @@ -0,0 +1,75 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_inner_product_utility.i" +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX512F__) +//! Inner Product +float InnerProductAVX512(const float *lhs, const float *rhs, size_t size) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 5) << 5); + + __m512 zmm_sum_0 = _mm512_setzero_ps(); + __m512 zmm_sum_1 = _mm512_setzero_ps(); + + if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + FMA_FP32_AVX512(_mm512_load_ps(lhs + 0), _mm512_load_ps(rhs + 0), + zmm_sum_0) + + FMA_FP32_AVX512(_mm512_load_ps(lhs + 16), _mm512_load_ps(rhs + 16), + zmm_sum_1) + } + + if (last >= last_aligned + 16) { + FMA_FP32_AVX512(_mm512_load_ps(lhs), _mm512_load_ps(rhs), zmm_sum_0) + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + FMA_FP32_AVX512(_mm512_loadu_ps(lhs + 0), _mm512_loadu_ps(rhs + 0), + zmm_sum_0) + + FMA_FP32_AVX512(_mm512_loadu_ps(lhs + 16), _mm512_loadu_ps(rhs + 16), + zmm_sum_1) + } + + if (last >= last_aligned + 16) { + FMA_FP32_AVX512(_mm512_loadu_ps(lhs), _mm512_loadu_ps(rhs), zmm_sum_0) + lhs += 16; + rhs += 16; + } + } + + zmm_sum_0 = _mm512_add_ps(zmm_sum_0, zmm_sum_1); + if (lhs != last) { + __mmask16 mask = (__mmask16)((1 << (last - lhs)) - 1); + __m512 zmm_undefined = _mm512_undefined_ps(); + zmm_sum_0 = _mm512_mask3_fmadd_ps( + _mm512_mask_loadu_ps(zmm_undefined, mask, lhs), + _mm512_mask_loadu_ps(zmm_undefined, mask, rhs), zmm_sum_0, mask); + } + return HorizontalAdd_FP32_V512(zmm_sum_0); +} + +#endif + +} // namespace ailego +} // namespace zvec diff --git a/src/ailego/math/inner_product_matrix_fp32_dispatch.cc b/src/ailego/math/inner_product_matrix_fp32_dispatch.cc new file mode 100644 index 00000000..57acef21 --- /dev/null +++ b/src/ailego/math/inner_product_matrix_fp32_dispatch.cc @@ -0,0 +1,97 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) +float InnerProductNEON(const float *lhs, const float *rhs, size_t size); +float MinusInnerProductNEON(const float *lhs, const float *rhs, size_t size); +#endif + +#if defined(__AVX512F__) +float InnerProductAVX512(const float *lhs, const float *rhs, size_t size); +#endif + +#if defined(__AVX__) +float InnerProductAVX(const float *lhs, const float *rhs, size_t size); +float MinusInnerProductAVX(const float *lhs, const float *rhs, size_t size); +#endif + +#if defined(__SSE__) +float InnerProductSSE(const float *lhs, const float *rhs, size_t size); +float MinusInnerProductSSE(const float *lhs, const float *rhs, size_t size); +#endif + +#if defined(__SSE__) || defined(__ARM_NEON) +//! Compute the distance between matrix and query (FP32, M=1, N=1) +void InnerProductMatrix::Compute(const ValueType *m, + const ValueType *q, size_t dim, + float *out) { +#if defined(__ARM_NEON) + *out = InnerProductNEON(m, q, dim); +#else +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + if (dim > 15) { + *out = InnerProductAVX512(m, q, dim); + return; + } + } +#endif // __AVX512F__ +#if defined(__AVX__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { + if (dim > 7) { + *out = InnerProductAVX(m, q, dim); + return; + } + } +#endif // __AVX__ + *out = InnerProductSSE(m, q, dim); +#endif // __ARM_NEON +} + +//! Compute the distance between matrix and query (FP32, M=1, N=1) +void MinusInnerProductMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, float *out) { +#if defined(__ARM_NEON) + *out = -InnerProductNEON(m, q, dim); +#else +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + if (dim > 15) { + *out = -InnerProductAVX512(m, q, dim); + return; + } + } +#endif // __AVX512F__ +#if defined(__AVX__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { + if (dim > 7) { + *out = -InnerProductAVX(m, q, dim); + return; + } + } +#endif // __AVX__ + *out = -InnerProductSSE(m, q, dim); +#endif // __ARM_NEON +} + +#endif +} // namespace ailego +} // namespace zvec diff --git a/src/ailego/math/inner_product_matrix_fp32_neon.cc b/src/ailego/math/inner_product_matrix_fp32_neon.cc new file mode 100644 index 00000000..e8626a3b --- /dev/null +++ b/src/ailego/math/inner_product_matrix_fp32_neon.cc @@ -0,0 +1,57 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_inner_product_utility.i" +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) +//! Inner Product +float InnerProductNEON(const float *lhs, const float *rhs, size_t size) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 3) << 3); + + float32x4_t v_sum_0 = vdupq_n_f32(0); + float32x4_t v_sum_1 = vdupq_n_f32(0); + + for (; lhs != last_aligned; lhs += 8, rhs += 8) { + v_sum_0 = vfmaq_f32(v_sum_0, vld1q_f32(lhs + 0), vld1q_f32(rhs + 0)); + v_sum_1 = vfmaq_f32(v_sum_1, vld1q_f32(lhs + 4), vld1q_f32(rhs + 4)); + } + if (last >= last_aligned + 4) { + v_sum_0 = vfmaq_f32(v_sum_0, vld1q_f32(lhs), vld1q_f32(rhs)); + lhs += 4; + rhs += 4; + } + + float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); + switch (last - lhs) { + case 3: + FMA_FP32_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + FMA_FP32_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + FMA_FP32_GENERAL(lhs[0], rhs[0], result) + } + return result; +} +#endif // __ARM_NEON + +} // namespace ailego +} // namespace zvec diff --git a/src/ailego/math/inner_product_matrix_fp32_sse.cc b/src/ailego/math/inner_product_matrix_fp32_sse.cc new file mode 100644 index 00000000..8a302bf9 --- /dev/null +++ b/src/ailego/math/inner_product_matrix_fp32_sse.cc @@ -0,0 +1,351 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_inner_product_utility.i" +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__SSE__) +//! Inner Product +float InnerProductSSE(const float *lhs, const float *rhs, size_t size) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 3) << 3); + + __m128 xmm_sum_0 = _mm_setzero_ps(); + __m128 xmm_sum_1 = _mm_setzero_ps(); + + if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { + for (; lhs != last_aligned; lhs += 8, rhs += 8) { + __m128 xmm_lhs_0 = _mm_load_ps(lhs + 0); + __m128 xmm_lhs_1 = _mm_load_ps(lhs + 4); + __m128 xmm_rhs_0 = _mm_load_ps(rhs + 0); + __m128 xmm_rhs_1 = _mm_load_ps(rhs + 4); + xmm_sum_0 = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum_0); + xmm_sum_1 = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum_1); + } + + if (last >= last_aligned + 4) { + xmm_sum_0 = _mm_fmadd_ps(_mm_load_ps(lhs), _mm_load_ps(rhs), xmm_sum_0); + lhs += 4; + rhs += 4; + } + } else { + for (; lhs != last_aligned; lhs += 8, rhs += 8) { + __m128 xmm_lhs_0 = _mm_loadu_ps(lhs + 0); + __m128 xmm_lhs_1 = _mm_loadu_ps(lhs + 4); + __m128 xmm_rhs_0 = _mm_loadu_ps(rhs + 0); + __m128 xmm_rhs_1 = _mm_loadu_ps(rhs + 4); + xmm_sum_0 = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum_0); + xmm_sum_1 = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum_1); + } + + if (last >= last_aligned + 4) { + xmm_sum_0 = _mm_fmadd_ps(_mm_loadu_ps(lhs), _mm_loadu_ps(rhs), xmm_sum_0); + lhs += 4; + rhs += 4; + } + } + float result = HorizontalAdd_FP32_V128(_mm_add_ps(xmm_sum_0, xmm_sum_1)); + + switch (last - lhs) { + case 3: + FMA_FP32_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + FMA_FP32_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + FMA_FP32_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +#endif // __SSE__ + +// #if 1 +#if defined(__SSE4_1__) +const static __m128i SHUFFLE_MASK16[16] = { + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, -127, -127), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 13, 12), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, + 4), + _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), +}; + +constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; + +float InnerProductSparseInSegmentSSE(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const float *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const float *q_sparse_value) { + float sum = 0.0f; + + // handle if the first dim is zero + bool m_zero = false; + float m_zero_value = 0.0f; + if (m_sparse_count > 0 && m_sparse_index[0] == 0) { + m_sparse_count--; + m_sparse_index++; + m_zero_value = *m_sparse_value++; + m_zero = true; + } + + bool q_zero = false; + float q_zero_value = 0.0f; + if (q_sparse_count > 0 && q_sparse_index[0] == 0) { + q_sparse_count--; + q_sparse_index++; + q_zero_value = *q_sparse_value++; + q_zero = true; + } + + if (m_zero && q_zero) { + sum = m_zero_value * q_zero_value; + } + + size_t i1 = 0, i2 = 0; + size_t end1 = m_sparse_count / 8 * 8; + size_t end2 = q_sparse_count / 8 * 8; + + // std::vector mem1; + // std::vector mem2; + + float fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; + float fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; + + float *val_start_1 = fixed_buffer_1; + float *val_start_2 = fixed_buffer_2; + + // uint32_t max_count = std::max(m_sparse_count, q_sparse_count); + + // if (MAX_SPARSE_BUFFER_LENGTH < max_count) { + // mem1.reserve(max_count); + // mem2.reserve(max_count); + + // val_start_1 = mem1.data(); + // val_start_2 = mem2.data(); + // } + + float *val_1 = val_start_1; + float *val_2 = val_start_2; + + if (i1 < end1 && i2 < end2) { + while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { + i1 += 8; + if (i1 >= end1) goto do_scalar; + } + + while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { + i2 += 8; + if (i2 >= end2) goto do_scalar; + } + + __m128i mm_index_m = + _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); + __m128i mm_index_q = + _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); + + while (true) { +#ifdef DEBUG_PRINT + std::cout << "index 1: " << std::endl; + print_data16(&mm_index_m); + + std::cout << "index 2: " << std::endl; + print_data16(&mm_index_q); +#endif + + __m128i mm_cmp_res = + _mm_cmpistrm(mm_index_q, mm_index_m, + _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); + +#ifdef DEBUG_PRINT + std::cout << "cmp res: " << std::endl; + print_data16(&mm_cmp_res); +#endif + + int r = _mm_extract_epi32(mm_cmp_res, 0); + + if (r) { + int r1 = r & 15; + + __m128i v = _mm_loadu_si128( + reinterpret_cast(&m_sparse_value[i1])); + __m128 vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); + + _mm_storeu_ps(val_1, vs); + val_1 += _mm_popcnt_u32(r1); + + int r2 = (r >> 4) & 15; + v = _mm_loadu_si128( + reinterpret_cast(&m_sparse_value[i1 + 4])); + vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); + _mm_storeu_ps(val_1, vs); + val_1 += _mm_popcnt_u32(r2); + + mm_cmp_res = _mm_cmpistrm( + mm_index_m, mm_index_q, + _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); + r = _mm_extract_epi32(mm_cmp_res, 0); + + r1 = r & 15; + + v = _mm_loadu_si128( + reinterpret_cast(&q_sparse_value[i2])); + vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); + _mm_storeu_ps(val_2, vs); + val_2 += _mm_popcnt_u32(r1); + + r2 = (r >> 4) & 15; + v = _mm_loadu_si128( + reinterpret_cast(&q_sparse_value[i2 + 4])); + vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); + _mm_storeu_ps(val_2, vs); + val_2 += _mm_popcnt_u32(r2); + } + + const uint16_t id1_max = m_sparse_index[i1 + 7]; + + if (id1_max <= q_sparse_index[i2 + 7]) { + i1 += 8; + if (i1 >= end1) goto do_scalar; + mm_index_m = _mm_loadu_si128( + reinterpret_cast(&m_sparse_index[i1])); + } + + if (id1_max >= q_sparse_index[i2 + 7]) { + i2 += 8; + if (i2 >= end2) goto do_scalar; + mm_index_q = _mm_loadu_si128( + reinterpret_cast(&q_sparse_index[i2])); + } + } + } + +do_scalar: + while (i1 < m_sparse_count && i2 < q_sparse_count) { + if (m_sparse_index[i1] == q_sparse_index[i2]) { + *val_1++ = m_sparse_value[i1]; + *val_2++ = q_sparse_value[i2]; + + ++i1; + ++i2; + } else if (m_sparse_index[i1] < q_sparse_index[i2]) { + ++i1; + } else { + ++i2; + } + } + + size_t res_num = val_1 - val_start_1; + + // if (res_num != val_2 - val_start_2) { + // std::cerr << "size mismatch!" << std::endl; + // } + + size_t res_num4 = res_num / 4 * 4; + + if (res_num4) { + __m128 sum128 = _mm_set1_ps(0); + + for (size_t k = 0; k < res_num4; k += 4) { + sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(val_start_1 + k), + _mm_loadu_ps(val_start_2 + k))); + } + + float __attribute__((aligned(16))) tmp_res[4]; + _mm_store_ps(tmp_res, sum128); + sum += (tmp_res[0] + tmp_res[1] + tmp_res[2] + tmp_res[3]); + } + + for (size_t k = res_num4; k < res_num; ++k) + sum += val_start_1[k] * val_start_2[k]; + + return sum; +} +#else +float InnerProductSparseInSegment(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const float *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const float *q_sparse_value) { + float sum = 0.0f; + + size_t m_i = 0; + size_t q_i = 0; + while (m_i < m_sparse_count && q_i < q_sparse_count) { + if (m_sparse_index[m_i] == q_sparse_index[q_i]) { + sum += m_sparse_value[m_i] * q_sparse_value[q_i]; + + ++m_i; + ++q_i; + } else if (m_sparse_index[m_i] < q_sparse_index[q_i]) { + ++m_i; + } else { + ++q_i; + } + } + + return sum; +} +#endif // __SSE4_1__ + +template <> +float MinusInnerProductSparseMatrix::ComputeInnerProductSparseInSegment( + uint32_t m_sparse_count, const uint16_t *m_sparse_index, + const ValueType *m_sparse_value, uint32_t q_sparse_count, + const uint16_t *q_sparse_index, const ValueType *q_sparse_value) { +#if defined(__SSE4_1__) + return InnerProductSparseInSegmentSSE(m_sparse_count, m_sparse_index, + m_sparse_value, q_sparse_count, + q_sparse_index, q_sparse_value); +#else + return InnerProductSparseInSegment(m_sparse_count, m_sparse_index, + m_sparse_value, q_sparse_count, + q_sparse_index, q_sparse_value); +#endif +} + +} // namespace ailego +} // namespace zvec diff --git a/src/ailego/math/inner_product_matrix_int4.cc b/src/ailego/math/inner_product_matrix_int4.cc deleted file mode 100644 index 87a82e80..00000000 --- a/src/ailego/math/inner_product_matrix_int4.cc +++ /dev/null @@ -1,803 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "distance_matrix_accum_int4.i" -#include "inner_product_matrix.h" - -namespace zvec { -namespace ailego { - -#define ACCUM_INT4_STEP_SSE FMA_INT4_SSE -#define ACCUM_INT4_STEP_AVX FMA_INT4_AVX - -#if defined(__SSE4_1__) -//! Four-bits Convert Table -static const AILEGO_ALIGNED(32) int8_t Int4ConvertTable[32] = { - 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1, - 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1}; -#endif // __SSE4_1__ - -#if defined(__SSE4_1__) -static const __m128 NEGZEROS_FP32_SSE = _mm_set1_ps(-0.0f); -static const __m128i MASK_INT4_SSE = _mm_set1_epi32(0x0f0f0f0f); -static const __m128i ONES_INT16_SSE = _mm_set1_epi32(0x00010001); -static const __m128i INT4_LOOKUP_SSE = - _mm_load_si128((const __m128i *)Int4ConvertTable); -#endif // __SSE4_1__ - -#if defined(__AVX2__) -static const __m256 NEGZEROS_FP32_AVX = _mm256_set1_ps(-0.0f); -static const __m256i MASK_INT4_AVX = _mm256_set1_epi32(0x0f0f0f0f); -static const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); -static const __m256i INT4_LOOKUP_AVX = - _mm256_load_si256((const __m256i *)Int4ConvertTable); -#endif // __AVX2__ - -//! Calculate Fused-Multiply-Add (GENERAL) -#define FMA_INT4_GENERAL(m, q, sum) \ - sum += Int4MulTable[(((m) << 4) & 0xf0) | (((q) >> 0) & 0xf)] + \ - Int4MulTable[(((m) >> 0) & 0xf0) | (((q) >> 4) & 0xf)]; - -//! Calculate Fused-Multiply-Add (SSE) -#define FMA_INT4_SSE(xmm_m, xmm_q, xmm_sum) \ - { \ - __m128i xmm_lhs = _mm_shuffle_epi8(INT4_LOOKUP_SSE, \ - _mm_and_si128((xmm_m), MASK_INT4_SSE)); \ - __m128i xmm_rhs = _mm_shuffle_epi8(INT4_LOOKUP_SSE, \ - _mm_and_si128((xmm_q), MASK_INT4_SSE)); \ - xmm_sum = _mm_add_epi32( \ - _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ - _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ - ONES_INT16_SSE), \ - xmm_sum); \ - xmm_lhs = _mm_shuffle_epi8( \ - INT4_LOOKUP_SSE, \ - _mm_and_si128(_mm_srli_epi32((xmm_m), 4), MASK_INT4_SSE)); \ - xmm_rhs = _mm_shuffle_epi8( \ - INT4_LOOKUP_SSE, \ - _mm_and_si128(_mm_srli_epi32((xmm_q), 4), MASK_INT4_SSE)); \ - xmm_sum = _mm_add_epi32( \ - _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ - _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ - ONES_INT16_SSE), \ - xmm_sum); \ - } - -//! Calculate Fused-Multiply-Add (AVX) -#define FMA_INT4_AVX(ymm_m, ymm_q, ymm_sum) \ - { \ - __m256i ymm_lhs = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, _mm256_and_si256((ymm_m), MASK_INT4_AVX)); \ - __m256i ymm_rhs = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, _mm256_and_si256((ymm_q), MASK_INT4_AVX)); \ - ymm_sum = _mm256_add_epi32( \ - _mm256_madd_epi16( \ - _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_rhs), \ - _mm256_sign_epi8(ymm_lhs, ymm_rhs)), \ - ONES_INT16_AVX), \ - ymm_sum); \ - ymm_lhs = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, \ - _mm256_and_si256(_mm256_srli_epi32((ymm_m), 4), MASK_INT4_AVX)); \ - ymm_rhs = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, \ - _mm256_and_si256(_mm256_srli_epi32((ymm_q), 4), MASK_INT4_AVX)); \ - ymm_sum = _mm256_add_epi32( \ - _mm256_madd_epi16( \ - _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_rhs), \ - _mm256_sign_epi8(ymm_lhs, ymm_rhs)), \ - ONES_INT16_AVX), \ - ymm_sum); \ - } - -//! Compute the distance between matrix and query -#define FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) \ - { \ - __m128i xmm_lhs_0 = _mm_shuffle_epi8( \ - INT4_LOOKUP_SSE, _mm_and_si128((xmm_lhs), MASK_INT4_SSE)); \ - __m128i xmm_rhs_0 = _mm_shuffle_epi8( \ - INT4_LOOKUP_SSE, _mm_and_si128((xmm_rhs), MASK_INT4_SSE)); \ - __m128i xmm_lhs_1 = _mm_shuffle_epi8( \ - INT4_LOOKUP_SSE, \ - _mm_and_si128(_mm_srli_epi32((xmm_lhs), 4), MASK_INT4_SSE)); \ - __m128i xmm_rhs_1 = _mm_shuffle_epi8( \ - INT4_LOOKUP_SSE, \ - _mm_and_si128(_mm_srli_epi32((xmm_rhs), 4), MASK_INT4_SSE)); \ - xmm_lhs_0 = _mm_sign_epi8(xmm_lhs_0, xmm_rhs_0); \ - xmm_lhs_1 = _mm_sign_epi8(xmm_lhs_1, xmm_rhs_1); \ - xmm_rhs_0 = _mm_abs_epi8(xmm_rhs_0); \ - xmm_rhs_1 = _mm_abs_epi8(xmm_rhs_1); \ - xmm_lhs_0 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_lhs_0), \ - ONES_INT16_SSE); \ - xmm_lhs_1 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_1, xmm_lhs_1), \ - ONES_INT16_SSE); \ - xmm_sum = _mm_add_epi32(_mm_add_epi32(xmm_lhs_0, xmm_lhs_1), xmm_sum); \ - } - -//! Compute the distance between matrix and query -#define FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) \ - { \ - __m256i ymm_lhs_0 = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, _mm256_and_si256((ymm_lhs), MASK_INT4_AVX)); \ - __m256i ymm_rhs_0 = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, _mm256_and_si256((ymm_rhs), MASK_INT4_AVX)); \ - __m256i ymm_lhs_1 = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, \ - _mm256_and_si256(_mm256_srli_epi32((ymm_lhs), 4), MASK_INT4_AVX)); \ - __m256i ymm_rhs_1 = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, \ - _mm256_and_si256(_mm256_srli_epi32((ymm_rhs), 4), MASK_INT4_AVX)); \ - ymm_lhs_0 = _mm256_sign_epi8(ymm_lhs_0, ymm_rhs_0); \ - ymm_lhs_1 = _mm256_sign_epi8(ymm_lhs_1, ymm_rhs_1); \ - ymm_rhs_0 = _mm256_abs_epi8(ymm_rhs_0); \ - ymm_rhs_1 = _mm256_abs_epi8(ymm_rhs_1); \ - ymm_lhs_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_lhs_0), \ - ONES_INT16_AVX); \ - ymm_lhs_1 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_1, ymm_lhs_1), \ - ONES_INT16_AVX); \ - ymm_sum = \ - _mm256_add_epi32(_mm256_add_epi32(ymm_lhs_0, ymm_lhs_1), ymm_sum); \ - } - -//! Reverse sign of value (SSE) -#define NEGATE_FP32_SSE(v, ...) \ - _mm_xor_ps(_mm_cvtepi32_ps(v), NEGZEROS_FP32_SSE) - -//! Reverse sign of value (AVX) -#define NEGATE_FP32_AVX(v, ...) \ - _mm256_xor_ps(_mm256_cvtepi32_ps(v), NEGZEROS_FP32_AVX) - -//! Reverse sign of value (AVX512) -#define NEGATE_FP32_AVX512(v, ...) \ - _mm512_xor_ps(_mm512_cvtepi32_ps(v), NEGZEROS_FP32_AVX512) - -#if defined(__SSE4_1__) -//! Inner Product -static inline float InnerProductSSE(const uint8_t *lhs, const uint8_t *rhs, - size_t size) { - const uint8_t *last = lhs + size; - const uint8_t *last_aligned = lhs + ((size >> 4) << 4); - __m128i xmm_sum = _mm_setzero_si128(); - - if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m128i xmm_lhs = _mm_load_si128((const __m128i *)(lhs)); - __m128i xmm_rhs = _mm_load_si128((const __m128i *)(rhs)); - FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) - } - } else { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)(lhs)); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)(rhs)); - FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) - } - } - float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); - - switch (last - lhs) { - case 15: - FMA_INT4_GENERAL(lhs[14], rhs[14], result) - /* FALLTHRU */ - case 14: - FMA_INT4_GENERAL(lhs[13], rhs[13], result) - /* FALLTHRU */ - case 13: - FMA_INT4_GENERAL(lhs[12], rhs[12], result) - /* FALLTHRU */ - case 12: - FMA_INT4_GENERAL(lhs[11], rhs[11], result) - /* FALLTHRU */ - case 11: - FMA_INT4_GENERAL(lhs[10], rhs[10], result) - /* FALLTHRU */ - case 10: - FMA_INT4_GENERAL(lhs[9], rhs[9], result) - /* FALLTHRU */ - case 9: - FMA_INT4_GENERAL(lhs[8], rhs[8], result) - /* FALLTHRU */ - case 8: - FMA_INT4_GENERAL(lhs[7], rhs[7], result) - /* FALLTHRU */ - case 7: - FMA_INT4_GENERAL(lhs[6], rhs[6], result) - /* FALLTHRU */ - case 6: - FMA_INT4_GENERAL(lhs[5], rhs[5], result) - /* FALLTHRU */ - case 5: - FMA_INT4_GENERAL(lhs[4], rhs[4], result) - /* FALLTHRU */ - case 4: - FMA_INT4_GENERAL(lhs[3], rhs[3], result) - /* FALLTHRU */ - case 3: - FMA_INT4_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - FMA_INT4_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - FMA_INT4_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __SSE4_1__ - -#if defined(__AVX2__) -//! Inner Product -static inline float InnerProductAVX(const uint8_t *lhs, const uint8_t *rhs, - size_t size) { - const uint8_t *last = lhs + size; - const uint8_t *last_aligned = lhs + ((size >> 5) << 5); - __m256i ymm_sum = _mm256_setzero_si256(); - - if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m256i ymm_lhs = _mm256_load_si256((const __m256i *)(lhs)); - __m256i ymm_rhs = _mm256_load_si256((const __m256i *)(rhs)); - FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) - } - - if (last >= lhs + 16) { - __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); - __m128i xmm_sum = _mm_setzero_si128(); - FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) - ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), - ymm_sum); - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)(lhs)); - __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)(rhs)); - FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) - } - - if (last >= lhs + 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); - __m128i xmm_sum = _mm_setzero_si128(); - FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) - ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), - ymm_sum); - lhs += 16; - rhs += 16; - } - } - float result = static_cast(HorizontalAdd_INT32_V256(ymm_sum)); - - switch (last - lhs) { - case 15: - FMA_INT4_GENERAL(lhs[14], rhs[14], result) - /* FALLTHRU */ - case 14: - FMA_INT4_GENERAL(lhs[13], rhs[13], result) - /* FALLTHRU */ - case 13: - FMA_INT4_GENERAL(lhs[12], rhs[12], result) - /* FALLTHRU */ - case 12: - FMA_INT4_GENERAL(lhs[11], rhs[11], result) - /* FALLTHRU */ - case 11: - FMA_INT4_GENERAL(lhs[10], rhs[10], result) - /* FALLTHRU */ - case 10: - FMA_INT4_GENERAL(lhs[9], rhs[9], result) - /* FALLTHRU */ - case 9: - FMA_INT4_GENERAL(lhs[8], rhs[8], result) - /* FALLTHRU */ - case 8: - FMA_INT4_GENERAL(lhs[7], rhs[7], result) - /* FALLTHRU */ - case 7: - FMA_INT4_GENERAL(lhs[6], rhs[6], result) - /* FALLTHRU */ - case 6: - FMA_INT4_GENERAL(lhs[5], rhs[5], result) - /* FALLTHRU */ - case 5: - FMA_INT4_GENERAL(lhs[4], rhs[4], result) - /* FALLTHRU */ - case 4: - FMA_INT4_GENERAL(lhs[3], rhs[3], result) - /* FALLTHRU */ - case 3: - FMA_INT4_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - FMA_INT4_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - FMA_INT4_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __AVX2__ - -#if defined(__SSE4_1__) -//! Compute the distance between matrix and query (INT4, M=1, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - if (dim > 63) { - *out = InnerProductAVX(m, q, dim >> 1); - return; - } -#endif // __AVX2__ - *out = InnerProductSSE(m, q, dim >> 1); -} - -//! Compute the distance between matrix and query (INT4, M=2, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_2X1_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT4_2X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=2, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_2X2_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT4_2X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X1_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT4_4X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X2_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT4_4X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X4_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT4_4X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_8X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_8X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_8X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_8X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_16X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_16X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_16X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_16X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=16) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X16_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_16X16_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=16) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X16_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X16_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=32) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X32_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT4_32X32_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=1, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - if (dim > 63) { - *out = -InnerProductAVX(m, q, dim >> 1); - return; - } -#endif // __AVX2__ - *out = -InnerProductSSE(m, q, dim >> 1); -} - -//! Compute the distance between matrix and query (INT4, M=2, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_2X1_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_INT4_2X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=2, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_2X2_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_INT4_2X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X1_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_INT4_4X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X2_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_INT4_4X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=4, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_4X4_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_INT4_4X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_8X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_8X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_8X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=8, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_8X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_8X8_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_16X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_16X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_16X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_16X8_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=16, N=16) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_16X16_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_16X16_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_32X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_32X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_32X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_32X8_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=16) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X16_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_32X16_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT4, M=32, N=32) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT4_32X32_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT4_32X32_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} -#endif // __SSE4_1__ - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_int4_avx2.cc b/src/ailego/math/inner_product_matrix_int4_avx2.cc new file mode 100644 index 00000000..f69864aa --- /dev/null +++ b/src/ailego/math/inner_product_matrix_int4_avx2.cc @@ -0,0 +1,123 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int4.i" +#include "distance_matrix_inner_product_utility.i" +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX2__) +//! Inner Product +float InnerProductAVX2(const uint8_t *lhs, const uint8_t *rhs, size_t size) { + const uint8_t *last = lhs + size; + const uint8_t *last_aligned = lhs + ((size >> 5) << 5); + __m256i ymm_sum = _mm256_setzero_si256(); + + if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m256i ymm_lhs = _mm256_load_si256((const __m256i *)(lhs)); + __m256i ymm_rhs = _mm256_load_si256((const __m256i *)(rhs)); + FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) + } + + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); + __m128i xmm_sum = _mm_setzero_si128(); + FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) + ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), + ymm_sum); + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)(lhs)); + __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)(rhs)); + FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) + } + + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); + __m128i xmm_sum = _mm_setzero_si128(); + FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) + ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), + ymm_sum); + lhs += 16; + rhs += 16; + } + } + float result = static_cast(HorizontalAdd_INT32_V256(ymm_sum)); + + switch (last - lhs) { + case 15: + FMA_INT4_GENERAL(lhs[14], rhs[14], result) + /* FALLTHRU */ + case 14: + FMA_INT4_GENERAL(lhs[13], rhs[13], result) + /* FALLTHRU */ + case 13: + FMA_INT4_GENERAL(lhs[12], rhs[12], result) + /* FALLTHRU */ + case 12: + FMA_INT4_GENERAL(lhs[11], rhs[11], result) + /* FALLTHRU */ + case 11: + FMA_INT4_GENERAL(lhs[10], rhs[10], result) + /* FALLTHRU */ + case 10: + FMA_INT4_GENERAL(lhs[9], rhs[9], result) + /* FALLTHRU */ + case 9: + FMA_INT4_GENERAL(lhs[8], rhs[8], result) + /* FALLTHRU */ + case 8: + FMA_INT4_GENERAL(lhs[7], rhs[7], result) + /* FALLTHRU */ + case 7: + FMA_INT4_GENERAL(lhs[6], rhs[6], result) + /* FALLTHRU */ + case 6: + FMA_INT4_GENERAL(lhs[5], rhs[5], result) + /* FALLTHRU */ + case 5: + FMA_INT4_GENERAL(lhs[4], rhs[4], result) + /* FALLTHRU */ + case 4: + FMA_INT4_GENERAL(lhs[3], rhs[3], result) + /* FALLTHRU */ + case 3: + FMA_INT4_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + FMA_INT4_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + FMA_INT4_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +float MinusInnerProductAVX2(const uint8_t *lhs, const uint8_t *rhs, + size_t size) { + return -InnerProductAVX2(lhs, rhs, size); +} + +#endif // __AVX2__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_int4_dispatch.cc b/src/ailego/math/inner_product_matrix_int4_dispatch.cc new file mode 100644 index 00000000..f26946d3 --- /dev/null +++ b/src/ailego/math/inner_product_matrix_int4_dispatch.cc @@ -0,0 +1,62 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX2__) +float InnerProductAVX2(const uint8_t *lhs, const uint8_t *rhs, size_t size); +float MinusInnerProductAVX2(const uint8_t *lhs, const uint8_t *rhs, + size_t size); +#endif + +#if defined(__SSE4_1__) +float InnerProductSSE(const uint8_t *lhs, const uint8_t *rhs, size_t size); +float MinusInnerProductSSE(const uint8_t *lhs, const uint8_t *rhs, size_t size); +#endif + +#if defined(__SSE4_1__) +//! Compute the distance between matrix and query (INT4, M=1, N=1) +void InnerProductMatrix::Compute(const ValueType *m, + const ValueType *q, size_t dim, + float *out) { +#if defined(__AVX2__) + if (dim > 63) { + *out = InnerProductAVX2(m, q, dim >> 1); + return; + } +#endif // __AVX2__ + *out = InnerProductSSE(m, q, dim >> 1); +} + +//! Compute the distance between matrix and query (INT4, M=1, N=1) +void MinusInnerProductMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, float *out) { +#if defined(__AVX2__) + if (dim > 63) { + *out = MinusInnerProductAVX2(m, q, dim >> 1); + return; + } +#endif // __AVX2__ + *out = MinusInnerProductSSE(m, q, dim >> 1); +} + +#endif //__SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_int4_sse.cc b/src/ailego/math/inner_product_matrix_int4_sse.cc new file mode 100644 index 00000000..11590bd5 --- /dev/null +++ b/src/ailego/math/inner_product_matrix_int4_sse.cc @@ -0,0 +1,101 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int4.i" +#include "distance_matrix_inner_product_utility.i" +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__SSE4_1__) +//! Inner Product +float InnerProductSSE(const uint8_t *lhs, const uint8_t *rhs, size_t size) { + const uint8_t *last = lhs + size; + const uint8_t *last_aligned = lhs + ((size >> 4) << 4); + __m128i xmm_sum = _mm_setzero_si128(); + + if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)(lhs)); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)(rhs)); + FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) + } + } else { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)(lhs)); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)(rhs)); + FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) + } + } + float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); + + switch (last - lhs) { + case 15: + FMA_INT4_GENERAL(lhs[14], rhs[14], result) + /* FALLTHRU */ + case 14: + FMA_INT4_GENERAL(lhs[13], rhs[13], result) + /* FALLTHRU */ + case 13: + FMA_INT4_GENERAL(lhs[12], rhs[12], result) + /* FALLTHRU */ + case 12: + FMA_INT4_GENERAL(lhs[11], rhs[11], result) + /* FALLTHRU */ + case 11: + FMA_INT4_GENERAL(lhs[10], rhs[10], result) + /* FALLTHRU */ + case 10: + FMA_INT4_GENERAL(lhs[9], rhs[9], result) + /* FALLTHRU */ + case 9: + FMA_INT4_GENERAL(lhs[8], rhs[8], result) + /* FALLTHRU */ + case 8: + FMA_INT4_GENERAL(lhs[7], rhs[7], result) + /* FALLTHRU */ + case 7: + FMA_INT4_GENERAL(lhs[6], rhs[6], result) + /* FALLTHRU */ + case 6: + FMA_INT4_GENERAL(lhs[5], rhs[5], result) + /* FALLTHRU */ + case 5: + FMA_INT4_GENERAL(lhs[4], rhs[4], result) + /* FALLTHRU */ + case 4: + FMA_INT4_GENERAL(lhs[3], rhs[3], result) + /* FALLTHRU */ + case 3: + FMA_INT4_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + FMA_INT4_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + FMA_INT4_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +float MinusInnerProductSSE(const uint8_t *lhs, const uint8_t *rhs, + size_t size) { + return -InnerProductSSE(lhs, rhs, size); +} + +#endif // __SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_int8.cc b/src/ailego/math/inner_product_matrix_int8.cc deleted file mode 100644 index a307f696..00000000 --- a/src/ailego/math/inner_product_matrix_int8.cc +++ /dev/null @@ -1,841 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "distance_matrix_accum_int8.i" -#include "inner_product_matrix.h" - -namespace zvec { -namespace ailego { - -#define ACCUM_INT8_STEP_SSE FMA_INT8_SSE -#define ACCUM_INT8_STEP_AVX FMA_INT8_AVX - -#if defined(__AVX512F__) && !defined(__AVX512DQ__) -#define _mm512_xor_ps(a, b) \ - _mm512_castsi512_ps( \ - _mm512_xor_epi32(_mm512_castps_si512(a), _mm512_castps_si512(b))) -#endif // __AVX512DQ__ - -#if defined(__SSE__) -static const __m128 NEGZEROS_FP32_SSE = _mm_set1_ps(-0.0f); -#endif // __SSE__ - -#if defined(__AVX__) -static const __m256 NEGZEROS_FP32_AVX = _mm256_set1_ps(-0.0f); -#endif // __AVX__ - -#if defined(__AVX512F__) -static const __m512 NEGZEROS_FP32_AVX512 = _mm512_set1_ps(-0.0f); -#endif // __AVX512F__ - -#if defined(__SSE4_1__) -static const __m128i ONES_INT16_SSE = _mm_set1_epi32(0x00010001); -#endif // __SSE4_1__ - -#if defined(__AVX2__) -static const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); -#endif // __AVX2__ - -//! Reverse sign of value (SSE) -#define NEGATE_FP32_SSE(v, ...) \ - _mm_xor_ps(_mm_cvtepi32_ps(v), NEGZEROS_FP32_SSE) - -//! Reverse sign of value (AVX) -#define NEGATE_FP32_AVX(v, ...) \ - _mm256_xor_ps(_mm256_cvtepi32_ps(v), NEGZEROS_FP32_AVX) - -//! Reverse sign of value (AVX512) -#define NEGATE_FP32_AVX512(v, ...) \ - _mm512_xor_ps(_mm512_cvtepi32_ps(v), NEGZEROS_FP32_AVX512) - -//! Calculate Fused-Multiply-Add (GENERAL) -#define FMA_INT8_GENERAL(m, q, sum) sum += static_cast(m * q); - -//! Calculate Fused-Multiply-Add (SSE) -#define FMA_INT8_SSE(xmm_m, xmm_q, xmm_sum) \ - xmm_sum = _mm_add_epi32( \ - _mm_madd_epi16( \ - _mm_maddubs_epi16(_mm_abs_epi8(xmm_q), _mm_sign_epi8(xmm_m, xmm_q)), \ - ONES_INT16_SSE), \ - xmm_sum); - -//! Calculate Fused-Multiply-Add (AVX) -#define FMA_INT8_AVX(ymm_m, ymm_q, ymm_sum) \ - ymm_sum = _mm256_add_epi32( \ - _mm256_madd_epi16(_mm256_maddubs_epi16(_mm256_abs_epi8(ymm_q), \ - _mm256_sign_epi8(ymm_m, ymm_q)), \ - ONES_INT16_AVX), \ - ymm_sum); - -#if defined(__SSE4_1__) -//! Inner Product -static inline float InnerProductSSE(const int8_t *lhs, const int8_t *rhs, - size_t size) { - const int8_t *last = lhs + size; - const int8_t *last_aligned = lhs + ((size >> 5) << 5); - - __m128i xmm_sum_0 = _mm_setzero_si128(); - __m128i xmm_sum_1 = _mm_setzero_si128(); - - if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m128i xmm_lhs_0 = _mm_load_si128((const __m128i *)(lhs + 0)); - __m128i xmm_lhs_1 = _mm_load_si128((const __m128i *)(lhs + 16)); - __m128i xmm_rhs_0 = _mm_load_si128((const __m128i *)(rhs + 0)); - __m128i xmm_rhs_1 = _mm_load_si128((const __m128i *)(rhs + 16)); - - xmm_lhs_0 = _mm_sign_epi8(xmm_lhs_0, xmm_rhs_0); - xmm_lhs_1 = _mm_sign_epi8(xmm_lhs_1, xmm_rhs_1); - xmm_rhs_0 = _mm_abs_epi8(xmm_rhs_0); - xmm_rhs_1 = _mm_abs_epi8(xmm_rhs_1); - xmm_sum_0 = - _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_lhs_0), - ONES_INT16_SSE), - xmm_sum_0); - xmm_sum_1 = - _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_1, xmm_lhs_1), - ONES_INT16_SSE), - xmm_sum_1); - } - - if (last >= last_aligned + 16) { - __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); - - xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); - xmm_rhs = _mm_abs_epi8(xmm_rhs); - xmm_sum_0 = _mm_add_epi32( - _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), ONES_INT16_SSE), - xmm_sum_0); - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m128i xmm_lhs_0 = _mm_loadu_si128((const __m128i *)(lhs + 0)); - __m128i xmm_lhs_1 = _mm_loadu_si128((const __m128i *)(lhs + 16)); - __m128i xmm_rhs_0 = _mm_loadu_si128((const __m128i *)(rhs + 0)); - __m128i xmm_rhs_1 = _mm_loadu_si128((const __m128i *)(rhs + 16)); - - xmm_lhs_0 = _mm_sign_epi8(xmm_lhs_0, xmm_rhs_0); - xmm_lhs_1 = _mm_sign_epi8(xmm_lhs_1, xmm_rhs_1); - xmm_rhs_0 = _mm_abs_epi8(xmm_rhs_0); - xmm_rhs_1 = _mm_abs_epi8(xmm_rhs_1); - xmm_sum_0 = - _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_lhs_0), - ONES_INT16_SSE), - xmm_sum_0); - xmm_sum_1 = - _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_1, xmm_lhs_1), - ONES_INT16_SSE), - xmm_sum_1); - } - - if (last >= last_aligned + 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); - - xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); - xmm_rhs = _mm_abs_epi8(xmm_rhs); - xmm_sum_0 = _mm_add_epi32( - _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), ONES_INT16_SSE), - xmm_sum_0); - lhs += 16; - rhs += 16; - } - } - float result = static_cast( - HorizontalAdd_INT32_V128(_mm_add_epi32(xmm_sum_0, xmm_sum_1))); - - switch (last - lhs) { - case 15: - FMA_INT8_GENERAL(lhs[14], rhs[14], result) - /* FALLTHRU */ - case 14: - FMA_INT8_GENERAL(lhs[13], rhs[13], result) - /* FALLTHRU */ - case 13: - FMA_INT8_GENERAL(lhs[12], rhs[12], result) - /* FALLTHRU */ - case 12: - FMA_INT8_GENERAL(lhs[11], rhs[11], result) - /* FALLTHRU */ - case 11: - FMA_INT8_GENERAL(lhs[10], rhs[10], result) - /* FALLTHRU */ - case 10: - FMA_INT8_GENERAL(lhs[9], rhs[9], result) - /* FALLTHRU */ - case 9: - FMA_INT8_GENERAL(lhs[8], rhs[8], result) - /* FALLTHRU */ - case 8: - FMA_INT8_GENERAL(lhs[7], rhs[7], result) - /* FALLTHRU */ - case 7: - FMA_INT8_GENERAL(lhs[6], rhs[6], result) - /* FALLTHRU */ - case 6: - FMA_INT8_GENERAL(lhs[5], rhs[5], result) - /* FALLTHRU */ - case 5: - FMA_INT8_GENERAL(lhs[4], rhs[4], result) - /* FALLTHRU */ - case 4: - FMA_INT8_GENERAL(lhs[3], rhs[3], result) - /* FALLTHRU */ - case 3: - FMA_INT8_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - FMA_INT8_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - FMA_INT8_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __SSE4_1__ - -#if defined(__AVX2__) -//! Inner Product -static inline float InnerProductAVX(const int8_t *lhs, const int8_t *rhs, - size_t size) { - const int8_t *last = lhs + size; - const int8_t *last_aligned = lhs + ((size >> 6) << 6); - float result = 0.0; - - __m256i ymm_sum_0 = _mm256_setzero_si256(); - __m256i ymm_sum_1 = _mm256_setzero_si256(); - - if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { - for (; lhs != last_aligned; lhs += 64, rhs += 64) { - __m256i ymm_lhs_0 = _mm256_load_si256((const __m256i *)(lhs + 0)); - __m256i ymm_lhs_1 = _mm256_load_si256((const __m256i *)(lhs + 32)); - __m256i ymm_rhs_0 = _mm256_load_si256((const __m256i *)(rhs + 0)); - __m256i ymm_rhs_1 = _mm256_load_si256((const __m256i *)(rhs + 32)); - - ymm_lhs_0 = _mm256_sign_epi8(ymm_lhs_0, ymm_rhs_0); - ymm_lhs_1 = _mm256_sign_epi8(ymm_lhs_1, ymm_rhs_1); - ymm_rhs_0 = _mm256_abs_epi8(ymm_rhs_0); - ymm_rhs_1 = _mm256_abs_epi8(ymm_rhs_1); - - ymm_sum_0 = _mm256_add_epi32( - _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_lhs_0), - ONES_INT16_AVX), - ymm_sum_0); - ymm_sum_1 = _mm256_add_epi32( - _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_1, ymm_lhs_1), - ONES_INT16_AVX), - ymm_sum_1); - } - - if (last >= last_aligned + 32) { - __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); - __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); - ymm_lhs = _mm256_sign_epi8(ymm_lhs, ymm_rhs); - ymm_rhs = _mm256_abs_epi8(ymm_rhs); - ymm_sum_0 = _mm256_add_epi32( - _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs, ymm_lhs), - ONES_INT16_AVX), - ymm_sum_0); - lhs += 32; - rhs += 32; - } - - if (last >= lhs + 16) { - __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); - xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); - xmm_rhs = _mm_abs_epi8(xmm_rhs); - ymm_sum_0 = _mm256_add_epi32( - _mm256_set_m128i(_mm_setzero_si128(), - _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), - ONES_INT16_SSE)), - ymm_sum_0); - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 64, rhs += 64) { - __m256i ymm_lhs_0 = _mm256_loadu_si256((const __m256i *)(lhs + 0)); - __m256i ymm_lhs_1 = _mm256_loadu_si256((const __m256i *)(lhs + 32)); - __m256i ymm_rhs_0 = _mm256_loadu_si256((const __m256i *)(rhs + 0)); - __m256i ymm_rhs_1 = _mm256_loadu_si256((const __m256i *)(rhs + 32)); - - ymm_lhs_0 = _mm256_sign_epi8(ymm_lhs_0, ymm_rhs_0); - ymm_lhs_1 = _mm256_sign_epi8(ymm_lhs_1, ymm_rhs_1); - ymm_rhs_0 = _mm256_abs_epi8(ymm_rhs_0); - ymm_rhs_1 = _mm256_abs_epi8(ymm_rhs_1); - - ymm_sum_0 = _mm256_add_epi32( - _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_lhs_0), - ONES_INT16_AVX), - ymm_sum_0); - ymm_sum_1 = _mm256_add_epi32( - _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_1, ymm_lhs_1), - ONES_INT16_AVX), - ymm_sum_1); - } - - if (last >= last_aligned + 32) { - __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); - __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); - ymm_lhs = _mm256_sign_epi8(ymm_lhs, ymm_rhs); - ymm_rhs = _mm256_abs_epi8(ymm_rhs); - ymm_sum_0 = _mm256_add_epi32( - _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs, ymm_lhs), - ONES_INT16_AVX), - ymm_sum_0); - lhs += 32; - rhs += 32; - } - - if (last >= lhs + 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); - xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); - xmm_rhs = _mm_abs_epi8(xmm_rhs); - ymm_sum_0 = _mm256_add_epi32( - _mm256_set_m128i(_mm_setzero_si128(), - _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), - ONES_INT16_SSE)), - ymm_sum_0); - lhs += 16; - rhs += 16; - } - } - result = static_cast( - HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); - - switch (last - lhs) { - case 15: - FMA_INT8_GENERAL(lhs[14], rhs[14], result) - /* FALLTHRU */ - case 14: - FMA_INT8_GENERAL(lhs[13], rhs[13], result) - /* FALLTHRU */ - case 13: - FMA_INT8_GENERAL(lhs[12], rhs[12], result) - /* FALLTHRU */ - case 12: - FMA_INT8_GENERAL(lhs[11], rhs[11], result) - /* FALLTHRU */ - case 11: - FMA_INT8_GENERAL(lhs[10], rhs[10], result) - /* FALLTHRU */ - case 10: - FMA_INT8_GENERAL(lhs[9], rhs[9], result) - /* FALLTHRU */ - case 9: - FMA_INT8_GENERAL(lhs[8], rhs[8], result) - /* FALLTHRU */ - case 8: - FMA_INT8_GENERAL(lhs[7], rhs[7], result) - /* FALLTHRU */ - case 7: - FMA_INT8_GENERAL(lhs[6], rhs[6], result) - /* FALLTHRU */ - case 6: - FMA_INT8_GENERAL(lhs[5], rhs[5], result) - /* FALLTHRU */ - case 5: - FMA_INT8_GENERAL(lhs[4], rhs[4], result) - /* FALLTHRU */ - case 4: - FMA_INT8_GENERAL(lhs[3], rhs[3], result) - /* FALLTHRU */ - case 3: - FMA_INT8_GENERAL(lhs[2], rhs[2], result) - /* FALLTHRU */ - case 2: - FMA_INT8_GENERAL(lhs[1], rhs[1], result) - /* FALLTHRU */ - case 1: - FMA_INT8_GENERAL(lhs[0], rhs[0], result) - } - return result; -} -#endif // __AVX2__ - -#if defined(__SSE4_1__) -//! Compute the distance between matrix and query (INT8, M=1, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - if (dim > 31) { - *out = InnerProductAVX(m, q, dim); - return; - } -#endif // __AVX2__ - *out = InnerProductSSE(m, q, dim); -} - -//! Compute the distance between matrix and query (INT8, M=2, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_2X1_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT8_2X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=2, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_2X2_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT8_2X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X1_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT8_4X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X2_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT8_4X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X4_AVX(m, q, dim, out, _mm_cvtepi32_ps) -#else - ACCUM_INT8_4X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_8X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_8X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_8X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_8X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_16X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_16X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_16X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_16X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=16) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X16_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_16X16_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=1) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X1_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X1_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=2) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X2_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X2_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=4) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X4_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X4_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=8) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X8_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X8_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=16) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X16_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X16_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=32) -void InnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, size_t dim, - float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X32_AVX(m, q, dim, out, _mm256_cvtepi32_ps) -#else - ACCUM_INT8_32X32_SSE(m, q, dim, out, _mm_cvtepi32_ps) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=1, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - if (dim > 31) { - *out = -InnerProductAVX(m, q, dim); - return; - } -#endif // __AVX2__ - *out = -InnerProductSSE(m, q, dim); -} - -//! Compute the distance between matrix and query (INT8, M=2, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_2X1_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_INT8_2X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=2, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_2X2_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_INT8_2X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X1_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_INT8_4X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X2_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_INT8_4X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=4, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_4X4_AVX(m, q, dim, out, NEGATE_FP32_SSE) -#else - ACCUM_INT8_4X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_8X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_8X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_8X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=8, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_8X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_8X8_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_16X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_16X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_16X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_16X8_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=16, N=16) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_16X16_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_16X16_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=1) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X1_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_32X1_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=2) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X2_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_32X2_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=4) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X4_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_32X4_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=8) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X8_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_32X8_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=16) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X16_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_32X16_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} - -//! Compute the distance between matrix and query (INT8, M=32, N=32) -void MinusInnerProductMatrix::Compute(const ValueType *m, - const ValueType *q, - size_t dim, float *out) { -#if defined(__AVX2__) - ACCUM_INT8_32X32_AVX(m, q, dim, out, NEGATE_FP32_AVX) -#else - ACCUM_INT8_32X32_SSE(m, q, dim, out, NEGATE_FP32_SSE) -#endif // __AVX2__ -} -#endif // __SSE4_1__ - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_int8_avx2.cc b/src/ailego/math/inner_product_matrix_int8_avx2.cc new file mode 100644 index 00000000..c32d6987 --- /dev/null +++ b/src/ailego/math/inner_product_matrix_int8_avx2.cc @@ -0,0 +1,189 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int8.i" +#include "distance_matrix_inner_product_utility.i" +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX2__) +//! Inner Product +float InnerProductAVX2(const int8_t *lhs, const int8_t *rhs, size_t size) { + const int8_t *last = lhs + size; + const int8_t *last_aligned = lhs + ((size >> 6) << 6); + float result = 0.0; + + __m256i ymm_sum_0 = _mm256_setzero_si256(); + __m256i ymm_sum_1 = _mm256_setzero_si256(); + + if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + __m256i ymm_lhs_0 = _mm256_load_si256((const __m256i *)(lhs + 0)); + __m256i ymm_lhs_1 = _mm256_load_si256((const __m256i *)(lhs + 32)); + __m256i ymm_rhs_0 = _mm256_load_si256((const __m256i *)(rhs + 0)); + __m256i ymm_rhs_1 = _mm256_load_si256((const __m256i *)(rhs + 32)); + + ymm_lhs_0 = _mm256_sign_epi8(ymm_lhs_0, ymm_rhs_0); + ymm_lhs_1 = _mm256_sign_epi8(ymm_lhs_1, ymm_rhs_1); + ymm_rhs_0 = _mm256_abs_epi8(ymm_rhs_0); + ymm_rhs_1 = _mm256_abs_epi8(ymm_rhs_1); + + ymm_sum_0 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_lhs_0), + ONES_INT16_AVX), + ymm_sum_0); + ymm_sum_1 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_1, ymm_lhs_1), + ONES_INT16_AVX), + ymm_sum_1); + } + + if (last >= last_aligned + 32) { + __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); + __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); + ymm_lhs = _mm256_sign_epi8(ymm_lhs, ymm_rhs); + ymm_rhs = _mm256_abs_epi8(ymm_rhs); + ymm_sum_0 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs, ymm_lhs), + ONES_INT16_AVX), + ymm_sum_0); + lhs += 32; + rhs += 32; + } + + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); + xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); + xmm_rhs = _mm_abs_epi8(xmm_rhs); + ymm_sum_0 = _mm256_add_epi32( + _mm256_set_m128i(_mm_setzero_si128(), + _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), + ONES_INT16_SSE)), + ymm_sum_0); + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + __m256i ymm_lhs_0 = _mm256_loadu_si256((const __m256i *)(lhs + 0)); + __m256i ymm_lhs_1 = _mm256_loadu_si256((const __m256i *)(lhs + 32)); + __m256i ymm_rhs_0 = _mm256_loadu_si256((const __m256i *)(rhs + 0)); + __m256i ymm_rhs_1 = _mm256_loadu_si256((const __m256i *)(rhs + 32)); + + ymm_lhs_0 = _mm256_sign_epi8(ymm_lhs_0, ymm_rhs_0); + ymm_lhs_1 = _mm256_sign_epi8(ymm_lhs_1, ymm_rhs_1); + ymm_rhs_0 = _mm256_abs_epi8(ymm_rhs_0); + ymm_rhs_1 = _mm256_abs_epi8(ymm_rhs_1); + + ymm_sum_0 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_lhs_0), + ONES_INT16_AVX), + ymm_sum_0); + ymm_sum_1 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_1, ymm_lhs_1), + ONES_INT16_AVX), + ymm_sum_1); + } + + if (last >= last_aligned + 32) { + __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); + __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); + ymm_lhs = _mm256_sign_epi8(ymm_lhs, ymm_rhs); + ymm_rhs = _mm256_abs_epi8(ymm_rhs); + ymm_sum_0 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs, ymm_lhs), + ONES_INT16_AVX), + ymm_sum_0); + lhs += 32; + rhs += 32; + } + + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); + xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); + xmm_rhs = _mm_abs_epi8(xmm_rhs); + ymm_sum_0 = _mm256_add_epi32( + _mm256_set_m128i(_mm_setzero_si128(), + _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), + ONES_INT16_SSE)), + ymm_sum_0); + lhs += 16; + rhs += 16; + } + } + result = static_cast( + HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); + + switch (last - lhs) { + case 15: + FMA_INT8_GENERAL(lhs[14], rhs[14], result) + /* FALLTHRU */ + case 14: + FMA_INT8_GENERAL(lhs[13], rhs[13], result) + /* FALLTHRU */ + case 13: + FMA_INT8_GENERAL(lhs[12], rhs[12], result) + /* FALLTHRU */ + case 12: + FMA_INT8_GENERAL(lhs[11], rhs[11], result) + /* FALLTHRU */ + case 11: + FMA_INT8_GENERAL(lhs[10], rhs[10], result) + /* FALLTHRU */ + case 10: + FMA_INT8_GENERAL(lhs[9], rhs[9], result) + /* FALLTHRU */ + case 9: + FMA_INT8_GENERAL(lhs[8], rhs[8], result) + /* FALLTHRU */ + case 8: + FMA_INT8_GENERAL(lhs[7], rhs[7], result) + /* FALLTHRU */ + case 7: + FMA_INT8_GENERAL(lhs[6], rhs[6], result) + /* FALLTHRU */ + case 6: + FMA_INT8_GENERAL(lhs[5], rhs[5], result) + /* FALLTHRU */ + case 5: + FMA_INT8_GENERAL(lhs[4], rhs[4], result) + /* FALLTHRU */ + case 4: + FMA_INT8_GENERAL(lhs[3], rhs[3], result) + /* FALLTHRU */ + case 3: + FMA_INT8_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + FMA_INT8_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + FMA_INT8_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +float MinusInnerProductAVX2(const int8_t *lhs, const int8_t *rhs, size_t size) { + return -InnerProductAVX2(lhs, rhs, size); +} + +#endif // __AVX2__ + + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_int8_dispatch.cc b/src/ailego/math/inner_product_matrix_int8_dispatch.cc new file mode 100644 index 00000000..5b756333 --- /dev/null +++ b/src/ailego/math/inner_product_matrix_int8_dispatch.cc @@ -0,0 +1,60 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX2__) +float InnerProductAVX2(const int8_t *lhs, const int8_t *rhs, size_t size); +float MinusInnerProductAVX2(const int8_t *lhs, const int8_t *rhs, size_t size); +#endif + +#if defined(__SSE4_1__) +float InnerProductSSE(const int8_t *lhs, const int8_t *rhs, size_t size); +float MinusInnerProductSSE(const int8_t *lhs, const int8_t *rhs, size_t size); +#endif + +#if defined(__SSE4_1__) +//! Compute the distance between matrix and query (INT8, M=1, N=1) +void InnerProductMatrix::Compute(const ValueType *m, + const ValueType *q, size_t dim, + float *out) { +#if defined(__AVX2__) + if (dim > 31) { + *out = InnerProductAVX2(m, q, dim); + return; + } +#endif // __AVX2__ + *out = InnerProductSSE(m, q, dim); +} + +//! Compute the distance between matrix and query (INT8, M=1, N=1) +void MinusInnerProductMatrix::Compute(const ValueType *m, + const ValueType *q, + size_t dim, float *out) { +#if defined(__AVX2__) + if (dim > 31) { + *out = MinusInnerProductAVX2(m, q, dim); + return; + } +#endif // __AVX2__ + *out = MinusInnerProductSSE(m, q, dim); +} +#endif // __SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/inner_product_matrix_int8_sse.cc b/src/ailego/math/inner_product_matrix_int8_sse.cc new file mode 100644 index 00000000..da0923c4 --- /dev/null +++ b/src/ailego/math/inner_product_matrix_int8_sse.cc @@ -0,0 +1,157 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int8.i" +#include "distance_matrix_inner_product_utility.i" +#include "inner_product_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__SSE4_1__) +//! Inner Product +float InnerProductSSE(const int8_t *lhs, const int8_t *rhs, size_t size) { + const int8_t *last = lhs + size; + const int8_t *last_aligned = lhs + ((size >> 5) << 5); + + __m128i xmm_sum_0 = _mm_setzero_si128(); + __m128i xmm_sum_1 = _mm_setzero_si128(); + + if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m128i xmm_lhs_0 = _mm_load_si128((const __m128i *)(lhs + 0)); + __m128i xmm_lhs_1 = _mm_load_si128((const __m128i *)(lhs + 16)); + __m128i xmm_rhs_0 = _mm_load_si128((const __m128i *)(rhs + 0)); + __m128i xmm_rhs_1 = _mm_load_si128((const __m128i *)(rhs + 16)); + + xmm_lhs_0 = _mm_sign_epi8(xmm_lhs_0, xmm_rhs_0); + xmm_lhs_1 = _mm_sign_epi8(xmm_lhs_1, xmm_rhs_1); + xmm_rhs_0 = _mm_abs_epi8(xmm_rhs_0); + xmm_rhs_1 = _mm_abs_epi8(xmm_rhs_1); + xmm_sum_0 = + _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_lhs_0), + ONES_INT16_SSE), + xmm_sum_0); + xmm_sum_1 = + _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_1, xmm_lhs_1), + ONES_INT16_SSE), + xmm_sum_1); + } + + if (last >= last_aligned + 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); + + xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); + xmm_rhs = _mm_abs_epi8(xmm_rhs); + xmm_sum_0 = _mm_add_epi32( + _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), ONES_INT16_SSE), + xmm_sum_0); + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m128i xmm_lhs_0 = _mm_loadu_si128((const __m128i *)(lhs + 0)); + __m128i xmm_lhs_1 = _mm_loadu_si128((const __m128i *)(lhs + 16)); + __m128i xmm_rhs_0 = _mm_loadu_si128((const __m128i *)(rhs + 0)); + __m128i xmm_rhs_1 = _mm_loadu_si128((const __m128i *)(rhs + 16)); + + xmm_lhs_0 = _mm_sign_epi8(xmm_lhs_0, xmm_rhs_0); + xmm_lhs_1 = _mm_sign_epi8(xmm_lhs_1, xmm_rhs_1); + xmm_rhs_0 = _mm_abs_epi8(xmm_rhs_0); + xmm_rhs_1 = _mm_abs_epi8(xmm_rhs_1); + xmm_sum_0 = + _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_lhs_0), + ONES_INT16_SSE), + xmm_sum_0); + xmm_sum_1 = + _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_1, xmm_lhs_1), + ONES_INT16_SSE), + xmm_sum_1); + } + + if (last >= last_aligned + 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); + + xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); + xmm_rhs = _mm_abs_epi8(xmm_rhs); + xmm_sum_0 = _mm_add_epi32( + _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), ONES_INT16_SSE), + xmm_sum_0); + lhs += 16; + rhs += 16; + } + } + float result = static_cast( + HorizontalAdd_INT32_V128(_mm_add_epi32(xmm_sum_0, xmm_sum_1))); + + switch (last - lhs) { + case 15: + FMA_INT8_GENERAL(lhs[14], rhs[14], result) + /* FALLTHRU */ + case 14: + FMA_INT8_GENERAL(lhs[13], rhs[13], result) + /* FALLTHRU */ + case 13: + FMA_INT8_GENERAL(lhs[12], rhs[12], result) + /* FALLTHRU */ + case 12: + FMA_INT8_GENERAL(lhs[11], rhs[11], result) + /* FALLTHRU */ + case 11: + FMA_INT8_GENERAL(lhs[10], rhs[10], result) + /* FALLTHRU */ + case 10: + FMA_INT8_GENERAL(lhs[9], rhs[9], result) + /* FALLTHRU */ + case 9: + FMA_INT8_GENERAL(lhs[8], rhs[8], result) + /* FALLTHRU */ + case 8: + FMA_INT8_GENERAL(lhs[7], rhs[7], result) + /* FALLTHRU */ + case 7: + FMA_INT8_GENERAL(lhs[6], rhs[6], result) + /* FALLTHRU */ + case 6: + FMA_INT8_GENERAL(lhs[5], rhs[5], result) + /* FALLTHRU */ + case 5: + FMA_INT8_GENERAL(lhs[4], rhs[4], result) + /* FALLTHRU */ + case 4: + FMA_INT8_GENERAL(lhs[3], rhs[3], result) + /* FALLTHRU */ + case 3: + FMA_INT8_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + FMA_INT8_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + FMA_INT8_GENERAL(lhs[0], rhs[0], result) + } + return result; +} + +float MinusInnerProductSSE(const int8_t *lhs, const int8_t *rhs, size_t size) { + return -InnerProductSSE(lhs, rhs, size); +} + +#endif // __SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp16.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp16.cc deleted file mode 100644 index f4ede598..00000000 --- a/src/ailego/math/mips_euclidean_distance_matrix_fp16.cc +++ /dev/null @@ -1,409 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "distance_matrix_accum_fp16.i" -#include "mips_euclidean_distance_matrix.h" - -namespace zvec { -namespace ailego { - -//! Calculate Fused-Multiply-Add (AVX512) -#define FMA_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ - zmm_sum = _mm512_fmadd_ps(zmm_m, zmm_q, zmm_sum); -#define FMA_MASK_FP32_AVX512(zmm_m, zmm_q, zmm_sum, mask) \ - zmm_sum = _mm512_mask3_fmadd_ps(zmm_m, zmm_q, zmm_sum, mask); - -#define HorizontalAdd_FP16_NEON(v) \ - vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(v)), vcvt_high_f32_f16(v))) - -#define HorizontalAdd_FP32_V512_TO_V256(zmm) \ - _mm256_add_ps( \ - _mm512_castps512_ps256(zmm), \ - _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(zmm), 1))) - -//! Calculate Fused-Multiply-Add (AVX, FP16) -#define FMA_FP16_GENERAL(lhs, rhs, sum, norm1, norm2) \ - { \ - float v1 = lhs; \ - float v2 = rhs; \ - sum += v1 * v2; \ - norm1 += v1 * v1; \ - norm2 += v2 * v2; \ - } - -#if defined(__ARM_NEON) && defined(__aarch64__) -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormNEON(const Float16 *lhs, - const Float16 *rhs, - size_t size, float *sql, - float *sqr) { - const Float16 *last = lhs + size; - const Float16 *last_aligned = lhs + ((size >> 3) << 3); - float16x8_t v_sum = vdupq_n_f16(0); - float16x8_t v_sum_norm1 = vdupq_n_f16(0); - float16x8_t v_sum_norm2 = vdupq_n_f16(0); - - for (; lhs != last_aligned; lhs += 8, rhs += 8) { - float16x8_t v_lhs = vld1q_f16((const float16_t *)lhs); - float16x8_t v_rhs = vld1q_f16((const float16_t *)rhs); - v_sum = vfmaq_f16(v_sum, v_lhs, v_rhs); - v_sum_norm1 = vfmaq_f16(v_sum_norm1, v_lhs, v_lhs); - v_sum_norm2 = vfmaq_f16(v_sum_norm2, v_rhs, v_rhs); - } - if (last >= last_aligned + 4) { - float16x8_t v_lhs = vcombine_f16(vld1_f16((const float16_t *)lhs), - vreinterpret_f16_u64(vdup_n_u64(0ul))); - float16x8_t v_rhs = vcombine_f16(vld1_f16((const float16_t *)rhs), - vreinterpret_f16_u64(vdup_n_u64(0ul))); - v_sum = vfmaq_f16(v_sum, v_lhs, v_rhs); - v_sum_norm1 = vfmaq_f16(v_sum_norm1, v_lhs, v_lhs); - v_sum_norm2 = vfmaq_f16(v_sum_norm2, v_rhs, v_rhs); - lhs += 4; - rhs += 4; - } - - float result = HorizontalAdd_FP16_NEON(v_sum); - float norm1 = HorizontalAdd_FP16_NEON(v_sum_norm1); - float norm2 = HorizontalAdd_FP16_NEON(v_sum_norm2); - - switch (last - lhs) { - case 3: - FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); - /* FALLTHRU */ - case 2: - FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); - /* FALLTHRU */ - case 1: - FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); - } - *sql = norm1; - *sqr = norm2; - return result; -} -#else -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormNEON(const Float16 *lhs, - const Float16 *rhs, - size_t size, float *sql, - float *sqr) { - const Float16 *last = lhs + size; - const Float16 *last_aligned = lhs + ((size >> 3) << 3); - float32x4_t v_sum_0 = vdupq_n_f32(0); - float32x4_t v_sum_1 = vdupq_n_f32(0); - float32x4_t v_sum_norm1 = vdupq_n_f32(0); - float32x4_t v_sum_norm2 = vdupq_n_f32(0); - - for (; lhs != last_aligned; lhs += 8, rhs += 8) { - float16x8_t v_lhs = vld1q_f16((const float16_t *)lhs); - float16x8_t v_rhs = vld1q_f16((const float16_t *)rhs); - float32x4_t v_lhs_0 = vcvt_f32_f16(vget_low_f16(v_lhs)); - float32x4_t v_rhs_0 = vcvt_f32_f16(vget_low_f16(v_rhs)); - float32x4_t v_lhs_1 = vcvt_high_f32_f16(v_lhs); - float32x4_t v_rhs_1 = vcvt_high_f32_f16(v_rhs); - v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); - v_sum_1 = vfmaq_f32(v_sum_1, v_lhs_1, v_rhs_1); - v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); - v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_1, v_lhs_1); - v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); - v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_1, v_rhs_1); - } - if (last >= last_aligned + 4) { - float32x4_t v_lhs_0 = vcvt_f32_f16(vld1_f16((const float16_t *)lhs)); - float32x4_t v_rhs_0 = vcvt_f32_f16(vld1_f16((const float16_t *)rhs)); - v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); - v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); - v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); - lhs += 4; - rhs += 4; - } - - float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); - float norm1 = vaddvq_f32(v_sum_norm1); - float norm2 = vaddvq_f32(v_sum_norm2); - switch (last - lhs) { - case 3: - FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); - /* FALLTHRU */ - case 2: - FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); - /* FALLTHRU */ - case 1: - FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); - } - *sql = norm1; - *sqr = norm2; - return result; -} -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#endif // __ARM_NEON && __aarch64__ - -#if defined(__AVX__) && defined(__F16C__) -#if defined(__AVX512F__) -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormAVX512(const Float16 *lhs, - const Float16 *rhs, - size_t size, float *sql, - float *sqr) { - __m512 zmm_sum_0 = _mm512_setzero_ps(); - __m512 zmm_sum_1 = _mm512_setzero_ps(); - __m512 zmm_sum_norm1 = _mm512_setzero_ps(); - __m512 zmm_sum_norm2 = _mm512_setzero_ps(); - - const Float16 *last = lhs + size; - const Float16 *last_aligned = lhs + ((size >> 5) << 5); - if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m512i zmm_lhs = _mm512_load_si512((const __m512i *)lhs); - __m512i zmm_rhs = _mm512_load_si512((const __m512i *)rhs); - __m512 zmm_lhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_lhs)); - __m512 zmm_lhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_lhs, 1)); - __m512 zmm_rhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_rhs)); - __m512 zmm_rhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_rhs, 1)); - FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) - FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) - FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) - FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) - } - if (last >= last_aligned + 16) { - __m512 zmm_lhs_0 = - _mm512_cvtph_ps(_mm256_load_si256((const __m256i *)lhs)); - __m512 zmm_rhs_0 = - _mm512_cvtph_ps(_mm256_load_si256((const __m256i *)rhs)); - FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) - FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m512i zmm_lhs = _mm512_loadu_si512((const __m512i *)lhs); - __m512i zmm_rhs = _mm512_loadu_si512((const __m512i *)rhs); - __m512 zmm_lhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_lhs)); - __m512 zmm_lhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_lhs, 1)); - __m512 zmm_rhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_rhs)); - __m512 zmm_rhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_rhs, 1)); - FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) - FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) - FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) - FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) - } - if (last >= last_aligned + 16) { - __m512 zmm_lhs_0 = - _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)lhs)); - __m512 zmm_rhs_0 = - _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)rhs)); - FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) - FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) - lhs += 16; - rhs += 16; - } - } - - __m256 ymm_sum_0 = - HorizontalAdd_FP32_V512_TO_V256(_mm512_add_ps(zmm_sum_0, zmm_sum_1)); - __m256 ymm_sum_norm1 = HorizontalAdd_FP32_V512_TO_V256(zmm_sum_norm1); - __m256 ymm_sum_norm2 = HorizontalAdd_FP32_V512_TO_V256(zmm_sum_norm2); - if (last >= lhs + 8) { - __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)lhs)); - __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)rhs)); - ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); - lhs += 8; - rhs += 8; - } - - float result = HorizontalAdd_FP32_V256(ymm_sum_0); - float norm1 = HorizontalAdd_FP32_V256(ymm_sum_norm1); - float norm2 = HorizontalAdd_FP32_V256(ymm_sum_norm2); - switch (last - lhs) { - case 7: - FMA_FP16_GENERAL(lhs[6], rhs[6], result, norm1, norm2); - /* FALLTHRU */ - case 6: - FMA_FP16_GENERAL(lhs[5], rhs[5], result, norm1, norm2); - /* FALLTHRU */ - case 5: - FMA_FP16_GENERAL(lhs[4], rhs[4], result, norm1, norm2); - /* FALLTHRU */ - case 4: - FMA_FP16_GENERAL(lhs[3], rhs[3], result, norm1, norm2); - /* FALLTHRU */ - case 3: - FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); - /* FALLTHRU */ - case 2: - FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); - /* FALLTHRU */ - case 1: - FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); - } - - *sql = norm1; - *sqr = norm2; - return result; -} -#else -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormAVX(const Float16 *lhs, - const Float16 *rhs, - size_t size, float *sql, - float *sqr) { - __m256 ymm_sum_0 = _mm256_setzero_ps(); - __m256 ymm_sum_1 = _mm256_setzero_ps(); - __m256 ymm_sum_norm1 = _mm256_setzero_ps(); - __m256 ymm_sum_norm2 = _mm256_setzero_ps(); - - const Float16 *last = lhs + size; - const Float16 *last_aligned = lhs + ((size >> 4) << 4); - if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); - __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); - __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_lhs)); - __m256 ymm_lhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_lhs, 1)); - __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_rhs)); - __m256 ymm_rhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_rhs, 1)); - ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); - } - if (last >= last_aligned + 8) { - __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i *)lhs)); - __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i *)rhs)); - ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); - lhs += 8; - rhs += 8; - } - } else { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); - __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); - __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_lhs)); - __m256 ymm_lhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_lhs, 1)); - __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_rhs)); - __m256 ymm_rhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_rhs, 1)); - ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); - } - if (last >= last_aligned + 8) { - __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)lhs)); - __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)rhs)); - ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); - lhs += 8; - rhs += 8; - } - } - - float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); - float norm1 = HorizontalAdd_FP32_V256(ymm_sum_norm1); - float norm2 = HorizontalAdd_FP32_V256(ymm_sum_norm2); - switch (last - lhs) { - case 7: - FMA_FP16_GENERAL(lhs[6], rhs[6], result, norm1, norm2); - /* FALLTHRU */ - case 6: - FMA_FP16_GENERAL(lhs[5], rhs[5], result, norm1, norm2); - /* FALLTHRU */ - case 5: - FMA_FP16_GENERAL(lhs[4], rhs[4], result, norm1, norm2); - /* FALLTHRU */ - case 4: - FMA_FP16_GENERAL(lhs[3], rhs[3], result, norm1, norm2); - /* FALLTHRU */ - case 3: - FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); - /* FALLTHRU */ - case 2: - FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); - /* FALLTHRU */ - case 1: - FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); - } - - *sql = norm1; - *sqr = norm2; - return result; -} -#endif // __AVX512F__ -#endif // __AVX__ && __F16C__ - -#if (defined(__F16C__) && defined(__AVX__)) || \ - (defined(__ARM_NEON) && defined(__aarch64__)) -//! Compute the distance between matrix and query by SphericalInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { - float u2; - float v2; - float sum; - -#if defined(__ARM_NEON) - sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); -#elif defined(__AVX512F__) - sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); -#else - sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); -#endif - - *out = ComputeSphericalInjection(sum, u2, v2, e2); -} - -//! Compute the distance between matrix and query by RepeatedQuadraticInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, - float *out) { - float u2; - float v2; - float sum; - -#if defined(__ARM_NEON) - sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); -#elif defined(__AVX512F__) - sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); -#else - sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); -#endif - - sum = e2 * (u2 + v2 - 2 * sum); - u2 *= e2; - v2 *= e2; - for (size_t i = 0; i < m; ++i) { - sum += (u2 - v2) * (u2 - v2); - u2 = u2 * u2; - v2 = v2 * v2; - } - *out = sum; -} -#endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx.cc new file mode 100644 index 00000000..c93edc1c --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx.cc @@ -0,0 +1,116 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp16.i" +#include "distance_matrix_mips_utility.i" +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX__) && defined(__F16C__) +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormAVX(const Float16 *lhs, const Float16 *rhs, + size_t size, float *sql, float *sqr) { + __m256 ymm_sum_0 = _mm256_setzero_ps(); + __m256 ymm_sum_1 = _mm256_setzero_ps(); + __m256 ymm_sum_norm1 = _mm256_setzero_ps(); + __m256 ymm_sum_norm2 = _mm256_setzero_ps(); + + const Float16 *last = lhs + size; + const Float16 *last_aligned = lhs + ((size >> 4) << 4); + if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); + __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); + __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_lhs)); + __m256 ymm_lhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_lhs, 1)); + __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_rhs)); + __m256 ymm_rhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_rhs, 1)); + ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); + } + if (last >= last_aligned + 8) { + __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i *)lhs)); + __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i *)rhs)); + ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); + lhs += 8; + rhs += 8; + } + } else { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); + __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); + __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_lhs)); + __m256 ymm_lhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_lhs, 1)); + __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_rhs)); + __m256 ymm_rhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_rhs, 1)); + ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); + } + if (last >= last_aligned + 8) { + __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)lhs)); + __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)rhs)); + ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); + lhs += 8; + rhs += 8; + } + } + + float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); + float norm1 = HorizontalAdd_FP32_V256(ymm_sum_norm1); + float norm2 = HorizontalAdd_FP32_V256(ymm_sum_norm2); + switch (last - lhs) { + case 7: + FMA_FP16_GENERAL(lhs[6], rhs[6], result, norm1, norm2); + /* FALLTHRU */ + case 6: + FMA_FP16_GENERAL(lhs[5], rhs[5], result, norm1, norm2); + /* FALLTHRU */ + case 5: + FMA_FP16_GENERAL(lhs[4], rhs[4], result, norm1, norm2); + /* FALLTHRU */ + case 4: + FMA_FP16_GENERAL(lhs[3], rhs[3], result, norm1, norm2); + /* FALLTHRU */ + case 3: + FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); + /* FALLTHRU */ + case 2: + FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); + /* FALLTHRU */ + case 1: + FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); + } + + *sql = norm1; + *sqr = norm2; + return result; +} +#endif // __AVX__ && __F16C__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx512.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx512.cc new file mode 100644 index 00000000..51ce4fc4 --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx512.cc @@ -0,0 +1,134 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp16.i" +#include "distance_matrix_mips_utility.i" +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX512F__) +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormAVX512(const Float16 *lhs, const Float16 *rhs, + size_t size, float *sql, float *sqr) { + __m512 zmm_sum_0 = _mm512_setzero_ps(); + __m512 zmm_sum_1 = _mm512_setzero_ps(); + __m512 zmm_sum_norm1 = _mm512_setzero_ps(); + __m512 zmm_sum_norm2 = _mm512_setzero_ps(); + + const Float16 *last = lhs + size; + const Float16 *last_aligned = lhs + ((size >> 5) << 5); + if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m512i zmm_lhs = _mm512_load_si512((const __m512i *)lhs); + __m512i zmm_rhs = _mm512_load_si512((const __m512i *)rhs); + __m512 zmm_lhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_lhs)); + __m512 zmm_lhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_lhs, 1)); + __m512 zmm_rhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_rhs)); + __m512 zmm_rhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_rhs, 1)); + FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) + FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) + FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) + FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) + } + if (last >= last_aligned + 16) { + __m512 zmm_lhs_0 = + _mm512_cvtph_ps(_mm256_load_si256((const __m256i *)lhs)); + __m512 zmm_rhs_0 = + _mm512_cvtph_ps(_mm256_load_si256((const __m256i *)rhs)); + FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) + FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m512i zmm_lhs = _mm512_loadu_si512((const __m512i *)lhs); + __m512i zmm_rhs = _mm512_loadu_si512((const __m512i *)rhs); + __m512 zmm_lhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_lhs)); + __m512 zmm_lhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_lhs, 1)); + __m512 zmm_rhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_rhs)); + __m512 zmm_rhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_rhs, 1)); + FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) + FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) + FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) + FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) + } + if (last >= last_aligned + 16) { + __m512 zmm_lhs_0 = + _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)lhs)); + __m512 zmm_rhs_0 = + _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)rhs)); + FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) + FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) + lhs += 16; + rhs += 16; + } + } + + __m256 ymm_sum_0 = + HorizontalAdd_FP32_V512_TO_V256(_mm512_add_ps(zmm_sum_0, zmm_sum_1)); + __m256 ymm_sum_norm1 = HorizontalAdd_FP32_V512_TO_V256(zmm_sum_norm1); + __m256 ymm_sum_norm2 = HorizontalAdd_FP32_V512_TO_V256(zmm_sum_norm2); + if (last >= lhs + 8) { + __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)lhs)); + __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)rhs)); + ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); + lhs += 8; + rhs += 8; + } + + float result = HorizontalAdd_FP32_V256(ymm_sum_0); + float norm1 = HorizontalAdd_FP32_V256(ymm_sum_norm1); + float norm2 = HorizontalAdd_FP32_V256(ymm_sum_norm2); + switch (last - lhs) { + case 7: + FMA_FP16_GENERAL(lhs[6], rhs[6], result, norm1, norm2); + /* FALLTHRU */ + case 6: + FMA_FP16_GENERAL(lhs[5], rhs[5], result, norm1, norm2); + /* FALLTHRU */ + case 5: + FMA_FP16_GENERAL(lhs[4], rhs[4], result, norm1, norm2); + /* FALLTHRU */ + case 4: + FMA_FP16_GENERAL(lhs[3], rhs[3], result, norm1, norm2); + /* FALLTHRU */ + case 3: + FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); + /* FALLTHRU */ + case 2: + FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); + /* FALLTHRU */ + case 1: + FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); + } + + *sql = norm1; + *sqr = norm2; + return result; +} +#endif // __AVX512F__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp16_dispatch.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp16_dispatch.cc new file mode 100644 index 00000000..b99ab45e --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp16_dispatch.cc @@ -0,0 +1,96 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) +float InnerProductAndSquaredNormNEON(const Float16 *lhs, const Float16 *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if defined(__AVX512F__) +float InnerProductAndSquaredNormAVX512(const Float16 *lhs, const Float16 *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if defined(__AVX__) +float InnerProductAndSquaredNormAVX(const Float16 *lhs, const Float16 *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if (defined(__F16C__) && defined(__AVX__)) || \ + (defined(__ARM_NEON) && defined(__aarch64__)) +//! Compute the distance between matrix and query by SphericalInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + +#if defined(__ARM_NEON) + sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); +#else +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); + } else +#endif //__AVX512F__ + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { + sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); + } +#endif //__ARM_NEON + + *out = ComputeSphericalInjection(sum, u2, v2, e2); +} + +//! Compute the distance between matrix and query by RepeatedQuadraticInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, + float *out) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + +#if defined(__ARM_NEON) + sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); +#else +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); + } else +#endif //__AVX512F__ + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { + sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); + } +#endif //__ARM_NEON + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + *out = sum; +} + +#endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp16_neon.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp16_neon.cc new file mode 100644 index 00000000..22493b4e --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp16_neon.cc @@ -0,0 +1,126 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp16.i" +#include "distance_matrix_mips_utility.i" +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) && defined(__aarch64__) +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormNEON(const Float16 *lhs, const Float16 *rhs, + size_t size, float *sql, float *sqr) { + const Float16 *last = lhs + size; + const Float16 *last_aligned = lhs + ((size >> 3) << 3); + float16x8_t v_sum = vdupq_n_f16(0); + float16x8_t v_sum_norm1 = vdupq_n_f16(0); + float16x8_t v_sum_norm2 = vdupq_n_f16(0); + + for (; lhs != last_aligned; lhs += 8, rhs += 8) { + float16x8_t v_lhs = vld1q_f16((const float16_t *)lhs); + float16x8_t v_rhs = vld1q_f16((const float16_t *)rhs); + v_sum = vfmaq_f16(v_sum, v_lhs, v_rhs); + v_sum_norm1 = vfmaq_f16(v_sum_norm1, v_lhs, v_lhs); + v_sum_norm2 = vfmaq_f16(v_sum_norm2, v_rhs, v_rhs); + } + if (last >= last_aligned + 4) { + float16x8_t v_lhs = vcombine_f16(vld1_f16((const float16_t *)lhs), + vreinterpret_f16_u64(vdup_n_u64(0ul))); + float16x8_t v_rhs = vcombine_f16(vld1_f16((const float16_t *)rhs), + vreinterpret_f16_u64(vdup_n_u64(0ul))); + v_sum = vfmaq_f16(v_sum, v_lhs, v_rhs); + v_sum_norm1 = vfmaq_f16(v_sum_norm1, v_lhs, v_lhs); + v_sum_norm2 = vfmaq_f16(v_sum_norm2, v_rhs, v_rhs); + lhs += 4; + rhs += 4; + } + + float result = HorizontalAdd_FP16_NEON(v_sum); + float norm1 = HorizontalAdd_FP16_NEON(v_sum_norm1); + float norm2 = HorizontalAdd_FP16_NEON(v_sum_norm2); + + switch (last - lhs) { + case 3: + FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); + /* FALLTHRU */ + case 2: + FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); + /* FALLTHRU */ + case 1: + FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); + } + *sql = norm1; + *sqr = norm2; + return result; +} +#else +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormNEON(const Float16 *lhs, const Float16 *rhs, + size_t size, float *sql, float *sqr) { + const Float16 *last = lhs + size; + const Float16 *last_aligned = lhs + ((size >> 3) << 3); + float32x4_t v_sum_0 = vdupq_n_f32(0); + float32x4_t v_sum_1 = vdupq_n_f32(0); + float32x4_t v_sum_norm1 = vdupq_n_f32(0); + float32x4_t v_sum_norm2 = vdupq_n_f32(0); + + for (; lhs != last_aligned; lhs += 8, rhs += 8) { + float16x8_t v_lhs = vld1q_f16((const float16_t *)lhs); + float16x8_t v_rhs = vld1q_f16((const float16_t *)rhs); + float32x4_t v_lhs_0 = vcvt_f32_f16(vget_low_f16(v_lhs)); + float32x4_t v_rhs_0 = vcvt_f32_f16(vget_low_f16(v_rhs)); + float32x4_t v_lhs_1 = vcvt_high_f32_f16(v_lhs); + float32x4_t v_rhs_1 = vcvt_high_f32_f16(v_rhs); + v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); + v_sum_1 = vfmaq_f32(v_sum_1, v_lhs_1, v_rhs_1); + v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); + v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_1, v_lhs_1); + v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); + v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_1, v_rhs_1); + } + if (last >= last_aligned + 4) { + float32x4_t v_lhs_0 = vcvt_f32_f16(vld1_f16((const float16_t *)lhs)); + float32x4_t v_rhs_0 = vcvt_f32_f16(vld1_f16((const float16_t *)rhs)); + v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); + v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); + v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); + lhs += 4; + rhs += 4; + } + + float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); + float norm1 = vaddvq_f32(v_sum_norm1); + float norm2 = vaddvq_f32(v_sum_norm2); + switch (last - lhs) { + case 3: + FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); + /* FALLTHRU */ + case 2: + FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); + /* FALLTHRU */ + case 1: + FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); + } + *sql = norm1; + *sqr = norm2; + return result; +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __ARM_NEON && __aarch64__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp32.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp32.cc deleted file mode 100644 index f14117a5..00000000 --- a/src/ailego/math/mips_euclidean_distance_matrix_fp32.cc +++ /dev/null @@ -1,684 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "distance_matrix_accum_fp32.i" -#include "mips_euclidean_distance_matrix.h" - -namespace zvec { -namespace ailego { - -//! Calculate Fused-Multiply-Add (GENERAL) -#define FMA_FP32_GENERAL(lhs, rhs, sum, norm1, norm2) \ - { \ - sum += (lhs) * (rhs); \ - norm1 += (lhs) * (lhs); \ - norm2 += (rhs) * (rhs); \ - } - -//! Calculate Fused-Multiply-Add (AVX512) -#define FMA_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ - zmm_sum = _mm512_fmadd_ps(zmm_m, zmm_q, zmm_sum); -#define FMA_MASK_FP32_AVX512(zmm_m, zmm_q, zmm_sum, mask) \ - zmm_sum = _mm512_mask3_fmadd_ps(zmm_m, zmm_q, zmm_sum, mask); - -#if defined(__ARM_NEON) -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormNEON(const float *lhs, - const float *rhs, - size_t size, float *sql, - float *sqr) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 3) << 3); - - float32x4_t v_sum_0 = vdupq_n_f32(0); - float32x4_t v_sum_1 = vdupq_n_f32(0); - float32x4_t v_sum_norm1 = vdupq_n_f32(0); - float32x4_t v_sum_norm2 = vdupq_n_f32(0); - - for (; lhs != last_aligned; lhs += 8, rhs += 8) { - float32x4_t v_lhs_0 = vld1q_f32(lhs + 0); - float32x4_t v_lhs_1 = vld1q_f32(lhs + 4); - float32x4_t v_rhs_0 = vld1q_f32(rhs + 0); - float32x4_t v_rhs_1 = vld1q_f32(rhs + 4); - v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); - v_sum_1 = vfmaq_f32(v_sum_1, v_lhs_1, v_rhs_1); - v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); - v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_1, v_lhs_1); - v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); - v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_1, v_rhs_1); - } - if (last >= last_aligned + 4) { - float32x4_t v_lhs_0 = vld1q_f32(lhs); - float32x4_t v_rhs_0 = vld1q_f32(rhs); - v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); - v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); - v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); - lhs += 4; - rhs += 4; - } - - float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); - float norm1 = vaddvq_f32(v_sum_norm1); - float norm2 = vaddvq_f32(v_sum_norm2); - switch (last - lhs) { - case 3: - FMA_FP32_GENERAL(lhs[2], rhs[2], result, norm1, norm2) - /* FALLTHRU */ - case 2: - FMA_FP32_GENERAL(lhs[1], rhs[1], result, norm1, norm2) - /* FALLTHRU */ - case 1: - FMA_FP32_GENERAL(lhs[0], rhs[0], result, norm1, norm2) - } - *sql = norm1; - *sqr = norm2; - return result; -} -#endif // __ARM_NEON - -#if defined(__SSE__) -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormSSE(const float *lhs, - const float *rhs, size_t size, - float *sql, float *sqr) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 3) << 3); - - __m128 xmm_sum = _mm_setzero_ps(); - __m128 xmm_sum_norm1 = _mm_setzero_ps(); - __m128 xmm_sum_norm2 = _mm_setzero_ps(); - - if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { - for (; lhs != last_aligned; lhs += 8, rhs += 8) { - __m128 xmm_lhs_0 = _mm_load_ps(lhs + 0); - __m128 xmm_lhs_1 = _mm_load_ps(lhs + 4); - __m128 xmm_rhs_0 = _mm_load_ps(rhs + 0); - __m128 xmm_rhs_1 = _mm_load_ps(rhs + 4); - xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); - xmm_sum = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum); - xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); - xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); - xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); - xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); - } - - if (last >= last_aligned + 4) { - __m128 xmm_lhs_0 = _mm_load_ps(lhs); - __m128 xmm_rhs_0 = _mm_load_ps(rhs); - xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); - xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); - xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); - lhs += 4; - rhs += 4; - } - } else { - for (; lhs != last_aligned; lhs += 8, rhs += 8) { - __m128 xmm_lhs_0 = _mm_loadu_ps(lhs + 0); - __m128 xmm_lhs_1 = _mm_loadu_ps(lhs + 4); - __m128 xmm_rhs_0 = _mm_loadu_ps(rhs + 0); - __m128 xmm_rhs_1 = _mm_loadu_ps(rhs + 4); - xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); - xmm_sum = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum); - xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); - xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); - xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); - xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); - } - - if (last >= last_aligned + 4) { - __m128 xmm_lhs_0 = _mm_loadu_ps(lhs); - __m128 xmm_rhs_0 = _mm_loadu_ps(rhs); - xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); - xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); - xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); - lhs += 4; - rhs += 4; - } - } - float result = HorizontalAdd_FP32_V128(xmm_sum); - float norm1 = HorizontalAdd_FP32_V128(xmm_sum_norm1); - float norm2 = HorizontalAdd_FP32_V128(xmm_sum_norm2); - - switch (last - lhs) { - case 3: - FMA_FP32_GENERAL(lhs[2], rhs[2], result, norm1, norm2) - /* FALLTHRU */ - case 2: - FMA_FP32_GENERAL(lhs[1], rhs[1], result, norm1, norm2) - /* FALLTHRU */ - case 1: - FMA_FP32_GENERAL(lhs[0], rhs[0], result, norm1, norm2) - } - *sql = norm1; - *sqr = norm2; - return result; -} -#endif // __SSE__ - -#if defined(__AVX__) -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormAVX(const float *lhs, - const float *rhs, size_t size, - float *sql, float *sqr) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 4) << 4); - - __m256 ymm_sum_0 = _mm256_setzero_ps(); - __m256 ymm_sum_1 = _mm256_setzero_ps(); - __m256 ymm_sum_norm1 = _mm256_setzero_ps(); - __m256 ymm_sum_norm2 = _mm256_setzero_ps(); - - if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m256 ymm_lhs_0 = _mm256_load_ps(lhs + 0); - __m256 ymm_lhs_1 = _mm256_load_ps(lhs + 8); - __m256 ymm_rhs_0 = _mm256_load_ps(rhs + 0); - __m256 ymm_rhs_1 = _mm256_load_ps(rhs + 8); - ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); - } - - if (last >= last_aligned + 8) { - __m256 ymm_lhs_0 = _mm256_load_ps(lhs); - __m256 ymm_rhs_0 = _mm256_load_ps(rhs); - ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); - lhs += 8; - rhs += 8; - } - } else { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m256 ymm_lhs_0 = _mm256_loadu_ps(lhs + 0); - __m256 ymm_lhs_1 = _mm256_loadu_ps(lhs + 8); - __m256 ymm_rhs_0 = _mm256_loadu_ps(rhs + 0); - __m256 ymm_rhs_1 = _mm256_loadu_ps(rhs + 8); - ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); - } - - if (last >= last_aligned + 8) { - __m256 ymm_lhs_0 = _mm256_loadu_ps(lhs); - __m256 ymm_rhs_0 = _mm256_loadu_ps(rhs); - ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); - ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); - lhs += 8; - rhs += 8; - } - } - float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); - float norm1 = HorizontalAdd_FP32_V256(ymm_sum_norm1); - float norm2 = HorizontalAdd_FP32_V256(ymm_sum_norm2); - - switch (last - lhs) { - case 7: - FMA_FP32_GENERAL(lhs[6], rhs[6], result, norm1, norm2) - /* FALLTHRU */ - case 6: - FMA_FP32_GENERAL(lhs[5], rhs[5], result, norm1, norm2) - /* FALLTHRU */ - case 5: - FMA_FP32_GENERAL(lhs[4], rhs[4], result, norm1, norm2) - /* FALLTHRU */ - case 4: - FMA_FP32_GENERAL(lhs[3], rhs[3], result, norm1, norm2) - /* FALLTHRU */ - case 3: - FMA_FP32_GENERAL(lhs[2], rhs[2], result, norm1, norm2) - /* FALLTHRU */ - case 2: - FMA_FP32_GENERAL(lhs[1], rhs[1], result, norm1, norm2) - /* FALLTHRU */ - case 1: - FMA_FP32_GENERAL(lhs[0], rhs[0], result, norm1, norm2) - } - *sql = norm1; - *sqr = norm2; - return result; -} -#endif // __AVX__ - -#if defined(__AVX512F__) -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormAVX512(const float *lhs, - const float *rhs, - size_t size, float *sql, - float *sqr) { - const float *last = lhs + size; - const float *last_aligned = lhs + ((size >> 5) << 5); - - __m512 zmm_sum_0 = _mm512_setzero_ps(); - __m512 zmm_sum_1 = _mm512_setzero_ps(); - __m512 zmm_sum_norm1 = _mm512_setzero_ps(); - __m512 zmm_sum_norm2 = _mm512_setzero_ps(); - - if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m512 zmm_lhs_0 = _mm512_load_ps(lhs + 0); - __m512 zmm_lhs_1 = _mm512_load_ps(lhs + 16); - __m512 zmm_rhs_0 = _mm512_load_ps(rhs + 0); - __m512 zmm_rhs_1 = _mm512_load_ps(rhs + 16); - FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) - FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) - FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) - FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) - } - - if (last >= last_aligned + 16) { - __m512 zmm_lhs_0 = _mm512_load_ps(lhs); - __m512 zmm_rhs_0 = _mm512_load_ps(rhs); - FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) - FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m512 zmm_lhs_0 = _mm512_loadu_ps(lhs + 0); - __m512 zmm_lhs_1 = _mm512_loadu_ps(lhs + 16); - __m512 zmm_rhs_0 = _mm512_loadu_ps(rhs + 0); - __m512 zmm_rhs_1 = _mm512_loadu_ps(rhs + 16); - FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) - FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) - FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) - FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) - } - - if (last >= last_aligned + 16) { - __m512 zmm_lhs_0 = _mm512_loadu_ps(lhs); - __m512 zmm_rhs_0 = _mm512_loadu_ps(rhs); - FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) - FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) - FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) - lhs += 16; - rhs += 16; - } - } - - zmm_sum_0 = _mm512_add_ps(zmm_sum_0, zmm_sum_1); - if (lhs != last) { - __mmask16 mask = (__mmask16)((1 << (last - lhs)) - 1); - __m512 zmm_undefined = _mm512_undefined_ps(); - __m512 zmm_lhs_0 = _mm512_mask_loadu_ps(zmm_undefined, mask, lhs); - __m512 zmm_rhs_0 = _mm512_mask_loadu_ps(zmm_undefined, mask, rhs); - FMA_MASK_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0, mask); - FMA_MASK_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1, mask); - FMA_MASK_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2, mask); - } - - *sql = HorizontalAdd_FP32_V512(zmm_sum_norm1); - *sqr = HorizontalAdd_FP32_V512(zmm_sum_norm2); - return HorizontalAdd_FP32_V512(zmm_sum_0); -} -#endif // __AVX512F__ - -#if defined(__SSE__) -//! Compute the distance between matrix and query by SphericalInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { - float u2; - float v2; - float sum; - -#if defined(__AVX512F__) - if (dim > 15) { - sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); - } else -#endif // __AVX512F__ -#if defined(__AVX__) - if (dim > 7) { - sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); - } else -#endif // __AVX__ - { - sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); - } - - *out = ComputeSphericalInjection(sum, u2, v2, e2); -} - -//! Compute the distance between matrix and query by RepeatedQuadraticInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, - float *out) { - float u2; - float v2; - float sum; - -#if defined(__AVX512F__) - if (dim > 15) { - sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); - } else -#endif // __AVX512F__ -#if defined(__AVX__) - if (dim > 7) { - sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); - } else -#endif // __AVX__ - { - sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); - } - - sum = e2 * (u2 + v2 - 2 * sum); - u2 *= e2; - v2 *= e2; - for (size_t i = 0; i < m; ++i) { - sum += (u2 - v2) * (u2 - v2); - u2 = u2 * u2; - v2 = v2 * v2; - } - *out = sum; -} -#endif // __SSE__ - -// #if 1 -#if defined(__SSE4_1__) -const static __m128i SHUFFLE_MASK16[16] = { - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, -127, -127, -127, -127), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, - 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, - -127, -127, 15, 14, 13, 12), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 3, 2, 1, 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 7, 6, 5, 4), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, - 11, 10, 9, 8), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, 1, - 0), - _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, - 4), - _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), -}; - -constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; - -float MipsInnerProductSparseInSegmentSSE(uint32_t m_sparse_count, - const uint16_t *m_sparse_index, - const float *m_sparse_value, - uint32_t q_sparse_count, - const uint16_t *q_sparse_index, - const float *q_sparse_value) { - float sum = 0.0f; - - // size_t alloc_size = 0; - - size_t i1 = 0, i2 = 0; - size_t end1 = m_sparse_count / 8 * 8; - size_t end2 = q_sparse_count / 8 * 8; - - // std::vector mem1; - // std::vector mem2; - - float fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; - float fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; - - float *val_start_1 = fixed_buffer_1; - float *val_start_2 = fixed_buffer_2; - - // uint32_t max_count = std::max(m_sparse_count, q_sparse_count); - - // if (MAX_SPARSE_BUFFER_LENGTH < max_count) { - // mem1.reserve(max_count); - // mem2.reserve(max_count); - - // val_start_1 = mem1.data(); - // val_start_2 = mem2.data(); - // } - - float *val_1 = val_start_1; - float *val_2 = val_start_2; - - if (i1 < end1 && i2 < end2) { - while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { - i1 += 8; - if (i1 >= end1) goto do_scalar; - } - - while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { - i2 += 8; - if (i2 >= end2) goto do_scalar; - } - - __m128i mm_index_m = - _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); - __m128i mm_index_q = - _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); - - while (true) { -#ifdef DEBUG_PRINT - std::cout << "index 1: " << std::endl; - print_data16(&mm_index_m); - - std::cout << "index 2: " << std::endl; - print_data16(&mm_index_q); -#endif - - __m128i mm_cmp_res = - _mm_cmpistrm(mm_index_q, mm_index_m, - _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); - -#ifdef DEBUG_PRINT - std::cout << "cmp res: " << std::endl; - print_data16(&mm_cmp_res); -#endif - - int r = _mm_extract_epi32(mm_cmp_res, 0); - - if (r) { - int r1 = r & 15; - - __m128i v = _mm_loadu_si128( - reinterpret_cast(&m_sparse_value[i1])); - __m128 vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); - - _mm_storeu_ps(val_1, vs); - val_1 += _mm_popcnt_u32(r1); - - int r2 = (r >> 4) & 15; - v = _mm_loadu_si128( - reinterpret_cast(&m_sparse_value[i1 + 4])); - vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); - _mm_storeu_ps(val_1, vs); - val_1 += _mm_popcnt_u32(r2); - - mm_cmp_res = _mm_cmpistrm( - mm_index_m, mm_index_q, - _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); - r = _mm_extract_epi32(mm_cmp_res, 0); - - r1 = r & 15; - - v = _mm_loadu_si128( - reinterpret_cast(&q_sparse_value[i2])); - vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); - _mm_storeu_ps(val_2, vs); - val_2 += _mm_popcnt_u32(r1); - - r2 = (r >> 4) & 15; - v = _mm_loadu_si128( - reinterpret_cast(&q_sparse_value[i2 + 4])); - vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); - _mm_storeu_ps(val_2, vs); - val_2 += _mm_popcnt_u32(r2); - } - - const uint16_t id1_max = m_sparse_index[i1 + 7]; - - if (id1_max <= q_sparse_index[i2 + 7]) { - i1 += 8; - if (i1 >= end1) goto do_scalar; - mm_index_m = _mm_loadu_si128( - reinterpret_cast(&m_sparse_index[i1])); - } - - if (id1_max >= q_sparse_index[i2 + 7]) { - i2 += 8; - if (i2 >= end2) goto do_scalar; - mm_index_q = _mm_loadu_si128( - reinterpret_cast(&q_sparse_index[i2])); - } - } - } - -do_scalar: - while (i1 < m_sparse_count && i2 < q_sparse_count) { - if (m_sparse_index[i1] == q_sparse_index[i2]) { - *val_1++ = m_sparse_value[i1]; - *val_2++ = q_sparse_value[i2]; - - ++i1; - ++i2; - } else if (m_sparse_index[i1] < q_sparse_index[i2]) { - ++i1; - } else { - ++i2; - } - } - - size_t res_num = val_1 - val_start_1; - - // if (res_num != val_2 - val_start_2) { - // std::cerr << "size mismatch!" << std::endl; - // } - - size_t res_num4 = res_num / 4 * 4; - - if (res_num4) { - __m128 sum128 = _mm_set1_ps(0); - - for (size_t k = 0; k < res_num4; k += 4) { - sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(val_start_1 + k), - _mm_loadu_ps(val_start_2 + k))); - } - - float __attribute__((aligned(16))) tmp_res[4]; - _mm_store_ps(tmp_res, sum128); - sum += (tmp_res[0] + tmp_res[1] + tmp_res[2] + tmp_res[3]); - } - - for (size_t k = res_num4; k < res_num; ++k) - sum += val_start_1[k] * val_start_2[k]; - - return sum; -} -#else -float MipsInnerProductSparseInSegment(uint32_t m_sparse_count, - const uint16_t *m_sparse_index, - const float *m_sparse_value, - uint32_t q_sparse_count, - const uint16_t *q_sparse_index, - const float *q_sparse_value) { - float sum = 0.0f; - - size_t m_i = 0; - size_t q_i = 0; - while (m_i < m_sparse_count && q_i < q_sparse_count) { - if (m_sparse_index[m_i] == q_sparse_index[q_i]) { - sum += m_sparse_value[m_i] * q_sparse_value[q_i]; - - ++m_i; - ++q_i; - } else if (m_sparse_index[m_i] < q_sparse_index[q_i]) { - ++m_i; - } else { - ++q_i; - } - } - - return sum; -} -#endif // __SSE4_1__ - -template <> -float MipsSquaredEuclideanSparseDistanceMatrix:: - ComputeInnerProductSparseInSegment(uint32_t m_sparse_count, - const uint16_t *m_sparse_index, - const ValueType *m_sparse_value, - uint32_t q_sparse_count, - const uint16_t *q_sparse_index, - const ValueType *q_sparse_value) { -#if defined(__SSE4_1__) - return MipsInnerProductSparseInSegmentSSE(m_sparse_count, m_sparse_index, - m_sparse_value, q_sparse_count, - q_sparse_index, q_sparse_value); -#else - return MipsInnerProductSparseInSegment(m_sparse_count, m_sparse_index, - m_sparse_value, q_sparse_count, - q_sparse_index, q_sparse_value); -#endif -} - -#if defined(__ARM_NEON) -//! Compute the distance between matrix and query by SphericalInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { - float u2; - float v2; - float sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); - - *out = ComputeSphericalInjection(sum, u2, v2, e2); -} - -//! Compute the distance between matrix and query by RepeatedQuadraticInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, - float *out) { - float u2; - float v2; - float sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); - - sum = e2 * (u2 + v2 - 2 * sum); - u2 *= e2; - v2 *= e2; - for (size_t i = 0; i < m; ++i) { - sum += (u2 - v2) * (u2 - v2); - u2 = u2 * u2; - v2 = v2 * v2; - } - *out = sum; -} -#endif //__ARM_NEON - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx.cc new file mode 100644 index 00000000..cff60e8f --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx.cc @@ -0,0 +1,114 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_mips_utility.i" +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX__) +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormAVX(const float *lhs, const float *rhs, + size_t size, float *sql, float *sqr) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 4) << 4); + + __m256 ymm_sum_0 = _mm256_setzero_ps(); + __m256 ymm_sum_1 = _mm256_setzero_ps(); + __m256 ymm_sum_norm1 = _mm256_setzero_ps(); + __m256 ymm_sum_norm2 = _mm256_setzero_ps(); + + if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m256 ymm_lhs_0 = _mm256_load_ps(lhs + 0); + __m256 ymm_lhs_1 = _mm256_load_ps(lhs + 8); + __m256 ymm_rhs_0 = _mm256_load_ps(rhs + 0); + __m256 ymm_rhs_1 = _mm256_load_ps(rhs + 8); + ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); + } + + if (last >= last_aligned + 8) { + __m256 ymm_lhs_0 = _mm256_load_ps(lhs); + __m256 ymm_rhs_0 = _mm256_load_ps(rhs); + ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); + lhs += 8; + rhs += 8; + } + } else { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m256 ymm_lhs_0 = _mm256_loadu_ps(lhs + 0); + __m256 ymm_lhs_1 = _mm256_loadu_ps(lhs + 8); + __m256 ymm_rhs_0 = _mm256_loadu_ps(rhs + 0); + __m256 ymm_rhs_1 = _mm256_loadu_ps(rhs + 8); + ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); + } + + if (last >= last_aligned + 8) { + __m256 ymm_lhs_0 = _mm256_loadu_ps(lhs); + __m256 ymm_rhs_0 = _mm256_loadu_ps(rhs); + ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); + ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); + lhs += 8; + rhs += 8; + } + } + float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); + float norm1 = HorizontalAdd_FP32_V256(ymm_sum_norm1); + float norm2 = HorizontalAdd_FP32_V256(ymm_sum_norm2); + + switch (last - lhs) { + case 7: + FMA_FP32_GENERAL(lhs[6], rhs[6], result, norm1, norm2) + /* FALLTHRU */ + case 6: + FMA_FP32_GENERAL(lhs[5], rhs[5], result, norm1, norm2) + /* FALLTHRU */ + case 5: + FMA_FP32_GENERAL(lhs[4], rhs[4], result, norm1, norm2) + /* FALLTHRU */ + case 4: + FMA_FP32_GENERAL(lhs[3], rhs[3], result, norm1, norm2) + /* FALLTHRU */ + case 3: + FMA_FP32_GENERAL(lhs[2], rhs[2], result, norm1, norm2) + /* FALLTHRU */ + case 2: + FMA_FP32_GENERAL(lhs[1], rhs[1], result, norm1, norm2) + /* FALLTHRU */ + case 1: + FMA_FP32_GENERAL(lhs[0], rhs[0], result, norm1, norm2) + } + *sql = norm1; + *sqr = norm2; + return result; +} +#endif // __AVX__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx512.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx512.cc new file mode 100644 index 00000000..1ac56a20 --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx512.cc @@ -0,0 +1,100 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_mips_utility.i" +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX512F__) +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormAVX512(const float *lhs, const float *rhs, + size_t size, float *sql, float *sqr) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 5) << 5); + + __m512 zmm_sum_0 = _mm512_setzero_ps(); + __m512 zmm_sum_1 = _mm512_setzero_ps(); + __m512 zmm_sum_norm1 = _mm512_setzero_ps(); + __m512 zmm_sum_norm2 = _mm512_setzero_ps(); + + if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m512 zmm_lhs_0 = _mm512_load_ps(lhs + 0); + __m512 zmm_lhs_1 = _mm512_load_ps(lhs + 16); + __m512 zmm_rhs_0 = _mm512_load_ps(rhs + 0); + __m512 zmm_rhs_1 = _mm512_load_ps(rhs + 16); + FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) + FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) + FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) + FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) + } + + if (last >= last_aligned + 16) { + __m512 zmm_lhs_0 = _mm512_load_ps(lhs); + __m512 zmm_rhs_0 = _mm512_load_ps(rhs); + FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) + FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m512 zmm_lhs_0 = _mm512_loadu_ps(lhs + 0); + __m512 zmm_lhs_1 = _mm512_loadu_ps(lhs + 16); + __m512 zmm_rhs_0 = _mm512_loadu_ps(rhs + 0); + __m512 zmm_rhs_1 = _mm512_loadu_ps(rhs + 16); + FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) + FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) + FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) + FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) + } + + if (last >= last_aligned + 16) { + __m512 zmm_lhs_0 = _mm512_loadu_ps(lhs); + __m512 zmm_rhs_0 = _mm512_loadu_ps(rhs); + FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) + FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) + FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) + lhs += 16; + rhs += 16; + } + } + + zmm_sum_0 = _mm512_add_ps(zmm_sum_0, zmm_sum_1); + if (lhs != last) { + __mmask16 mask = (__mmask16)((1 << (last - lhs)) - 1); + __m512 zmm_undefined = _mm512_undefined_ps(); + __m512 zmm_lhs_0 = _mm512_mask_loadu_ps(zmm_undefined, mask, lhs); + __m512 zmm_rhs_0 = _mm512_mask_loadu_ps(zmm_undefined, mask, rhs); + FMA_MASK_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0, mask); + FMA_MASK_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1, mask); + FMA_MASK_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2, mask); + } + + *sql = HorizontalAdd_FP32_V512(zmm_sum_norm1); + *sqr = HorizontalAdd_FP32_V512(zmm_sum_norm2); + return HorizontalAdd_FP32_V512(zmm_sum_0); +} +#endif // __AVX512F__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp32_dispatch.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp32_dispatch.cc new file mode 100644 index 00000000..992da0d1 --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp32_dispatch.cc @@ -0,0 +1,136 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) +float InnerProductAndSquaredNormNEON(const float *lhs, const float *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if defined(__AVX512F__) +float InnerProductAndSquaredNormAVX512(const float *lhs, const float *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if defined(__AVX__) +float InnerProductAndSquaredNormAVX(const float *lhs, const float *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if defined(__SSE__) +float InnerProductAndSquaredNormSSE(const float *lhs, const float *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if defined(__SSE4_1__) +float MipsInnerProductSparseInSegmentSSE(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const float *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const float *q_sparse_value); +#endif + +float MipsInnerProductSparseInSegment(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const float *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const float *q_sparse_value); + +#if defined(__SSE__) +//! Compute the distance between matrix and query by SphericalInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F && dim > 15) { + sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); + } else +#endif // __AVX512F__ +#if defined(__AVX__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX && dim > 7) { + sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); + } else +#endif // __AVX__ + { + sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); + } + + *out = ComputeSphericalInjection(sum, u2, v2, e2); +} + +//! Compute the distance between matrix and query by RepeatedQuadraticInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, + float *out) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F && dim > 15) { + sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); + } else +#endif // __AVX512F__ +#if defined(__AVX__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX && dim > 7) { + sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); + } else +#endif // __AVX__ + { + sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); + } + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + *out = sum; +} +#endif // __SSE__ + +template <> +float MipsSquaredEuclideanSparseDistanceMatrix:: + ComputeInnerProductSparseInSegment(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const ValueType *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const ValueType *q_sparse_value) { +#if defined(__SSE4_1__) + return MipsInnerProductSparseInSegmentSSE(m_sparse_count, m_sparse_index, + m_sparse_value, q_sparse_count, + q_sparse_index, q_sparse_value); +#else + return MipsInnerProductSparseInSegment(m_sparse_count, m_sparse_index, + m_sparse_value, q_sparse_count, + q_sparse_index, q_sparse_value); +#endif +} + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp32_neon.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp32_neon.cc new file mode 100644 index 00000000..8e98922c --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp32_neon.cc @@ -0,0 +1,105 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_mips_utility.i" +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__ARM_NEON) +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormNEON(const float *lhs, const float *rhs, + size_t size, float *sql, float *sqr) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 3) << 3); + + float32x4_t v_sum_0 = vdupq_n_f32(0); + float32x4_t v_sum_1 = vdupq_n_f32(0); + float32x4_t v_sum_norm1 = vdupq_n_f32(0); + float32x4_t v_sum_norm2 = vdupq_n_f32(0); + + for (; lhs != last_aligned; lhs += 8, rhs += 8) { + float32x4_t v_lhs_0 = vld1q_f32(lhs + 0); + float32x4_t v_lhs_1 = vld1q_f32(lhs + 4); + float32x4_t v_rhs_0 = vld1q_f32(rhs + 0); + float32x4_t v_rhs_1 = vld1q_f32(rhs + 4); + v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); + v_sum_1 = vfmaq_f32(v_sum_1, v_lhs_1, v_rhs_1); + v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); + v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_1, v_lhs_1); + v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); + v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_1, v_rhs_1); + } + if (last >= last_aligned + 4) { + float32x4_t v_lhs_0 = vld1q_f32(lhs); + float32x4_t v_rhs_0 = vld1q_f32(rhs); + v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); + v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); + v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); + lhs += 4; + rhs += 4; + } + + float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); + float norm1 = vaddvq_f32(v_sum_norm1); + float norm2 = vaddvq_f32(v_sum_norm2); + switch (last - lhs) { + case 3: + FMA_FP32_GENERAL(lhs[2], rhs[2], result, norm1, norm2) + /* FALLTHRU */ + case 2: + FMA_FP32_GENERAL(lhs[1], rhs[1], result, norm1, norm2) + /* FALLTHRU */ + case 1: + FMA_FP32_GENERAL(lhs[0], rhs[0], result, norm1, norm2) + } + *sql = norm1; + *sqr = norm2; + return result; +} + +//! Compute the distance between matrix and query by SphericalInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { + float u2; + float v2; + float sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); + + *out = ComputeSphericalInjection(sum, u2, v2, e2); +} + +//! Compute the distance between matrix and query by RepeatedQuadraticInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, + float *out) { + float u2; + float v2; + float sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + *out = sum; +} +#endif //__ARM_NEON + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp32_sse.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp32_sse.cc new file mode 100644 index 00000000..43d8f9b7 --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp32_sse.cc @@ -0,0 +1,336 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_fp32.i" +#include "distance_matrix_mips_utility.i" +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__SSE__) +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormSSE(const float *lhs, const float *rhs, + size_t size, float *sql, float *sqr) { + const float *last = lhs + size; + const float *last_aligned = lhs + ((size >> 3) << 3); + + __m128 xmm_sum = _mm_setzero_ps(); + __m128 xmm_sum_norm1 = _mm_setzero_ps(); + __m128 xmm_sum_norm2 = _mm_setzero_ps(); + + if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { + for (; lhs != last_aligned; lhs += 8, rhs += 8) { + __m128 xmm_lhs_0 = _mm_load_ps(lhs + 0); + __m128 xmm_lhs_1 = _mm_load_ps(lhs + 4); + __m128 xmm_rhs_0 = _mm_load_ps(rhs + 0); + __m128 xmm_rhs_1 = _mm_load_ps(rhs + 4); + xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); + xmm_sum = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum); + xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); + xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); + xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); + xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); + } + + if (last >= last_aligned + 4) { + __m128 xmm_lhs_0 = _mm_load_ps(lhs); + __m128 xmm_rhs_0 = _mm_load_ps(rhs); + xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); + xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); + xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); + lhs += 4; + rhs += 4; + } + } else { + for (; lhs != last_aligned; lhs += 8, rhs += 8) { + __m128 xmm_lhs_0 = _mm_loadu_ps(lhs + 0); + __m128 xmm_lhs_1 = _mm_loadu_ps(lhs + 4); + __m128 xmm_rhs_0 = _mm_loadu_ps(rhs + 0); + __m128 xmm_rhs_1 = _mm_loadu_ps(rhs + 4); + xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); + xmm_sum = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum); + xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); + xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); + xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); + xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); + } + + if (last >= last_aligned + 4) { + __m128 xmm_lhs_0 = _mm_loadu_ps(lhs); + __m128 xmm_rhs_0 = _mm_loadu_ps(rhs); + xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); + xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); + xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); + lhs += 4; + rhs += 4; + } + } + float result = HorizontalAdd_FP32_V128(xmm_sum); + float norm1 = HorizontalAdd_FP32_V128(xmm_sum_norm1); + float norm2 = HorizontalAdd_FP32_V128(xmm_sum_norm2); + + switch (last - lhs) { + case 3: + FMA_FP32_GENERAL(lhs[2], rhs[2], result, norm1, norm2) + /* FALLTHRU */ + case 2: + FMA_FP32_GENERAL(lhs[1], rhs[1], result, norm1, norm2) + /* FALLTHRU */ + case 1: + FMA_FP32_GENERAL(lhs[0], rhs[0], result, norm1, norm2) + } + *sql = norm1; + *sqr = norm2; + return result; +} + +#endif // __SSE__ + +// #if 1 +#if defined(__SSE4_1__) +const static __m128i SHUFFLE_MASK16[16] = { + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, -127, -127, -127), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, + 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, 15, 14, 13, 12), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 3, 2, 1, 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 7, 6, 5, 4), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, + 11, 10, 9, 8), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, 1, + 0), + _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, + 4), + _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), +}; + +constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; + +float MipsInnerProductSparseInSegmentSSE(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const float *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const float *q_sparse_value) { + float sum = 0.0f; + + // size_t alloc_size = 0; + + size_t i1 = 0, i2 = 0; + size_t end1 = m_sparse_count / 8 * 8; + size_t end2 = q_sparse_count / 8 * 8; + + // std::vector mem1; + // std::vector mem2; + + float fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; + float fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; + + float *val_start_1 = fixed_buffer_1; + float *val_start_2 = fixed_buffer_2; + + // uint32_t max_count = std::max(m_sparse_count, q_sparse_count); + + // if (MAX_SPARSE_BUFFER_LENGTH < max_count) { + // mem1.reserve(max_count); + // mem2.reserve(max_count); + + // val_start_1 = mem1.data(); + // val_start_2 = mem2.data(); + // } + + float *val_1 = val_start_1; + float *val_2 = val_start_2; + + if (i1 < end1 && i2 < end2) { + while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { + i1 += 8; + if (i1 >= end1) goto do_scalar; + } + + while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { + i2 += 8; + if (i2 >= end2) goto do_scalar; + } + + __m128i mm_index_m = + _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); + __m128i mm_index_q = + _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); + + while (true) { +#ifdef DEBUG_PRINT + std::cout << "index 1: " << std::endl; + print_data16(&mm_index_m); + + std::cout << "index 2: " << std::endl; + print_data16(&mm_index_q); +#endif + + __m128i mm_cmp_res = + _mm_cmpistrm(mm_index_q, mm_index_m, + _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); + +#ifdef DEBUG_PRINT + std::cout << "cmp res: " << std::endl; + print_data16(&mm_cmp_res); +#endif + + int r = _mm_extract_epi32(mm_cmp_res, 0); + + if (r) { + int r1 = r & 15; + + __m128i v = _mm_loadu_si128( + reinterpret_cast(&m_sparse_value[i1])); + __m128 vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); + + _mm_storeu_ps(val_1, vs); + val_1 += _mm_popcnt_u32(r1); + + int r2 = (r >> 4) & 15; + v = _mm_loadu_si128( + reinterpret_cast(&m_sparse_value[i1 + 4])); + vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); + _mm_storeu_ps(val_1, vs); + val_1 += _mm_popcnt_u32(r2); + + mm_cmp_res = _mm_cmpistrm( + mm_index_m, mm_index_q, + _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); + r = _mm_extract_epi32(mm_cmp_res, 0); + + r1 = r & 15; + + v = _mm_loadu_si128( + reinterpret_cast(&q_sparse_value[i2])); + vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); + _mm_storeu_ps(val_2, vs); + val_2 += _mm_popcnt_u32(r1); + + r2 = (r >> 4) & 15; + v = _mm_loadu_si128( + reinterpret_cast(&q_sparse_value[i2 + 4])); + vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); + _mm_storeu_ps(val_2, vs); + val_2 += _mm_popcnt_u32(r2); + } + + const uint16_t id1_max = m_sparse_index[i1 + 7]; + + if (id1_max <= q_sparse_index[i2 + 7]) { + i1 += 8; + if (i1 >= end1) goto do_scalar; + mm_index_m = _mm_loadu_si128( + reinterpret_cast(&m_sparse_index[i1])); + } + + if (id1_max >= q_sparse_index[i2 + 7]) { + i2 += 8; + if (i2 >= end2) goto do_scalar; + mm_index_q = _mm_loadu_si128( + reinterpret_cast(&q_sparse_index[i2])); + } + } + } + +do_scalar: + while (i1 < m_sparse_count && i2 < q_sparse_count) { + if (m_sparse_index[i1] == q_sparse_index[i2]) { + *val_1++ = m_sparse_value[i1]; + *val_2++ = q_sparse_value[i2]; + + ++i1; + ++i2; + } else if (m_sparse_index[i1] < q_sparse_index[i2]) { + ++i1; + } else { + ++i2; + } + } + + size_t res_num = val_1 - val_start_1; + + // if (res_num != val_2 - val_start_2) { + // std::cerr << "size mismatch!" << std::endl; + // } + + size_t res_num4 = res_num / 4 * 4; + + if (res_num4) { + __m128 sum128 = _mm_set1_ps(0); + + for (size_t k = 0; k < res_num4; k += 4) { + sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(val_start_1 + k), + _mm_loadu_ps(val_start_2 + k))); + } + + float __attribute__((aligned(16))) tmp_res[4]; + _mm_store_ps(tmp_res, sum128); + sum += (tmp_res[0] + tmp_res[1] + tmp_res[2] + tmp_res[3]); + } + + for (size_t k = res_num4; k < res_num; ++k) + sum += val_start_1[k] * val_start_2[k]; + + return sum; +} +#else +float MipsInnerProductSparseInSegment(uint32_t m_sparse_count, + const uint16_t *m_sparse_index, + const float *m_sparse_value, + uint32_t q_sparse_count, + const uint16_t *q_sparse_index, + const float *q_sparse_value) { + float sum = 0.0f; + + size_t m_i = 0; + size_t q_i = 0; + while (m_i < m_sparse_count && q_i < q_sparse_count) { + if (m_sparse_index[m_i] == q_sparse_index[q_i]) { + sum += m_sparse_value[m_i] * q_sparse_value[q_i]; + + ++m_i; + ++q_i; + } else if (m_sparse_index[m_i] < q_sparse_index[q_i]) { + ++m_i; + } else { + ++q_i; + } + } + + return sum; +} +#endif // __SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int4.cc b/src/ailego/math/mips_euclidean_distance_matrix_int4.cc deleted file mode 100644 index dc74b2d0..00000000 --- a/src/ailego/math/mips_euclidean_distance_matrix_int4.cc +++ /dev/null @@ -1,358 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "distance_matrix_accum_int8.i" -#include "inner_product_matrix.h" -#include "mips_euclidean_distance_matrix.h" -#include "norm_matrix.h" - -namespace zvec { -namespace ailego { - -#if defined(__SSE4_1__) -//! Four-bits Convert Table -static const AILEGO_ALIGNED(32) int8_t Int4ConvertTable[32] = { - 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1, - 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1}; -#endif // __SSE4_1__ - -#if defined(__SSE4_1__) -static const __m128i MASK_INT4_SSE = _mm_set1_epi32(0x0f0f0f0f); -static const __m128i ONES_INT16_SSE = _mm_set1_epi32(0x00010001); -static const __m128i INT4_LOOKUP_SSE = - _mm_load_si128((const __m128i *)Int4ConvertTable); -#endif // __SSE4_1__ - -#if defined(__AVX2__) -static const __m256i MASK_INT4_AVX = _mm256_set1_epi32(0x0f0f0f0f); -static const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); -static const __m256i INT4_LOOKUP_AVX = - _mm256_load_si256((const __m256i *)Int4ConvertTable); -#endif // __AVX2__ - -//! Calculate Fused-Multiply-Add (GENERAL) -#define FMA_INT4_GENERAL(lhs, rhs, sum, norm1, norm2) \ - { \ - sum += Int4MulTable[(((lhs) << 4) & 0xf0) | (((rhs) >> 0) & 0xf)] + \ - Int4MulTable[(((lhs) >> 0) & 0xf0) | (((rhs) >> 4) & 0xf)]; \ - norm1 += static_cast( \ - ((int8_t)((lhs) << 4) >> 4) * ((int8_t)((lhs) << 4) >> 4) + \ - ((int8_t)((lhs) & 0xf0) >> 4) * ((int8_t)((lhs) & 0xf0) >> 4)); \ - norm2 += static_cast( \ - ((int8_t)((rhs) << 4) >> 4) * ((int8_t)((rhs) << 4) >> 4) + \ - ((int8_t)((rhs) & 0xf0) >> 4) * ((int8_t)((rhs) & 0xf0) >> 4)); \ - } - -//! Calculate Fused-Multiply-Add (SSE) -#define FMA_INT8_SSE(xmm_lhs, xmm_rhs, xmm_sum) \ - xmm_sum = _mm_add_epi32( \ - _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ - _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ - ONES_INT16_SSE), \ - xmm_sum) - -//! Calculate Fused-Multiply-Add (AVX) -#define FMA_INT8_AVX(ymm_lhs, ymm_rhs, ymm_sum) \ - ymm_sum = _mm256_add_epi32( \ - _mm256_madd_epi16( \ - _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_rhs), \ - _mm256_sign_epi8(ymm_lhs, ymm_rhs)), \ - ONES_INT16_AVX), \ - ymm_sum) - -//! Compute the distance between matrix and query (SSE) -#define FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum_0, xmm_sum_norm1, \ - xmm_sum_norm2) \ - { \ - __m128i xmm_lhs_0 = _mm_shuffle_epi8( \ - INT4_LOOKUP_SSE, _mm_and_si128((xmm_lhs), MASK_INT4_SSE)); \ - __m128i xmm_rhs_0 = _mm_shuffle_epi8( \ - INT4_LOOKUP_SSE, _mm_and_si128((xmm_rhs), MASK_INT4_SSE)); \ - __m128i xmm_lhs_1 = _mm_shuffle_epi8( \ - INT4_LOOKUP_SSE, \ - _mm_and_si128(_mm_srli_epi32((xmm_lhs), 4), MASK_INT4_SSE)); \ - __m128i xmm_rhs_1 = _mm_shuffle_epi8( \ - INT4_LOOKUP_SSE, \ - _mm_and_si128(_mm_srli_epi32((xmm_rhs), 4), MASK_INT4_SSE)); \ - FMA_INT8_SSE(xmm_lhs_0, xmm_rhs_0, xmm_sum_0); \ - FMA_INT8_SSE(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); \ - FMA_INT8_SSE(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); \ - FMA_INT8_SSE(xmm_lhs_1, xmm_rhs_1, xmm_sum_0); \ - FMA_INT8_SSE(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); \ - FMA_INT8_SSE(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); \ - } - -//! Compute the distance between matrix and query (AVX) -#define FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum_0, ymm_sum1, \ - ymm_sum_norm1, ymm_sum_norm2) \ - { \ - __m256i ymm_lhs_0 = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, _mm256_and_si256((ymm_lhs), MASK_INT4_AVX)); \ - __m256i ymm_rhs_0 = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, _mm256_and_si256((ymm_rhs), MASK_INT4_AVX)); \ - __m256i ymm_lhs_1 = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, \ - _mm256_and_si256(_mm256_srli_epi32((ymm_lhs), 4), MASK_INT4_AVX)); \ - __m256i ymm_rhs_1 = _mm256_shuffle_epi8( \ - INT4_LOOKUP_AVX, \ - _mm256_and_si256(_mm256_srli_epi32((ymm_rhs), 4), MASK_INT4_AVX)); \ - FMA_INT8_AVX(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); \ - FMA_INT8_AVX(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); \ - FMA_INT8_AVX(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); \ - FMA_INT8_AVX(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); \ - FMA_INT8_AVX(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); \ - FMA_INT8_AVX(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); \ - } - -#if defined(__SSE4_1__) -#if defined(__AVX2__) -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormAVX(const uint8_t *lhs, - const uint8_t *rhs, - size_t size, float *sql, - float *sqr) { - const uint8_t *last = lhs + size; - const uint8_t *last_aligned = lhs + ((size >> 5) << 5); - __m256i ymm_sum_0 = _mm256_setzero_si256(); - __m256i ymm_sum_1 = _mm256_setzero_si256(); - __m256i ymm_sum_norm1 = _mm256_setzero_si256(); - __m256i ymm_sum_norm2 = _mm256_setzero_si256(); - - if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m256i ymm_lhs = _mm256_load_si256((const __m256i *)(lhs)); - __m256i ymm_rhs = _mm256_load_si256((const __m256i *)(rhs)); - FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum_0, ymm_sum1, ymm_sum_norm1, - ymm_sum_norm2) - } - if (last >= lhs + 16) { - __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); - __m128i xmm_sum = _mm_setzero_si128(); - __m128i xmm_sum_norm1 = _mm_setzero_si128(); - __m128i xmm_sum_norm2 = _mm_setzero_si128(); - FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) - ymm_sum_0 = _mm256_add_epi32( - _mm256_set_m128i(_mm_setzero_si128(), xmm_sum), ymm_sum_0); - ymm_sum_norm1 = _mm256_add_epi32( - _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm1), ymm_sum_norm1); - ymm_sum_norm2 = _mm256_add_epi32( - _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm2), ymm_sum_norm2); - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)(lhs)); - __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)(rhs)); - FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum_0, ymm_sum1, ymm_sum_norm1, - ymm_sum_norm2) - } - if (last >= lhs + 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); - __m128i xmm_sum = _mm_setzero_si128(); - __m128i xmm_sum_norm1 = _mm_setzero_si128(); - __m128i xmm_sum_norm2 = _mm_setzero_si128(); - FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) - ymm_sum_0 = _mm256_add_epi32( - _mm256_set_m128i(_mm_setzero_si128(), xmm_sum), ymm_sum_0); - ymm_sum_norm1 = _mm256_add_epi32( - _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm1), ymm_sum_norm1); - ymm_sum_norm2 = _mm256_add_epi32( - _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm2), ymm_sum_norm2); - lhs += 16; - rhs += 16; - } - } - float result = static_cast( - HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); - float norm1 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm1)); - float norm2 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm2)); - - switch (last - lhs) { - case 15: - FMA_INT4_GENERAL(lhs[14], rhs[14], result, norm1, norm2) - /* FALLTHRU */ - case 14: - FMA_INT4_GENERAL(lhs[13], rhs[13], result, norm1, norm2) - /* FALLTHRU */ - case 13: - FMA_INT4_GENERAL(lhs[12], rhs[12], result, norm1, norm2) - /* FALLTHRU */ - case 12: - FMA_INT4_GENERAL(lhs[11], rhs[11], result, norm1, norm2) - /* FALLTHRU */ - case 11: - FMA_INT4_GENERAL(lhs[10], rhs[10], result, norm1, norm2) - /* FALLTHRU */ - case 10: - FMA_INT4_GENERAL(lhs[9], rhs[9], result, norm1, norm2) - /* FALLTHRU */ - case 9: - FMA_INT4_GENERAL(lhs[8], rhs[8], result, norm1, norm2) - /* FALLTHRU */ - case 8: - FMA_INT4_GENERAL(lhs[7], rhs[7], result, norm1, norm2) - /* FALLTHRU */ - case 7: - FMA_INT4_GENERAL(lhs[6], rhs[6], result, norm1, norm2) - /* FALLTHRU */ - case 6: - FMA_INT4_GENERAL(lhs[5], rhs[5], result, norm1, norm2) - /* FALLTHRU */ - case 5: - FMA_INT4_GENERAL(lhs[4], rhs[4], result, norm1, norm2) - /* FALLTHRU */ - case 4: - FMA_INT4_GENERAL(lhs[3], rhs[3], result, norm1, norm2) - /* FALLTHRU */ - case 3: - FMA_INT4_GENERAL(lhs[2], rhs[2], result, norm1, norm2) - /* FALLTHRU */ - case 2: - FMA_INT4_GENERAL(lhs[1], rhs[1], result, norm1, norm2) - /* FALLTHRU */ - case 1: - FMA_INT4_GENERAL(lhs[0], rhs[0], result, norm1, norm2) - } - *sql = norm1; - *sqr = norm2; - return result; -} -#else -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormSSE(const uint8_t *lhs, - const uint8_t *rhs, - size_t size, float *sql, - float *sqr) { - const uint8_t *last = lhs + size; - const uint8_t *last_aligned = lhs + ((size >> 4) << 4); - __m128i xmm_sum = _mm_setzero_si128(); - __m128i xmm_sum_norm1 = _mm_setzero_si128(); - __m128i xmm_sum_norm2 = _mm_setzero_si128(); - - if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m128i xmm_lhs = _mm_load_si128((const __m128i *)(lhs)); - __m128i xmm_rhs = _mm_load_si128((const __m128i *)(rhs)); - FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) - } - } else { - for (; lhs != last_aligned; lhs += 16, rhs += 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)(lhs)); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)(rhs)); - FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) - } - } - float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); - float norm1 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm1)); - float norm2 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm2)); - - switch (last - lhs) { - case 15: - FMA_INT4_GENERAL(lhs[14], rhs[14], result, norm1, norm2) - /* FALLTHRU */ - case 14: - FMA_INT4_GENERAL(lhs[13], rhs[13], result, norm1, norm2) - /* FALLTHRU */ - case 13: - FMA_INT4_GENERAL(lhs[12], rhs[12], result, norm1, norm2) - /* FALLTHRU */ - case 12: - FMA_INT4_GENERAL(lhs[11], rhs[11], result, norm1, norm2) - /* FALLTHRU */ - case 11: - FMA_INT4_GENERAL(lhs[10], rhs[10], result, norm1, norm2) - /* FALLTHRU */ - case 10: - FMA_INT4_GENERAL(lhs[9], rhs[9], result, norm1, norm2) - /* FALLTHRU */ - case 9: - FMA_INT4_GENERAL(lhs[8], rhs[8], result, norm1, norm2) - /* FALLTHRU */ - case 8: - FMA_INT4_GENERAL(lhs[7], rhs[7], result, norm1, norm2) - /* FALLTHRU */ - case 7: - FMA_INT4_GENERAL(lhs[6], rhs[6], result, norm1, norm2) - /* FALLTHRU */ - case 6: - FMA_INT4_GENERAL(lhs[5], rhs[5], result, norm1, norm2) - /* FALLTHRU */ - case 5: - FMA_INT4_GENERAL(lhs[4], rhs[4], result, norm1, norm2) - /* FALLTHRU */ - case 4: - FMA_INT4_GENERAL(lhs[3], rhs[3], result, norm1, norm2) - /* FALLTHRU */ - case 3: - FMA_INT4_GENERAL(lhs[2], rhs[2], result, norm1, norm2) - /* FALLTHRU */ - case 2: - FMA_INT4_GENERAL(lhs[1], rhs[1], result, norm1, norm2) - /* FALLTHRU */ - case 1: - FMA_INT4_GENERAL(lhs[0], rhs[0], result, norm1, norm2) - } - *sql = norm1; - *sqr = norm2; - return result; -} -#endif // __AVX2__ - -//! Compute the distance between matrix and query by SphericalInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { - float u2; - float v2; - float sum; - -#if defined(__AVX2__) - sum = InnerProductAndSquaredNormAVX(p, q, dim >> 1, &u2, &v2); -#else - sum = InnerProductAndSquaredNormSSE(p, q, dim >> 1, &u2, &v2); -#endif - - *out = ComputeSphericalInjection(sum, u2, v2, e2); -} - -//! Compute the distance between matrix and query by RepeatedQuadraticInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, - float *out) { - float u2; - float v2; - float sum; - -#if defined(__AVX2__) - sum = InnerProductAndSquaredNormAVX(p, q, dim >> 1, &u2, &v2); -#else - sum = InnerProductAndSquaredNormSSE(p, q, dim >> 1, &u2, &v2); -#endif - - sum = e2 * (u2 + v2 - 2 * sum); - u2 *= e2; - v2 *= e2; - for (size_t i = 0; i < m; ++i) { - sum += (u2 - v2) * (u2 - v2); - u2 = u2 * u2; - v2 = v2 * v2; - } - *out = sum; -} -#endif // __SSE4_1__ - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int4_avx2.cc b/src/ailego/math/mips_euclidean_distance_matrix_int4_avx2.cc new file mode 100644 index 00000000..95a3f007 --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_int4_avx2.cc @@ -0,0 +1,140 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int8.i" +#include "distance_matrix_mips_utility.i" +#include "inner_product_matrix.h" +#include "mips_euclidean_distance_matrix.h" +#include "norm_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX2__) +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormAVX(const uint8_t *lhs, const uint8_t *rhs, + size_t size, float *sql, float *sqr) { + const uint8_t *last = lhs + size; + const uint8_t *last_aligned = lhs + ((size >> 5) << 5); + __m256i ymm_sum_0 = _mm256_setzero_si256(); + __m256i ymm_sum_1 = _mm256_setzero_si256(); + __m256i ymm_sum_norm1 = _mm256_setzero_si256(); + __m256i ymm_sum_norm2 = _mm256_setzero_si256(); + + if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m256i ymm_lhs = _mm256_load_si256((const __m256i *)(lhs)); + __m256i ymm_rhs = _mm256_load_si256((const __m256i *)(rhs)); + FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum_0, ymm_sum1, ymm_sum_norm1, + ymm_sum_norm2) + } + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); + __m128i xmm_sum = _mm_setzero_si128(); + __m128i xmm_sum_norm1 = _mm_setzero_si128(); + __m128i xmm_sum_norm2 = _mm_setzero_si128(); + FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) + ymm_sum_0 = _mm256_add_epi32( + _mm256_set_m128i(_mm_setzero_si128(), xmm_sum), ymm_sum_0); + ymm_sum_norm1 = _mm256_add_epi32( + _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm1), ymm_sum_norm1); + ymm_sum_norm2 = _mm256_add_epi32( + _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm2), ymm_sum_norm2); + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)(lhs)); + __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)(rhs)); + FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum_0, ymm_sum1, ymm_sum_norm1, + ymm_sum_norm2) + } + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); + __m128i xmm_sum = _mm_setzero_si128(); + __m128i xmm_sum_norm1 = _mm_setzero_si128(); + __m128i xmm_sum_norm2 = _mm_setzero_si128(); + FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) + ymm_sum_0 = _mm256_add_epi32( + _mm256_set_m128i(_mm_setzero_si128(), xmm_sum), ymm_sum_0); + ymm_sum_norm1 = _mm256_add_epi32( + _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm1), ymm_sum_norm1); + ymm_sum_norm2 = _mm256_add_epi32( + _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm2), ymm_sum_norm2); + lhs += 16; + rhs += 16; + } + } + float result = static_cast( + HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); + float norm1 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm1)); + float norm2 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm2)); + + switch (last - lhs) { + case 15: + FMA_INT4_GENERAL(lhs[14], rhs[14], result, norm1, norm2) + /* FALLTHRU */ + case 14: + FMA_INT4_GENERAL(lhs[13], rhs[13], result, norm1, norm2) + /* FALLTHRU */ + case 13: + FMA_INT4_GENERAL(lhs[12], rhs[12], result, norm1, norm2) + /* FALLTHRU */ + case 12: + FMA_INT4_GENERAL(lhs[11], rhs[11], result, norm1, norm2) + /* FALLTHRU */ + case 11: + FMA_INT4_GENERAL(lhs[10], rhs[10], result, norm1, norm2) + /* FALLTHRU */ + case 10: + FMA_INT4_GENERAL(lhs[9], rhs[9], result, norm1, norm2) + /* FALLTHRU */ + case 9: + FMA_INT4_GENERAL(lhs[8], rhs[8], result, norm1, norm2) + /* FALLTHRU */ + case 8: + FMA_INT4_GENERAL(lhs[7], rhs[7], result, norm1, norm2) + /* FALLTHRU */ + case 7: + FMA_INT4_GENERAL(lhs[6], rhs[6], result, norm1, norm2) + /* FALLTHRU */ + case 6: + FMA_INT4_GENERAL(lhs[5], rhs[5], result, norm1, norm2) + /* FALLTHRU */ + case 5: + FMA_INT4_GENERAL(lhs[4], rhs[4], result, norm1, norm2) + /* FALLTHRU */ + case 4: + FMA_INT4_GENERAL(lhs[3], rhs[3], result, norm1, norm2) + /* FALLTHRU */ + case 3: + FMA_INT4_GENERAL(lhs[2], rhs[2], result, norm1, norm2) + /* FALLTHRU */ + case 2: + FMA_INT4_GENERAL(lhs[1], rhs[1], result, norm1, norm2) + /* FALLTHRU */ + case 1: + FMA_INT4_GENERAL(lhs[0], rhs[0], result, norm1, norm2) + } + *sql = norm1; + *sqr = norm2; + return result; +} +#endif // __AVX2__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int4_dispatch.cc b/src/ailego/math/mips_euclidean_distance_matrix_int4_dispatch.cc new file mode 100644 index 00000000..c967f832 --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_int4_dispatch.cc @@ -0,0 +1,83 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "inner_product_matrix.h" +#include "mips_euclidean_distance_matrix.h" +#include "norm_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX__) +float InnerProductAndSquaredNormAVX(const uint8_t *lhs, const uint8_t *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if defined(__SSE__) +float InnerProductAndSquaredNormSSE(const uint8_t *lhs, const uint8_t *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if defined(__SSE4_1__) +//! Compute the distance between matrix and query by SphericalInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + sum = InnerProductAndSquaredNormAVX(p, q, dim >> 1, &u2, &v2); + } else +#endif + { + sum = InnerProductAndSquaredNormSSE(p, q, dim >> 1, &u2, &v2); + } + + *out = ComputeSphericalInjection(sum, u2, v2, e2); +} + +//! Compute the distance between matrix and query by RepeatedQuadraticInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, + float *out) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + sum = InnerProductAndSquaredNormAVX(p, q, dim >> 1, &u2, &v2); + } else +#endif + { + sum = InnerProductAndSquaredNormSSE(p, q, dim >> 1, &u2, &v2); + } + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + *out = sum; +} +#endif + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int4_sse.cc b/src/ailego/math/mips_euclidean_distance_matrix_int4_sse.cc new file mode 100644 index 00000000..139b14c9 --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_int4_sse.cc @@ -0,0 +1,104 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int8.i" +#include "distance_matrix_mips_utility.i" +#include "inner_product_matrix.h" +#include "mips_euclidean_distance_matrix.h" +#include "norm_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__SSE4_1__) +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormSSE(const uint8_t *lhs, const uint8_t *rhs, + size_t size, float *sql, float *sqr) { + const uint8_t *last = lhs + size; + const uint8_t *last_aligned = lhs + ((size >> 4) << 4); + __m128i xmm_sum = _mm_setzero_si128(); + __m128i xmm_sum_norm1 = _mm_setzero_si128(); + __m128i xmm_sum_norm2 = _mm_setzero_si128(); + + if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)(lhs)); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)(rhs)); + FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) + } + } else { + for (; lhs != last_aligned; lhs += 16, rhs += 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)(lhs)); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)(rhs)); + FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) + } + } + float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); + float norm1 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm1)); + float norm2 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm2)); + + switch (last - lhs) { + case 15: + FMA_INT4_GENERAL(lhs[14], rhs[14], result, norm1, norm2) + /* FALLTHRU */ + case 14: + FMA_INT4_GENERAL(lhs[13], rhs[13], result, norm1, norm2) + /* FALLTHRU */ + case 13: + FMA_INT4_GENERAL(lhs[12], rhs[12], result, norm1, norm2) + /* FALLTHRU */ + case 12: + FMA_INT4_GENERAL(lhs[11], rhs[11], result, norm1, norm2) + /* FALLTHRU */ + case 11: + FMA_INT4_GENERAL(lhs[10], rhs[10], result, norm1, norm2) + /* FALLTHRU */ + case 10: + FMA_INT4_GENERAL(lhs[9], rhs[9], result, norm1, norm2) + /* FALLTHRU */ + case 9: + FMA_INT4_GENERAL(lhs[8], rhs[8], result, norm1, norm2) + /* FALLTHRU */ + case 8: + FMA_INT4_GENERAL(lhs[7], rhs[7], result, norm1, norm2) + /* FALLTHRU */ + case 7: + FMA_INT4_GENERAL(lhs[6], rhs[6], result, norm1, norm2) + /* FALLTHRU */ + case 6: + FMA_INT4_GENERAL(lhs[5], rhs[5], result, norm1, norm2) + /* FALLTHRU */ + case 5: + FMA_INT4_GENERAL(lhs[4], rhs[4], result, norm1, norm2) + /* FALLTHRU */ + case 4: + FMA_INT4_GENERAL(lhs[3], rhs[3], result, norm1, norm2) + /* FALLTHRU */ + case 3: + FMA_INT4_GENERAL(lhs[2], rhs[2], result, norm1, norm2) + /* FALLTHRU */ + case 2: + FMA_INT4_GENERAL(lhs[1], rhs[1], result, norm1, norm2) + /* FALLTHRU */ + case 1: + FMA_INT4_GENERAL(lhs[0], rhs[0], result, norm1, norm2) + } + *sql = norm1; + *sqr = norm2; + return result; +} +#endif // __SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int8.cc b/src/ailego/math/mips_euclidean_distance_matrix_int8.cc deleted file mode 100644 index accf3d1c..00000000 --- a/src/ailego/math/mips_euclidean_distance_matrix_int8.cc +++ /dev/null @@ -1,358 +0,0 @@ -// Copyright 2025-present the zvec project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "distance_matrix_accum_int8.i" -#include "mips_euclidean_distance_matrix.h" - -namespace zvec { -namespace ailego { - -#if defined(__SSE4_1__) -static const __m128i ONES_INT16_SSE = _mm_set1_epi32(0x00010001); -#endif // __SSE4_1__ - -#if defined(__AVX2__) -static const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); -#endif // __AVX2__ - -//! Calculate Fused-Multiply-Add (GENERAL) -#define FMA_INT8_GENERAL(lhs, rhs, sum, norm1, norm2) \ - { \ - sum += static_cast(lhs * rhs); \ - norm1 += static_cast(lhs * lhs); \ - norm2 += static_cast(rhs * rhs); \ - } - -//! Calculate Fused-Multiply-Add (AVX) -#define FMA_INT8_AVX(ymm_lhs, ymm_rhs, ymm_sum) \ - ymm_sum = _mm256_add_epi32( \ - _mm256_madd_epi16( \ - _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_rhs), \ - _mm256_sign_epi8(ymm_lhs, ymm_rhs)), \ - ONES_INT16_AVX), \ - ymm_sum) -#define FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_rhs, ymm_sum) \ - ymm_sum = _mm256_add_epi32( \ - _mm256_set_m128i( \ - _mm_setzero_si128(), \ - _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ - _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ - ONES_INT16_SSE)), \ - ymm_sum) - -//! Calculate Fused-Multiply-Add (SSE) -#define FMA_INT8_SSE(xmm_lhs, xmm_rhs, xmm_sum) \ - xmm_sum = _mm_add_epi32( \ - _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ - _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ - ONES_INT16_SSE), \ - xmm_sum) - -#if defined(__SSE4_1__) -#if defined(__AVX2__) -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormAVX(const int8_t *lhs, - const int8_t *rhs, - size_t size, float *sql, - float *sqr) { - const int8_t *last = lhs + size; - const int8_t *last_aligned = lhs + ((size >> 6) << 6); - - __m256i ymm_sum_0 = _mm256_setzero_si256(); - __m256i ymm_sum_1 = _mm256_setzero_si256(); - __m256i ymm_sum_norm1 = _mm256_setzero_si256(); - __m256i ymm_sum_norm2 = _mm256_setzero_si256(); - - if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { - for (; lhs != last_aligned; lhs += 64, rhs += 64) { - __m256i ymm_lhs_0 = _mm256_load_si256((const __m256i *)(lhs + 0)); - __m256i ymm_lhs_1 = _mm256_load_si256((const __m256i *)(lhs + 32)); - __m256i ymm_rhs_0 = _mm256_load_si256((const __m256i *)(rhs + 0)); - __m256i ymm_rhs_1 = _mm256_load_si256((const __m256i *)(rhs + 32)); - FMA_INT8_AVX(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - FMA_INT8_AVX(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); - FMA_INT8_AVX(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); - FMA_INT8_AVX(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); - FMA_INT8_AVX(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); - FMA_INT8_AVX(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); - } - - if (last >= last_aligned + 32) { - __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); - __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); - FMA_INT8_AVX(ymm_lhs, ymm_rhs, ymm_sum_0); - FMA_INT8_AVX(ymm_lhs, ymm_lhs, ymm_sum_norm1); - FMA_INT8_AVX(ymm_rhs, ymm_rhs, ymm_sum_norm2); - lhs += 32; - rhs += 32; - } - - if (last >= lhs + 16) { - __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); - FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_rhs, ymm_sum_0); - FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_lhs, ymm_sum_norm1); - FMA_INT8_AVX_SSE_HYBRID(xmm_rhs, xmm_rhs, ymm_sum_norm2); - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 64, rhs += 64) { - __m256i ymm_lhs_0 = _mm256_loadu_si256((const __m256i *)(lhs + 0)); - __m256i ymm_lhs_1 = _mm256_loadu_si256((const __m256i *)(lhs + 32)); - __m256i ymm_rhs_0 = _mm256_loadu_si256((const __m256i *)(rhs + 0)); - __m256i ymm_rhs_1 = _mm256_loadu_si256((const __m256i *)(rhs + 32)); - FMA_INT8_AVX(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); - FMA_INT8_AVX(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); - FMA_INT8_AVX(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); - FMA_INT8_AVX(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); - FMA_INT8_AVX(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); - FMA_INT8_AVX(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); - } - - if (last >= last_aligned + 32) { - __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); - __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); - FMA_INT8_AVX(ymm_lhs, ymm_rhs, ymm_sum_0); - FMA_INT8_AVX(ymm_lhs, ymm_lhs, ymm_sum_norm1); - FMA_INT8_AVX(ymm_rhs, ymm_rhs, ymm_sum_norm2); - lhs += 32; - rhs += 32; - } - - if (last >= lhs + 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); - FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_rhs, ymm_sum_0); - FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_lhs, ymm_sum_norm1); - FMA_INT8_AVX_SSE_HYBRID(xmm_rhs, xmm_rhs, ymm_sum_norm2); - lhs += 16; - rhs += 16; - } - } - float result = static_cast( - HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); - float norm1 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm1)); - float norm2 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm2)); - - switch (last - lhs) { - case 15: - FMA_INT8_GENERAL(lhs[14], rhs[14], result, norm1, norm2) - /* FALLTHRU */ - case 14: - FMA_INT8_GENERAL(lhs[13], rhs[13], result, norm1, norm2) - /* FALLTHRU */ - case 13: - FMA_INT8_GENERAL(lhs[12], rhs[12], result, norm1, norm2) - /* FALLTHRU */ - case 12: - FMA_INT8_GENERAL(lhs[11], rhs[11], result, norm1, norm2) - /* FALLTHRU */ - case 11: - FMA_INT8_GENERAL(lhs[10], rhs[10], result, norm1, norm2) - /* FALLTHRU */ - case 10: - FMA_INT8_GENERAL(lhs[9], rhs[9], result, norm1, norm2) - /* FALLTHRU */ - case 9: - FMA_INT8_GENERAL(lhs[8], rhs[8], result, norm1, norm2) - /* FALLTHRU */ - case 8: - FMA_INT8_GENERAL(lhs[7], rhs[7], result, norm1, norm2) - /* FALLTHRU */ - case 7: - FMA_INT8_GENERAL(lhs[6], rhs[6], result, norm1, norm2) - /* FALLTHRU */ - case 6: - FMA_INT8_GENERAL(lhs[5], rhs[5], result, norm1, norm2) - /* FALLTHRU */ - case 5: - FMA_INT8_GENERAL(lhs[4], rhs[4], result, norm1, norm2) - /* FALLTHRU */ - case 4: - FMA_INT8_GENERAL(lhs[3], rhs[3], result, norm1, norm2) - /* FALLTHRU */ - case 3: - FMA_INT8_GENERAL(lhs[2], rhs[2], result, norm1, norm2) - /* FALLTHRU */ - case 2: - FMA_INT8_GENERAL(lhs[1], rhs[1], result, norm1, norm2) - /* FALLTHRU */ - case 1: - FMA_INT8_GENERAL(lhs[0], rhs[0], result, norm1, norm2) - } - *sql = norm1; - *sqr = norm2; - return result; -} -#else -//! Compute the Inner Product between p and q, and each Squared L2-Norm value -static inline float InnerProductAndSquaredNormSSE(const int8_t *lhs, - const int8_t *rhs, - size_t size, float *sql, - float *sqr) { - const int8_t *last = lhs + size; - const int8_t *last_aligned = lhs + ((size >> 5) << 5); - - __m128i xmm_sum = _mm_setzero_si128(); - __m128i xmm_sum_norm1 = _mm_setzero_si128(); - __m128i xmm_sum_norm2 = _mm_setzero_si128(); - - if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m128i xmm_lhs_0 = _mm_load_si128((const __m128i *)(lhs + 0)); - __m128i xmm_lhs_1 = _mm_load_si128((const __m128i *)(lhs + 16)); - __m128i xmm_rhs_0 = _mm_load_si128((const __m128i *)(rhs + 0)); - __m128i xmm_rhs_1 = _mm_load_si128((const __m128i *)(rhs + 16)); - FMA_INT8_SSE(xmm_lhs_0, xmm_rhs_0, xmm_sum); - FMA_INT8_SSE(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); - FMA_INT8_SSE(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); - FMA_INT8_SSE(xmm_lhs_1, xmm_rhs_1, xmm_sum); - FMA_INT8_SSE(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); - FMA_INT8_SSE(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); - } - - if (last >= last_aligned + 16) { - __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); - FMA_INT8_SSE(xmm_lhs, xmm_rhs, xmm_sum); - FMA_INT8_SSE(xmm_lhs, xmm_lhs, xmm_sum_norm1); - FMA_INT8_SSE(xmm_rhs, xmm_rhs, xmm_sum_norm2); - lhs += 16; - rhs += 16; - } - } else { - for (; lhs != last_aligned; lhs += 32, rhs += 32) { - __m128i xmm_lhs_0 = _mm_loadu_si128((const __m128i *)(lhs + 0)); - __m128i xmm_lhs_1 = _mm_loadu_si128((const __m128i *)(lhs + 16)); - __m128i xmm_rhs_0 = _mm_loadu_si128((const __m128i *)(rhs + 0)); - __m128i xmm_rhs_1 = _mm_loadu_si128((const __m128i *)(rhs + 16)); - FMA_INT8_SSE(xmm_lhs_0, xmm_rhs_0, xmm_sum); - FMA_INT8_SSE(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); - FMA_INT8_SSE(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); - FMA_INT8_SSE(xmm_lhs_1, xmm_rhs_1, xmm_sum); - FMA_INT8_SSE(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); - FMA_INT8_SSE(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); - } - - if (last >= last_aligned + 16) { - __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); - __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); - FMA_INT8_SSE(xmm_lhs, xmm_rhs, xmm_sum); - FMA_INT8_SSE(xmm_lhs, xmm_lhs, xmm_sum_norm1); - FMA_INT8_SSE(xmm_rhs, xmm_rhs, xmm_sum_norm2); - lhs += 16; - rhs += 16; - } - } - float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); - float norm1 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm1)); - float norm2 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm2)); - - switch (last - lhs) { - case 15: - FMA_INT8_GENERAL(lhs[14], rhs[14], result, norm1, norm2) - /* FALLTHRU */ - case 14: - FMA_INT8_GENERAL(lhs[13], rhs[13], result, norm1, norm2) - /* FALLTHRU */ - case 13: - FMA_INT8_GENERAL(lhs[12], rhs[12], result, norm1, norm2) - /* FALLTHRU */ - case 12: - FMA_INT8_GENERAL(lhs[11], rhs[11], result, norm1, norm2) - /* FALLTHRU */ - case 11: - FMA_INT8_GENERAL(lhs[10], rhs[10], result, norm1, norm2) - /* FALLTHRU */ - case 10: - FMA_INT8_GENERAL(lhs[9], rhs[9], result, norm1, norm2) - /* FALLTHRU */ - case 9: - FMA_INT8_GENERAL(lhs[8], rhs[8], result, norm1, norm2) - /* FALLTHRU */ - case 8: - FMA_INT8_GENERAL(lhs[7], rhs[7], result, norm1, norm2) - /* FALLTHRU */ - case 7: - FMA_INT8_GENERAL(lhs[6], rhs[6], result, norm1, norm2) - /* FALLTHRU */ - case 6: - FMA_INT8_GENERAL(lhs[5], rhs[5], result, norm1, norm2) - /* FALLTHRU */ - case 5: - FMA_INT8_GENERAL(lhs[4], rhs[4], result, norm1, norm2) - /* FALLTHRU */ - case 4: - FMA_INT8_GENERAL(lhs[3], rhs[3], result, norm1, norm2) - /* FALLTHRU */ - case 3: - FMA_INT8_GENERAL(lhs[2], rhs[2], result, norm1, norm2) - /* FALLTHRU */ - case 2: - FMA_INT8_GENERAL(lhs[1], rhs[1], result, norm1, norm2) - /* FALLTHRU */ - case 1: - FMA_INT8_GENERAL(lhs[0], rhs[0], result, norm1, norm2) - } - *sql = norm1; - *sqr = norm2; - return result; -} -#endif // __AVX2__ - -//! Compute the distance between matrix and query by SphericalInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { - float u2; - float v2; - float sum; - -#if defined(__AVX2__) - sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); -#else - sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); -#endif - - *out = ComputeSphericalInjection(sum, u2, v2, e2); -} - -//! Compute the distance between matrix and query by RepeatedQuadraticInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, - float *out) { - float u2; - float v2; - float sum; - -#if defined(__AVX2__) - sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); -#else - sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); -#endif - - sum = e2 * (u2 + v2 - 2 * sum); - u2 *= e2; - v2 *= e2; - for (size_t i = 0; i < m; ++i) { - sum += (u2 - v2) * (u2 - v2); - u2 = u2 * u2; - v2 = v2 * v2; - } - *out = sum; -} -#endif // __SSE4_1__ - -} // namespace ailego -} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int8_avx2.cc b/src/ailego/math/mips_euclidean_distance_matrix_int8_avx2.cc new file mode 100644 index 00000000..0b969537 --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_int8_avx2.cc @@ -0,0 +1,159 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int8.i" +#include "distance_matrix_mips_utility.i" +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX2__) +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormAVX2(const int8_t *lhs, const int8_t *rhs, + size_t size, float *sql, float *sqr) { + const int8_t *last = lhs + size; + const int8_t *last_aligned = lhs + ((size >> 6) << 6); + + __m256i ymm_sum_0 = _mm256_setzero_si256(); + __m256i ymm_sum_1 = _mm256_setzero_si256(); + __m256i ymm_sum_norm1 = _mm256_setzero_si256(); + __m256i ymm_sum_norm2 = _mm256_setzero_si256(); + + if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + __m256i ymm_lhs_0 = _mm256_load_si256((const __m256i *)(lhs + 0)); + __m256i ymm_lhs_1 = _mm256_load_si256((const __m256i *)(lhs + 32)); + __m256i ymm_rhs_0 = _mm256_load_si256((const __m256i *)(rhs + 0)); + __m256i ymm_rhs_1 = _mm256_load_si256((const __m256i *)(rhs + 32)); + FMA_INT8_AVX(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + FMA_INT8_AVX(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); + FMA_INT8_AVX(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); + FMA_INT8_AVX(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); + FMA_INT8_AVX(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); + FMA_INT8_AVX(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); + } + + if (last >= last_aligned + 32) { + __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); + __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); + FMA_INT8_AVX(ymm_lhs, ymm_rhs, ymm_sum_0); + FMA_INT8_AVX(ymm_lhs, ymm_lhs, ymm_sum_norm1); + FMA_INT8_AVX(ymm_rhs, ymm_rhs, ymm_sum_norm2); + lhs += 32; + rhs += 32; + } + + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); + FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_rhs, ymm_sum_0); + FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_lhs, ymm_sum_norm1); + FMA_INT8_AVX_SSE_HYBRID(xmm_rhs, xmm_rhs, ymm_sum_norm2); + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + __m256i ymm_lhs_0 = _mm256_loadu_si256((const __m256i *)(lhs + 0)); + __m256i ymm_lhs_1 = _mm256_loadu_si256((const __m256i *)(lhs + 32)); + __m256i ymm_rhs_0 = _mm256_loadu_si256((const __m256i *)(rhs + 0)); + __m256i ymm_rhs_1 = _mm256_loadu_si256((const __m256i *)(rhs + 32)); + FMA_INT8_AVX(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); + FMA_INT8_AVX(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); + FMA_INT8_AVX(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); + FMA_INT8_AVX(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); + FMA_INT8_AVX(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); + FMA_INT8_AVX(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); + } + + if (last >= last_aligned + 32) { + __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); + __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); + FMA_INT8_AVX(ymm_lhs, ymm_rhs, ymm_sum_0); + FMA_INT8_AVX(ymm_lhs, ymm_lhs, ymm_sum_norm1); + FMA_INT8_AVX(ymm_rhs, ymm_rhs, ymm_sum_norm2); + lhs += 32; + rhs += 32; + } + + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); + FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_rhs, ymm_sum_0); + FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_lhs, ymm_sum_norm1); + FMA_INT8_AVX_SSE_HYBRID(xmm_rhs, xmm_rhs, ymm_sum_norm2); + lhs += 16; + rhs += 16; + } + } + float result = static_cast( + HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); + float norm1 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm1)); + float norm2 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm2)); + + switch (last - lhs) { + case 15: + FMA_INT8_GENERAL(lhs[14], rhs[14], result, norm1, norm2) + /* FALLTHRU */ + case 14: + FMA_INT8_GENERAL(lhs[13], rhs[13], result, norm1, norm2) + /* FALLTHRU */ + case 13: + FMA_INT8_GENERAL(lhs[12], rhs[12], result, norm1, norm2) + /* FALLTHRU */ + case 12: + FMA_INT8_GENERAL(lhs[11], rhs[11], result, norm1, norm2) + /* FALLTHRU */ + case 11: + FMA_INT8_GENERAL(lhs[10], rhs[10], result, norm1, norm2) + /* FALLTHRU */ + case 10: + FMA_INT8_GENERAL(lhs[9], rhs[9], result, norm1, norm2) + /* FALLTHRU */ + case 9: + FMA_INT8_GENERAL(lhs[8], rhs[8], result, norm1, norm2) + /* FALLTHRU */ + case 8: + FMA_INT8_GENERAL(lhs[7], rhs[7], result, norm1, norm2) + /* FALLTHRU */ + case 7: + FMA_INT8_GENERAL(lhs[6], rhs[6], result, norm1, norm2) + /* FALLTHRU */ + case 6: + FMA_INT8_GENERAL(lhs[5], rhs[5], result, norm1, norm2) + /* FALLTHRU */ + case 5: + FMA_INT8_GENERAL(lhs[4], rhs[4], result, norm1, norm2) + /* FALLTHRU */ + case 4: + FMA_INT8_GENERAL(lhs[3], rhs[3], result, norm1, norm2) + /* FALLTHRU */ + case 3: + FMA_INT8_GENERAL(lhs[2], rhs[2], result, norm1, norm2) + /* FALLTHRU */ + case 2: + FMA_INT8_GENERAL(lhs[1], rhs[1], result, norm1, norm2) + /* FALLTHRU */ + case 1: + FMA_INT8_GENERAL(lhs[0], rhs[0], result, norm1, norm2) + } + *sql = norm1; + *sqr = norm2; + return result; +} +#endif // __AVX2__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int8_dispatch.cc b/src/ailego/math/mips_euclidean_distance_matrix_int8_dispatch.cc new file mode 100644 index 00000000..35b4a8c8 --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_int8_dispatch.cc @@ -0,0 +1,81 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__AVX2__) +float InnerProductAndSquaredNormAVX2(const int8_t *lhs, const int8_t *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if defined(__SSE__) +float InnerProductAndSquaredNormSSE(const int8_t *lhs, const int8_t *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if defined(__SSE4_1__) +//! Compute the distance between matrix and query by SphericalInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + sum = InnerProductAndSquaredNormAVX2(p, q, dim, &u2, &v2); + } else +#endif + { + sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); + } + + *out = ComputeSphericalInjection(sum, u2, v2, e2); +} + +//! Compute the distance between matrix and query by RepeatedQuadraticInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, + float *out) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + sum = InnerProductAndSquaredNormAVX2(p, q, dim, &u2, &v2); + } else +#endif + { + sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); + } + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + *out = sum; +} +#endif // __SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int8_sse.cc b/src/ailego/math/mips_euclidean_distance_matrix_int8_sse.cc new file mode 100644 index 00000000..a0d6192c --- /dev/null +++ b/src/ailego/math/mips_euclidean_distance_matrix_int8_sse.cc @@ -0,0 +1,137 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "distance_matrix_accum_int8.i" +#include "distance_matrix_mips_utility.i" +#include "mips_euclidean_distance_matrix.h" + +namespace zvec { +namespace ailego { + +#if defined(__SSE4_1__) +//! Compute the Inner Product between p and q, and each Squared L2-Norm value +float InnerProductAndSquaredNormSSE(const int8_t *lhs, const int8_t *rhs, + size_t size, float *sql, float *sqr) { + const int8_t *last = lhs + size; + const int8_t *last_aligned = lhs + ((size >> 5) << 5); + + __m128i xmm_sum = _mm_setzero_si128(); + __m128i xmm_sum_norm1 = _mm_setzero_si128(); + __m128i xmm_sum_norm2 = _mm_setzero_si128(); + + if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m128i xmm_lhs_0 = _mm_load_si128((const __m128i *)(lhs + 0)); + __m128i xmm_lhs_1 = _mm_load_si128((const __m128i *)(lhs + 16)); + __m128i xmm_rhs_0 = _mm_load_si128((const __m128i *)(rhs + 0)); + __m128i xmm_rhs_1 = _mm_load_si128((const __m128i *)(rhs + 16)); + FMA_INT8_SSE(xmm_lhs_0, xmm_rhs_0, xmm_sum); + FMA_INT8_SSE(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); + FMA_INT8_SSE(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); + FMA_INT8_SSE(xmm_lhs_1, xmm_rhs_1, xmm_sum); + FMA_INT8_SSE(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); + FMA_INT8_SSE(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); + } + + if (last >= last_aligned + 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); + FMA_INT8_SSE(xmm_lhs, xmm_rhs, xmm_sum); + FMA_INT8_SSE(xmm_lhs, xmm_lhs, xmm_sum_norm1); + FMA_INT8_SSE(xmm_rhs, xmm_rhs, xmm_sum_norm2); + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 32, rhs += 32) { + __m128i xmm_lhs_0 = _mm_loadu_si128((const __m128i *)(lhs + 0)); + __m128i xmm_lhs_1 = _mm_loadu_si128((const __m128i *)(lhs + 16)); + __m128i xmm_rhs_0 = _mm_loadu_si128((const __m128i *)(rhs + 0)); + __m128i xmm_rhs_1 = _mm_loadu_si128((const __m128i *)(rhs + 16)); + FMA_INT8_SSE(xmm_lhs_0, xmm_rhs_0, xmm_sum); + FMA_INT8_SSE(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); + FMA_INT8_SSE(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); + FMA_INT8_SSE(xmm_lhs_1, xmm_rhs_1, xmm_sum); + FMA_INT8_SSE(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); + FMA_INT8_SSE(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); + } + + if (last >= last_aligned + 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); + FMA_INT8_SSE(xmm_lhs, xmm_rhs, xmm_sum); + FMA_INT8_SSE(xmm_lhs, xmm_lhs, xmm_sum_norm1); + FMA_INT8_SSE(xmm_rhs, xmm_rhs, xmm_sum_norm2); + lhs += 16; + rhs += 16; + } + } + float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); + float norm1 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm1)); + float norm2 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm2)); + + switch (last - lhs) { + case 15: + FMA_INT8_GENERAL(lhs[14], rhs[14], result, norm1, norm2) + /* FALLTHRU */ + case 14: + FMA_INT8_GENERAL(lhs[13], rhs[13], result, norm1, norm2) + /* FALLTHRU */ + case 13: + FMA_INT8_GENERAL(lhs[12], rhs[12], result, norm1, norm2) + /* FALLTHRU */ + case 12: + FMA_INT8_GENERAL(lhs[11], rhs[11], result, norm1, norm2) + /* FALLTHRU */ + case 11: + FMA_INT8_GENERAL(lhs[10], rhs[10], result, norm1, norm2) + /* FALLTHRU */ + case 10: + FMA_INT8_GENERAL(lhs[9], rhs[9], result, norm1, norm2) + /* FALLTHRU */ + case 9: + FMA_INT8_GENERAL(lhs[8], rhs[8], result, norm1, norm2) + /* FALLTHRU */ + case 8: + FMA_INT8_GENERAL(lhs[7], rhs[7], result, norm1, norm2) + /* FALLTHRU */ + case 7: + FMA_INT8_GENERAL(lhs[6], rhs[6], result, norm1, norm2) + /* FALLTHRU */ + case 6: + FMA_INT8_GENERAL(lhs[5], rhs[5], result, norm1, norm2) + /* FALLTHRU */ + case 5: + FMA_INT8_GENERAL(lhs[4], rhs[4], result, norm1, norm2) + /* FALLTHRU */ + case 4: + FMA_INT8_GENERAL(lhs[3], rhs[3], result, norm1, norm2) + /* FALLTHRU */ + case 3: + FMA_INT8_GENERAL(lhs[2], rhs[2], result, norm1, norm2) + /* FALLTHRU */ + case 2: + FMA_INT8_GENERAL(lhs[1], rhs[1], result, norm1, norm2) + /* FALLTHRU */ + case 1: + FMA_INT8_GENERAL(lhs[0], rhs[0], result, norm1, norm2) + } + *sql = norm1; + *sqr = norm2; + return result; +} +#endif // __SSE4_1__ + +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math/norm1_matrix_fp16.cc b/src/ailego/math/norm1_matrix_fp16.cc index ce36de7e..9bb86201 100644 --- a/src/ailego/math/norm1_matrix_fp16.cc +++ b/src/ailego/math/norm1_matrix_fp16.cc @@ -73,13 +73,17 @@ void Norm1Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP16_1_NEON(m, dim, out, ) -#elif defined(__AVX512F__) - NORM_FP16_1_AVX512(m, dim, out, ) #else +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + NORM_FP16_1_AVX512(m, dim, out, ) + return; + } +#endif NORM_FP16_1_AVX(m, dim, out, ) #endif } #endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/norm1_matrix_fp32.cc b/src/ailego/math/norm1_matrix_fp32.cc index 15bd3f73..c1ec3668 100644 --- a/src/ailego/math/norm1_matrix_fp32.cc +++ b/src/ailego/math/norm1_matrix_fp32.cc @@ -26,18 +26,15 @@ namespace ailego { #define NORM_FP32_STEP_NEON SA_FP32_NEON #if defined(__SSE__) -static const __m128 ABS_MASK_FP32_SSE = - _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffu)); +#define ABS_MASK_FP32_SSE _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffu)) #endif // __SSE__ #if defined(__AVX__) -static const __m256 ABS_MASK_FP32_AVX = - _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffu)); +#define ABS_MASK_FP32_AVX _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffu)) #endif // __AVX__ #if defined(__AVX512F__) -static const __m512 ABS_MASK_FP32_AVX512 = - _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffffu)); +#define ABS_MASK_FP32_AVX512 _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffffu)) #endif // __AVX512F__ //! Calculate sum of absolute (GENERAL) @@ -64,15 +61,23 @@ void Norm1Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP32_1_NEON(m, dim, out, ) -#elif defined(__AVX512F__) - NORM_FP32_1_AVX512(m, dim, out, ) -#elif defined(__AVX__) - NORM_FP32_1_AVX(m, dim, out, ) #else +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + NORM_FP32_1_AVX512(m, dim, out, ) + return; + } +#endif +#if defined(__AVX__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { + NORM_FP32_1_AVX(m, dim, out, ) + return; + } +#endif NORM_FP32_1_SSE(m, dim, out, ) #endif } #endif // __SSE__ || (__ARM_NEON && __aarch64__) } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/norm2_matrix_fp16.cc b/src/ailego/math/norm2_matrix_fp16.cc index 28dbadb9..37c3313c 100644 --- a/src/ailego/math/norm2_matrix_fp16.cc +++ b/src/ailego/math/norm2_matrix_fp16.cc @@ -58,9 +58,13 @@ void Norm2Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP16_1_NEON(m, dim, out, std::sqrt) -#elif defined(__AVX512F__) - NORM_FP16_1_AVX512(m, dim, out, std::sqrt) #else +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + NORM_FP16_1_AVX512(m, dim, out, std::sqrt) + return; + } +#endif NORM_FP16_1_AVX(m, dim, out, std::sqrt) #endif } @@ -70,9 +74,13 @@ void SquaredNorm2Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP16_1_NEON(m, dim, out, ) -#elif defined(__AVX512F__) - NORM_FP16_1_AVX512(m, dim, out, ) #else +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + NORM_FP16_1_AVX512(m, dim, out, ) + return; + } +#endif NORM_FP16_1_AVX(m, dim, out, ) #endif } diff --git a/src/ailego/math/norm2_matrix_fp32.cc b/src/ailego/math/norm2_matrix_fp32.cc index b5dcf2f6..8cc76c1f 100644 --- a/src/ailego/math/norm2_matrix_fp32.cc +++ b/src/ailego/math/norm2_matrix_fp32.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include "norm2_matrix.h" #include "norm_matrix_fp32.i" @@ -48,11 +49,19 @@ void Norm2Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP32_1_NEON(m, dim, out, std::sqrt) -#elif defined(__AVX512F__) - NORM_FP32_1_AVX512(m, dim, out, std::sqrt) -#elif defined(__AVX__) - NORM_FP32_1_AVX(m, dim, out, std::sqrt) #else +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + NORM_FP32_1_AVX512(m, dim, out, std::sqrt) + return; + } +#endif +#if defined(__AVX__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { + NORM_FP32_1_AVX(m, dim, out, std::sqrt) + return; + } +#endif NORM_FP32_1_SSE(m, dim, out, std::sqrt) #endif } @@ -62,15 +71,23 @@ void SquaredNorm2Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP32_1_NEON(m, dim, out, ) -#elif defined(__AVX512F__) - NORM_FP32_1_AVX512(m, dim, out, ) -#elif defined(__AVX__) - NORM_FP32_1_AVX(m, dim, out, ) #else +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + NORM_FP32_1_AVX512(m, dim, out, ) + return; + } +#endif +#if defined(__AVX__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { + NORM_FP32_1_AVX(m, dim, out, ) + return; + } +#endif NORM_FP32_1_SSE(m, dim, out, ) #endif } #endif // __SSE__ || (__ARM_NEON && __aarch64__) } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math_batch/distance_batch.h b/src/ailego/math_batch/distance_batch.h index c762a258..17a21907 100644 --- a/src/ailego/math_batch/distance_batch.h +++ b/src/ailego/math_batch/distance_batch.h @@ -19,10 +19,8 @@ #include "cosine_distance_batch.h" #include "inner_product_distance_batch.h" - namespace zvec::ailego { - template < template class DistanceType, typename ValueType, size_t BatchSize, size_t PrefetchStep, typename = void> diff --git a/src/ailego/math_batch/inner_product_distance_batch.h b/src/ailego/math_batch/inner_product_distance_batch.h index f5799497..f42345bf 100644 --- a/src/ailego/math_batch/inner_product_distance_batch.h +++ b/src/ailego/math_batch/inner_product_distance_batch.h @@ -15,20 +15,28 @@ #pragma once #include -#include #include #include #include #include -#include "inner_product_distance_batch_impl.h" -#include "inner_product_distance_batch_impl_fp16.h" -#include "inner_product_distance_batch_impl_int8.h" namespace zvec::ailego::DistanceBatch { template struct InnerProductDistanceBatch; +template +static void compute_one_to_many_inner_product_fallback( + const ValueType *query, const ValueType **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { + for (size_t j = 0; j < BatchSize; ++j) { + sums[j] = 0.0; + InnerProductMatrix::Compute(ptrs[j], query, dim, sums + j); + ailego_prefetch(&prefetch_ptrs[j]); + } +} + // Function template partial specialization is not allowed, // therefore the wrapper struct is required. template @@ -38,98 +46,14 @@ struct InnerProductDistanceBatchImpl { const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { - return compute_one_to_many_fallback(query, ptrs, prefetch_ptrs, dim, sums); - } - static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc() { - return nullptr; - } -}; - -template -struct InnerProductDistanceBatchImpl { - using ValueType = float; - static void compute_one_to_many( - const ValueType *query, const ValueType **ptrs, - std::array &prefetch_ptrs, size_t dim, - float *sums) { -#if defined(__AVX2__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - return compute_one_to_many_avx2_fp32( - query, ptrs, prefetch_ptrs, dim, sums); - } -#endif - return compute_one_to_many_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_inner_product_fallback(query, ptrs, + prefetch_ptrs, dim, sums); } - static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc() { return nullptr; } }; -template -struct InnerProductDistanceBatchImpl { - using ValueType = int8_t; - static void compute_one_to_many( - const int8_t *query, const int8_t **ptrs, - std::array &prefetch_ptrs, size_t dim, - float *sums) { -// #if defined(__AVX512BW__) // TODO: this version is problematic -// return compute_one_to_many_avx512_int8( -// query, ptrs, prefetch_ptrs, dim, sums); -#if defined(__AVX512VNNI__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { - return compute_one_to_many_avx512_vnni_int8( - query, ptrs, prefetch_ptrs, dim, sums); - } -#endif -#if defined(__AVX2__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - return compute_one_to_many_avx2_int8( - query, ptrs, prefetch_ptrs, dim, sums); - } -#endif - return compute_one_to_many_fallback(query, ptrs, prefetch_ptrs, dim, sums); - } - - static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc() { -#if defined(__AVX512VNNI__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { - return compute_one_to_many_avx512_vnni_int8_query_preprocess; - } -#endif - return nullptr; - } -}; - -template -struct InnerProductDistanceBatchImpl { - using ValueType = ailego::Float16; - static void compute_one_to_many( - const ailego::Float16 *query, const ailego::Float16 **ptrs, - std::array &prefetch_ptrs, size_t dim, - float *sums) { -#if defined(__AVX512FP16__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { - return compute_one_to_many_avx512fp16_fp16( - query, ptrs, prefetch_ptrs, dim, sums); - } -#endif -#if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - return compute_one_to_many_avx512f_fp16( - query, ptrs, prefetch_ptrs, dim, sums); - } -#endif -#if defined(__AVX2__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - return compute_one_to_many_avx2_fp16( - query, ptrs, prefetch_ptrs, dim, sums); - } -#endif - return compute_one_to_many_fallback(query, ptrs, prefetch_ptrs, dim, sums); - } -}; - template struct InnerProductDistanceBatch { using ValueType = typename std::remove_cv::type; @@ -163,4 +87,56 @@ struct InnerProductDistanceBatch { } }; +template <> +struct InnerProductDistanceBatchImpl { + using ValueType = ailego::Float16; + static void compute_one_to_many( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums); +}; + +template <> +struct InnerProductDistanceBatchImpl { + using ValueType = float; + static void compute_one_to_many(const float *query, const float **ptrs, + std::array &prefetch_ptrs, + size_t dim, float *sums); +}; + +template <> +struct InnerProductDistanceBatchImpl { + using ValueType = int8_t; + static void compute_one_to_many(const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, + size_t dim, float *sums); + + static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc(); +}; + +template <> +struct InnerProductDistanceBatchImpl { + using ValueType = ailego::Float16; + static void compute_one_to_many( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums); +}; + +template <> +struct InnerProductDistanceBatchImpl { + using ValueType = float; + static void compute_one_to_many(const float *query, const float **ptrs, + std::array &prefetch_ptrs, + size_t dim, float *sums); +}; + +template <> +struct InnerProductDistanceBatchImpl { + using ValueType = int8_t; + static void compute_one_to_many(const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, + size_t dim, float *sums); +}; + } // namespace zvec::ailego::DistanceBatch diff --git a/src/ailego/math_batch/inner_product_distance_batch_dispatch.cc b/src/ailego/math_batch/inner_product_distance_batch_dispatch.cc new file mode 100644 index 00000000..78376626 --- /dev/null +++ b/src/ailego/math_batch/inner_product_distance_batch_dispatch.cc @@ -0,0 +1,228 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include "inner_product_distance_batch.h" + +namespace zvec::ailego::DistanceBatch { + +#if defined(__AVX512VNNI__) +void compute_one_to_many_inner_product_avx512_vnni_int8_query_preprocess( + void *query, size_t dim); + +void compute_one_to_many_inner_product_avx512_vnni_int8_1( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dimensionality, + float *results); + +void compute_one_to_many_inner_product_avx512_vnni_int8_12( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dimensionality, + float *results); +#endif + +#if defined(__AVX512FP16__) +void compute_one_to_many_inner_product_avx512fp16_fp16_1( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results); + +void compute_one_to_many_inner_product_avx512fp16_fp16_12( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results); +#endif //__AVX512FP16__ + +#if defined(__AVX512F__) +void compute_one_to_many_inner_product_avx512f_fp16_1( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results); + +void compute_one_to_many_inner_product_avx512f_fp16_12( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results); +#endif //__AVX512F__ + +#if defined(__AVX2__) +void compute_one_to_many_inner_product_avx2_fp32_1( + const float *query, const float **ptrs, + std::array &prefetch_ptrs, size_t dimensionality, + float *results); + +void compute_one_to_many_inner_product_avx2_fp16_1( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results); + +void compute_one_to_many_inner_product_avx2_int8_1( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dimensionality, + float *results); + +void compute_one_to_many_inner_product_avx2_fp32_12( + const float *query, const float **ptrs, + std::array &prefetch_ptrs, size_t dimensionality, + float *results); + +void compute_one_to_many_inner_product_avx2_fp16_12( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results); + +void compute_one_to_many_inner_product_avx2_int8_12( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dimensionality, + float *results); +#endif + +void InnerProductDistanceBatchImpl::compute_one_to_many( + const ValueType *query, const ValueType **ptrs, + std::array &prefetch_ptrs, size_t dim, float *sums) { +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + return compute_one_to_many_inner_product_avx2_fp32_1( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif + return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, + dim, sums); +} + +void InnerProductDistanceBatchImpl::compute_one_to_many( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { +#if defined(__AVX512FP16__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { + return compute_one_to_many_inner_product_avx512fp16_fp16_1( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + return compute_one_to_many_inner_product_avx512f_fp16_1( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + return compute_one_to_many_inner_product_avx2_fp16_1( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif + return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, + dim, sums); +} + +void InnerProductDistanceBatchImpl::compute_one_to_many( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dim, float *sums) { +// #if defined(__AVX512BW__) // TODO: this version is problematic +// return compute_one_to_many_avx512_int8( +// query, ptrs, prefetch_ptrs, dim, sums); +#if defined(__AVX512VNNI__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { + return compute_one_to_many_inner_product_avx512_vnni_int8_1( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + return compute_one_to_many_inner_product_avx2_int8_1( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif + return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, + dim, sums); +} + +DistanceBatchQueryPreprocessFunc +InnerProductDistanceBatchImpl::GetQueryPreprocessFunc() { +#if defined(__AVX512VNNI__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { + return compute_one_to_many_inner_product_avx512_vnni_int8_query_preprocess; + } +#endif + return nullptr; +} + +void InnerProductDistanceBatchImpl::compute_one_to_many( + const ValueType *query, const ValueType **ptrs, + std::array &prefetch_ptrs, size_t dim, float *sums) { +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + return compute_one_to_many_inner_product_avx2_fp32_12( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif + return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, + dim, sums); +} + +void InnerProductDistanceBatchImpl::compute_one_to_many( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { +#if defined(__AVX512FP16__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { + return compute_one_to_many_inner_product_avx512fp16_fp16_12( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + return compute_one_to_many_inner_product_avx512f_fp16_12( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + return compute_one_to_many_inner_product_avx2_fp16_12( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif + return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, + dim, sums); +} + +void InnerProductDistanceBatchImpl::compute_one_to_many( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dim, float *sums) { +// #if defined(__AVX512BW__) // TODO: this version is problematic +// return compute_one_to_many_avx512_int8( +// query, ptrs, prefetch_ptrs, dim, sums); +#if defined(__AVX512VNNI__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { + return compute_one_to_many_inner_product_avx512_vnni_int8_12( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + return compute_one_to_many_inner_product_avx2_int8_12( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif + return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, + dim, sums); +} + +} // namespace zvec::ailego::DistanceBatch diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx2.cc b/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx2.cc new file mode 100644 index 00000000..59320de0 --- /dev/null +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx2.cc @@ -0,0 +1,110 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +namespace zvec::ailego::DistanceBatch { + +#if defined(__AVX2__) + +template +static std::enable_if_t, void> +compute_one_to_many_inner_product_avx2_fp16( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results) { + std::array<__m256, dp_batch> accs; + + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = _mm256_setzero_ps(); + } + + size_t dim = 0; + for (; dim + 16 <= dimensionality; dim += 16) { + __m256i q = + _mm256_loadu_si256(reinterpret_cast(query + dim)); + + __m256 q1 = _mm256_cvtph_ps(_mm256_castsi256_si128(q)); + __m256 q2 = _mm256_cvtph_ps(_mm256_extractf128_si256(q, 1)); + + std::array<__m256, dp_batch> data_regs_1; + std::array<__m256, dp_batch> data_regs_2; + for (size_t i = 0; i < dp_batch; ++i) { + __m256i m = + _mm256_loadu_si256(reinterpret_cast(ptrs[i] + dim)); + + data_regs_1[i] = _mm256_cvtph_ps(_mm256_castsi256_si128(m)); + data_regs_2[i] = _mm256_cvtph_ps(_mm256_extractf128_si256(m, 1)); + } + + if (prefetch_ptrs[0]) { + for (size_t i = 0; i < dp_batch; ++i) { + ailego_prefetch(prefetch_ptrs[i] + dim); + } + } + + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = _mm256_fmadd_ps(q1, data_regs_1[i], accs[i]); + accs[i] = _mm256_fmadd_ps(q2, data_regs_2[i], accs[i]); + } + } + + if (dim + 8 <= dimensionality) { + __m256 q = _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(query + dim))); + + std::array<__m256, dp_batch> data_regs; + for (size_t i = 0; i < dp_batch; ++i) { + data_regs[i] = _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ptrs[i] + dim))); + accs[i] = _mm256_fmadd_ps(q, data_regs[i], accs[i]); + } + + dim += 8; + } + + for (size_t i = 0; i < dp_batch; ++i) { + results[i] = HorizontalAdd_FP32_V256(accs[i]); + } + + for (; dim < dimensionality; ++dim) { + for (size_t i = 0; i < dp_batch; ++i) { + results[i] += (*(query + dim)) * (*(ptrs[i] + dim)); + } + } +} + +void compute_one_to_many_inner_product_avx2_fp16_1( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { + return compute_one_to_many_inner_product_avx2_fp16( + query, ptrs, prefetch_ptrs, dim, sums); +} + +void compute_one_to_many_inner_product_avx2_fp16_12( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { + return compute_one_to_many_inner_product_avx2_fp16( + query, ptrs, prefetch_ptrs, dim, sums); +} + +#endif + +} // namespace zvec::ailego::DistanceBatch \ No newline at end of file diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_fp16.h b/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx512.cc similarity index 70% rename from src/ailego/math_batch/inner_product_distance_batch_impl_fp16.h rename to src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx512.cc index 183369d7..1fbe5b24 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl_fp16.h +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx512.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include #include #include @@ -23,10 +21,9 @@ namespace zvec::ailego::DistanceBatch { #if defined(__AVX512FP16__) - template static std::enable_if_t, void> -compute_one_to_many_avx512fp16_fp16( +compute_one_to_many_inner_product_avx512fp16_fp16( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { @@ -82,7 +79,7 @@ compute_one_to_many_avx512fp16_fp16( template static std::enable_if_t, void> -compute_one_to_many_avx512f_fp16( +compute_one_to_many_inner_product_avx512f_fp16( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { @@ -122,7 +119,7 @@ compute_one_to_many_avx512f_fp16( } } - if (dim + 16 < dimensionality) { + if (dim + 16 <= dimensionality) { __m512 q = _mm512_cvtph_ps( _mm256_loadu_si256(reinterpret_cast(query + dim))); @@ -143,7 +140,7 @@ compute_one_to_many_avx512f_fp16( _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(accs[i]), 1))); } - if (dim + 8 < dimensionality) { + if (dim + 8 <= dimensionality) { __m256 q = _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(query + dim))); @@ -166,78 +163,43 @@ compute_one_to_many_avx512f_fp16( } } } -#endif -#if defined(__AVX2__) +#endif -template -static std::enable_if_t, void> -compute_one_to_many_avx2_fp16( +#if defined(__AVX512FP16__) +void compute_one_to_many_inner_product_avx512fp16_fp16_1( const ailego::Float16 *query, const ailego::Float16 **ptrs, - std::array &prefetch_ptrs, - size_t dimensionality, float *results) { - std::array<__m256, dp_batch> accs; - - for (size_t i = 0; i < dp_batch; ++i) { - accs[i] = _mm256_setzero_ps(); - } - - size_t dim = 0; - for (; dim + 16 <= dimensionality; dim += 16) { - __m256i q = - _mm256_loadu_si256(reinterpret_cast(query + dim)); - - __m256 q1 = _mm256_cvtph_ps(_mm256_castsi256_si128(q)); - __m256 q2 = _mm256_cvtph_ps(_mm256_extractf128_si256(q, 1)); - - std::array<__m256, dp_batch> data_regs_1; - std::array<__m256, dp_batch> data_regs_2; - for (size_t i = 0; i < dp_batch; ++i) { - __m256i m = - _mm256_loadu_si256(reinterpret_cast(ptrs[i] + dim)); - - data_regs_1[i] = _mm256_cvtph_ps(_mm256_castsi256_si128(m)); - data_regs_2[i] = _mm256_cvtph_ps(_mm256_extractf128_si256(m, 1)); - } - - if (prefetch_ptrs[0]) { - for (size_t i = 0; i < dp_batch; ++i) { - ailego_prefetch(prefetch_ptrs[i] + dim); - } - } - - for (size_t i = 0; i < dp_batch; ++i) { - accs[i] = _mm256_fmadd_ps(q1, data_regs_1[i], accs[i]); - accs[i] = _mm256_fmadd_ps(q2, data_regs_2[i], accs[i]); - } - } - - if (dim + 8 < dimensionality) { - __m256 q = _mm256_cvtph_ps( - _mm_loadu_si128(reinterpret_cast(query + dim))); - - std::array<__m256, dp_batch> data_regs; - for (size_t i = 0; i < dp_batch; ++i) { - data_regs[i] = _mm256_cvtph_ps( - _mm_loadu_si128(reinterpret_cast(ptrs[i] + dim))); - accs[i] = _mm256_fmadd_ps(q, data_regs[i], accs[i]); - } - - dim += 8; - } + std::array &prefetch_ptrs, size_t dim, + float *sums) { + return compute_one_to_many_inner_product_avx512fp16_fp16( + query, ptrs, prefetch_ptrs, dim, sums); +} - for (size_t i = 0; i < dp_batch; ++i) { - results[i] = HorizontalAdd_FP32_V256(accs[i]); - } +void compute_one_to_many_inner_product_avx512fp16_fp16_12( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { + return compute_one_to_many_inner_product_avx512fp16_fp16( + query, ptrs, prefetch_ptrs, dim, sums); +} +#endif - for (; dim < dimensionality; ++dim) { - for (size_t i = 0; i < dp_batch; ++i) { - results[i] += (*(query + dim)) * (*(ptrs[i] + dim)); - } - } +#if defined(__AVX512F__) +void compute_one_to_many_inner_product_avx512f_fp16_1( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { + return compute_one_to_many_inner_product_avx512f_fp16( + query, ptrs, prefetch_ptrs, dim, sums); } +void compute_one_to_many_inner_product_avx512f_fp16_12( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { + return compute_one_to_many_inner_product_avx512f_fp16( + query, ptrs, prefetch_ptrs, dim, sums); +} #endif - } // namespace zvec::ailego::DistanceBatch \ No newline at end of file diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl.h b/src/ailego/math_batch/inner_product_distance_batch_impl_fp32_avx2.cc similarity index 85% rename from src/ailego/math_batch/inner_product_distance_batch_impl.h rename to src/ailego/math_batch/inner_product_distance_batch_impl_fp32_avx2.cc index d15a747e..0e54064b 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl.h +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_fp32_avx2.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include #include #include @@ -22,18 +20,6 @@ namespace zvec::ailego::DistanceBatch { -template -static void compute_one_to_many_fallback( - const ValueType *query, const ValueType **ptrs, - std::array &prefetch_ptrs, size_t dim, - float *sums) { - for (size_t j = 0; j < BatchSize; ++j) { - sums[j] = 0.0; - InnerProductMatrix::Compute(ptrs[j], query, dim, sums + j); - ailego_prefetch(&prefetch_ptrs[j]); - } -} - #if defined(__AVX2__) inline float sum4(__m128 v) { @@ -49,7 +35,7 @@ inline __m128 sum_top_bottom_avx(__m256 v) { template static std::enable_if_t, void> -compute_one_to_many_avx2_fp32( +compute_one_to_many_inner_product_avx2_fp32( const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { @@ -123,7 +109,21 @@ compute_one_to_many_avx2_fp32( results[i] = -res[i]; } } -#endif +void compute_one_to_many_inner_product_avx2_fp32_1( + const float *query, const float **ptrs, + std::array &prefetch_ptrs, size_t dim, float *sums) { + return compute_one_to_many_inner_product_avx2_fp32( + query, ptrs, prefetch_ptrs, dim, sums); +} + +void compute_one_to_many_inner_product_avx2_fp32_12( + const float *query, const float **ptrs, + std::array &prefetch_ptrs, size_t dim, float *sums) { + return compute_one_to_many_inner_product_avx2_fp32( + query, ptrs, prefetch_ptrs, dim, sums); +} + +#endif } // namespace zvec::ailego::DistanceBatch \ No newline at end of file diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx2.cc b/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx2.cc new file mode 100644 index 00000000..23d3566a --- /dev/null +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx2.cc @@ -0,0 +1,102 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace zvec::ailego::DistanceBatch { + +#if defined(__AVX2__) + +template +static std::enable_if_t, void> +compute_one_to_many_inner_product_avx2_int8( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dimensionality, + float *results) { + std::vector<__m256i> accs(dp_batch); + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = _mm256_setzero_si256(); + } + size_t dim = 0; + for (; dim + 32 <= dimensionality; dim += 32) { + __m256i q = _mm256_loadu_si256((const __m256i *)(query + dim)); + std::vector<__m256i> data_regs(dp_batch); + for (size_t i = 0; i < dp_batch; ++i) { + data_regs[i] = _mm256_loadu_si256((const __m256i *)(ptrs[i] + dim)); + } + if (prefetch_ptrs[0]) { + for (size_t i = 0; i < dp_batch; ++i) { + ailego_prefetch(prefetch_ptrs[i] + dim); + } + } + __m256i q_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q)); + __m256i q_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q, 1)); + __m256i data_lo[dp_batch]; + __m256i data_hi[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + data_lo[i] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(data_regs[i])); + data_hi[i] = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(data_regs[i], 1)); + } + __m256i prod_lo[dp_batch]; + __m256i prod_hi[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + prod_lo[i] = _mm256_madd_epi16(q_lo, data_lo[i]); + prod_hi[i] = _mm256_madd_epi16(q_hi, data_hi[i]); + } + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = + _mm256_add_epi32(accs[i], _mm256_add_epi32(prod_lo[i], prod_hi[i])); + } + } + std::array temp_results; + for (size_t i = 0; i < dp_batch; ++i) { + __m128i lo = _mm256_castsi256_si128(accs[i]); + __m128i hi = _mm256_extracti128_si256(accs[i], 1); + __m128i sum128 = _mm_add_epi32(lo, hi); + sum128 = _mm_hadd_epi32(sum128, sum128); + sum128 = _mm_hadd_epi32(sum128, sum128); + temp_results[i] = _mm_cvtsi128_si32(sum128); + } + for (; dim < dimensionality; ++dim) { + int8_t q = query[dim]; + for (size_t i = 0; i < dp_batch; ++i) { + temp_results[i] += q * static_cast(ptrs[i][dim]); + } + } + for (size_t i = 0; i < dp_batch; ++i) { + results[i] = static_cast(temp_results[i]); + } +} + +void compute_one_to_many_inner_product_avx2_int8_1( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dim, float *sums) { + return compute_one_to_many_inner_product_avx2_int8( + query, ptrs, prefetch_ptrs, dim, sums); +} + +void compute_one_to_many_inner_product_avx2_int8_12( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dim, float *sums) { + return compute_one_to_many_inner_product_avx2_int8( + query, ptrs, prefetch_ptrs, dim, sums); +} + +#endif + +} // namespace zvec::ailego::DistanceBatch \ No newline at end of file diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_int8.h b/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx512.cc similarity index 68% rename from src/ailego/math_batch/inner_product_distance_batch_impl_int8.h rename to src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx512.cc index 0e236641..1e105832 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl_int8.h +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx512.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include #include #include @@ -23,8 +21,8 @@ namespace zvec::ailego::DistanceBatch { #if defined(__AVX512VNNI__) -static void compute_one_to_many_avx512_vnni_int8_query_preprocess(void *query, - size_t dim) { +void compute_one_to_many_inner_product_avx512_vnni_int8_query_preprocess( + void *query, size_t dim) { const int8_t *input = reinterpret_cast(query); uint8_t *output = reinterpret_cast(query); @@ -48,10 +46,9 @@ static void compute_one_to_many_avx512_vnni_int8_query_preprocess(void *query, } } - // query is unsigned template -static void compute_one_to_many_avx512_vnni_int8( +static void compute_one_to_many_inner_product_avx512_vnni_int8( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { @@ -153,73 +150,22 @@ static void compute_one_to_many_avx512_vnni_int8( // results[i] = static_cast(temp_results[i]); // } // } -#endif -#if defined(__AVX2__) +void compute_one_to_many_inner_product_avx512_vnni_int8_1( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dim, float *sums) { + return compute_one_to_many_inner_product_avx512_vnni_int8<1>( + query, ptrs, prefetch_ptrs, dim, sums); +} -template -static std::enable_if_t, void> -compute_one_to_many_avx2_int8( +void compute_one_to_many_inner_product_avx512_vnni_int8_12( const int8_t *query, const int8_t **ptrs, - std::array &prefetch_ptrs, size_t dimensionality, - float *results) { - std::array<__m256i, dp_batch> accs; - for (size_t i = 0; i < dp_batch; ++i) { - accs[i] = _mm256_setzero_si256(); - } - size_t dim = 0; - for (; dim + 32 <= dimensionality; dim += 32) { - __m256i q = _mm256_loadu_si256((const __m256i *)(query + dim)); - std::array<__m256i, dp_batch> data_regs; - for (size_t i = 0; i < dp_batch; ++i) { - data_regs[i] = _mm256_loadu_si256((const __m256i *)(ptrs[i] + dim)); - } - if (prefetch_ptrs[0]) { - for (size_t i = 0; i < dp_batch; ++i) { - ailego_prefetch(prefetch_ptrs[i] + dim); - } - } - __m256i q_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q)); - __m256i q_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q, 1)); - std::array<__m256i, dp_batch> data_lo; - std::array<__m256i, dp_batch> data_hi; - for (size_t i = 0; i < dp_batch; ++i) { - data_lo[i] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(data_regs[i])); - data_hi[i] = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(data_regs[i], 1)); - } - std::array<__m256i, dp_batch> prod_lo; - std::array<__m256i, dp_batch> prod_hi; - for (size_t i = 0; i < dp_batch; ++i) { - prod_lo[i] = _mm256_madd_epi16(q_lo, data_lo[i]); - prod_hi[i] = _mm256_madd_epi16(q_hi, data_hi[i]); - } - for (size_t i = 0; i < dp_batch; ++i) { - accs[i] = - _mm256_add_epi32(accs[i], _mm256_add_epi32(prod_lo[i], prod_hi[i])); - } - } - std::array temp_results; - for (size_t i = 0; i < dp_batch; ++i) { - __m128i lo = _mm256_castsi256_si128(accs[i]); - __m128i hi = _mm256_extracti128_si256(accs[i], 1); - __m128i sum128 = _mm_add_epi32(lo, hi); - sum128 = _mm_hadd_epi32(sum128, sum128); - sum128 = _mm_hadd_epi32(sum128, sum128); - temp_results[i] = _mm_cvtsi128_si32(sum128); - } - for (; dim < dimensionality; ++dim) { - int8_t q = query[dim]; - for (size_t i = 0; i < dp_batch; ++i) { - temp_results[i] += q * static_cast(ptrs[i][dim]); - } - } - for (size_t i = 0; i < dp_batch; ++i) { - results[i] = static_cast(temp_results[i]); - } + std::array &prefetch_ptrs, size_t dim, float *sums) { + return compute_one_to_many_inner_product_avx512_vnni_int8<12>( + query, ptrs, prefetch_ptrs, dim, sums); } -#endif +#endif } // namespace zvec::ailego::DistanceBatch \ No newline at end of file From 2366be91a387cab5a3be8788f12dda4fba760cec Mon Sep 17 00:00:00 2001 From: rayx Date: Thu, 12 Mar 2026 12:22:37 +0800 Subject: [PATCH 16/34] fix: fix uint8 conversion overflow (#215) --- src/include/zvec/ailego/internal/platform.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/zvec/ailego/internal/platform.h b/src/include/zvec/ailego/internal/platform.h index ecfe673f..598ca97a 100644 --- a/src/include/zvec/ailego/internal/platform.h +++ b/src/include/zvec/ailego/internal/platform.h @@ -476,11 +476,11 @@ static inline void ailego_assert_report(const char *file, const char *func, // is undefined (on arm, result will be zero), it's necessary to convert it // to signed integer firstly static inline uint8_t static_cast_from_float_to_uint8(float data) { - return static_cast(static_cast(data)); + return static_cast(static_cast(data)); } static inline uint16_t static_cast_from_float_to_uint16(float data) { - return static_cast(static_cast(data)); + return static_cast(static_cast(data)); } #ifdef __cplusplus From f86d4c2a64eff6ca0b52ee6714e55278a892423a Mon Sep 17 00:00:00 2001 From: egolearner Date: Thu, 12 Mar 2026 14:06:04 +0800 Subject: [PATCH 17/34] fix compile (#218) * fix compile * fix missing include * Revert "fix missing include" This reverts commit 295b52b0a5057a1c979b0316faf80e864031a0cb. --------- Co-authored-by: rayx --- src/ailego/math/norm1_matrix_fp16.cc | 1 + src/ailego/math/norm1_matrix_fp32.cc | 1 + src/ailego/math/norm2_matrix_fp16.cc | 1 + 3 files changed, 3 insertions(+) diff --git a/src/ailego/math/norm1_matrix_fp16.cc b/src/ailego/math/norm1_matrix_fp16.cc index 9bb86201..e75b3e0a 100644 --- a/src/ailego/math/norm1_matrix_fp16.cc +++ b/src/ailego/math/norm1_matrix_fp16.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include "ailego/internal/cpu_features.h" #include "norm1_matrix.h" #include "norm_matrix_fp16.i" diff --git a/src/ailego/math/norm1_matrix_fp32.cc b/src/ailego/math/norm1_matrix_fp32.cc index c1ec3668..2e727911 100644 --- a/src/ailego/math/norm1_matrix_fp32.cc +++ b/src/ailego/math/norm1_matrix_fp32.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include "ailego/internal/cpu_features.h" #include "norm1_matrix.h" #include "norm_matrix_fp32.i" diff --git a/src/ailego/math/norm2_matrix_fp16.cc b/src/ailego/math/norm2_matrix_fp16.cc index 37c3313c..6bb8dd06 100644 --- a/src/ailego/math/norm2_matrix_fp16.cc +++ b/src/ailego/math/norm2_matrix_fp16.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include "ailego/internal/cpu_features.h" #include "norm2_matrix.h" #include "norm_matrix_fp16.i" From b7a79b31d95bbd0b9639322cbc0a24cc01ac6522 Mon Sep 17 00:00:00 2001 From: rayx Date: Thu, 12 Mar 2026 16:12:28 +0800 Subject: [PATCH 18/34] fix: turn off math march if not auto detected (#220) --- src/ailego/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ailego/CMakeLists.txt b/src/ailego/CMakeLists.txt index c9867c36..bdabe413 100644 --- a/src/ailego/CMakeLists.txt +++ b/src/ailego/CMakeLists.txt @@ -18,7 +18,7 @@ if(UNIX AND NOT APPLE) list(APPEND EXTRA_LIBS ${LIB_RT}) endif() -if(NOT ANDROID) +if(NOT ANDROID AND AUTO_DETECT_ARCH) if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|i686|i386|x64") setup_compiler_march_for_x86(MATH_MARCH_FLAG_SSE MATH_MARCH_FLAG_AVX2 MATH_MARCH_FLAG_AVX512) message(STATUS "best compiler march, sse: " ${MATH_MARCH_FLAG_SSE} ", avx2: " ${MATH_MARCH_FLAG_AVX2} ", avx512: " ${MATH_MARCH_FLAG_AVX512}) From 0b03af29b5200dfba4cf9145162796a98958160c Mon Sep 17 00:00:00 2001 From: Qinren Zhou Date: Thu, 12 Mar 2026 17:35:26 +0800 Subject: [PATCH 19/34] fix: compilation warning in ut (#217) --- tests/db/crash_recovery/data_generator.cc | 6 +++--- tests/db/crash_recovery/utility.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/db/crash_recovery/data_generator.cc b/tests/db/crash_recovery/data_generator.cc index 57542471..d0f5798f 100644 --- a/tests/db/crash_recovery/data_generator.cc +++ b/tests/db/crash_recovery/data_generator.cc @@ -136,8 +136,8 @@ int main(int argc, char **argv) { std::cout << " BatchDelay: " << kBatchDelayMs << "ms" << std::endl; std::cout << std::endl; - auto result = - zvec::Collection::Open(config.path, zvec::CollectionOptions{false, true}); + auto result = zvec::Collection::Open( + config.path, zvec::CollectionOptions{false, true, 4 * 1024 * 1024}); if (!result) { LOG_ERROR("Failed to open collection[%s]: %s", config.path.c_str(), result.error().c_str()); @@ -160,7 +160,7 @@ int main(int argc, char **argv) { std::vector docs; docs.reserve(batch_count); - for (uint64_t i = config.start_id; i < batch_end; i++) { + for (int i = config.start_id; i < batch_end; i++) { docs.push_back(zvec::CreateTestDoc(i, config.version)); } diff --git a/tests/db/crash_recovery/utility.h b/tests/db/crash_recovery/utility.h index 36768b24..063108dd 100644 --- a/tests/db/crash_recovery/utility.h +++ b/tests/db/crash_recovery/utility.h @@ -31,7 +31,7 @@ namespace zvec { inline CollectionSchema::Ptr CreateTestSchema( const std::string &name = "crash_recovery_test") { auto schema = std::make_shared(name); - schema->set_max_doc_count_per_segment(2000); + schema->set_max_doc_count_per_segment(10000); schema->add_field( std::make_shared("int32_field", DataType::INT32, false)); From b29bc8ce7d32ec935672442885da3798aa9b200f Mon Sep 17 00:00:00 2001 From: rayx Date: Fri, 13 Mar 2026 13:34:28 +0800 Subject: [PATCH 20/34] fix: use static array (#222) * fix: use static array * fix: clang-format * fix: clang-format * fix: clang-format --- ...ner_product_distance_batch_impl_fp16_avx2.cc | 9 ++++----- ...r_product_distance_batch_impl_fp16_avx512.cc | 16 +++++++--------- ...ner_product_distance_batch_impl_fp32_avx2.cc | 17 +++++++++++------ ...ner_product_distance_batch_impl_int8_avx2.cc | 10 ++++++---- ...r_product_distance_batch_impl_int8_avx512.cc | 9 ++++++--- 5 files changed, 34 insertions(+), 27 deletions(-) diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx2.cc b/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx2.cc index 59320de0..d6fe475f 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx2.cc +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx2.cc @@ -28,8 +28,7 @@ compute_one_to_many_inner_product_avx2_fp16( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { - std::array<__m256, dp_batch> accs; - + __m256 accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm256_setzero_ps(); } @@ -42,8 +41,8 @@ compute_one_to_many_inner_product_avx2_fp16( __m256 q1 = _mm256_cvtph_ps(_mm256_castsi256_si128(q)); __m256 q2 = _mm256_cvtph_ps(_mm256_extractf128_si256(q, 1)); - std::array<__m256, dp_batch> data_regs_1; - std::array<__m256, dp_batch> data_regs_2; + __m256 data_regs_1[dp_batch]; + __m256 data_regs_2[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { __m256i m = _mm256_loadu_si256(reinterpret_cast(ptrs[i] + dim)); @@ -68,7 +67,7 @@ compute_one_to_many_inner_product_avx2_fp16( __m256 q = _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(query + dim))); - std::array<__m256, dp_batch> data_regs; + __m256 data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ptrs[i] + dim))); diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx512.cc b/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx512.cc index 1fbe5b24..e06820e9 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx512.cc +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx512.cc @@ -27,8 +27,7 @@ compute_one_to_many_inner_product_avx512fp16_fp16( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { - std::array<__m512h, dp_batch> accs; - + __m512h accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm512_setzero_ph(); } @@ -37,7 +36,7 @@ compute_one_to_many_inner_product_avx512fp16_fp16( for (; dim + 32 <= dimensionality; dim += 32) { __m512h q = _mm512_loadu_ph(query + dim); - std::array<__m512h, dp_batch> data_regs; + __m512h data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm512_loadu_ph(ptrs[i] + dim); } @@ -83,8 +82,7 @@ compute_one_to_many_inner_product_avx512f_fp16( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { - std::array<__m512, dp_batch> accs; - + __m512 accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm512_setzero_ps(); } @@ -97,8 +95,8 @@ compute_one_to_many_inner_product_avx512f_fp16( __m512 q1 = _mm512_cvtph_ps(_mm512_castsi512_si256(q)); __m512 q2 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(q, 1)); - std::array<__m512, dp_batch> data_regs_1; - std::array<__m512, dp_batch> data_regs_2; + __m512 data_regs_1[dp_batch]; + __m512 data_regs_2[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { __m512i m = _mm512_loadu_si512(reinterpret_cast(ptrs[i] + dim)); @@ -123,7 +121,7 @@ compute_one_to_many_inner_product_avx512f_fp16( __m512 q = _mm512_cvtph_ps( _mm256_loadu_si256(reinterpret_cast(query + dim))); - std::array<__m512, dp_batch> data_regs; + __m512 data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm512_cvtph_ps( _mm256_loadu_si256(reinterpret_cast(ptrs[i] + dim))); @@ -133,7 +131,7 @@ compute_one_to_many_inner_product_avx512f_fp16( dim += 16; } - std::array<__m256, dp_batch> acc_new; + __m256 acc_new[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { acc_new[i] = _mm256_add_ps( _mm512_castps512_ps256(accs[i]), diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_fp32_avx2.cc b/src/ailego/math_batch/inner_product_distance_batch_impl_fp32_avx2.cc index 0e54064b..ffda66e9 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl_fp32_avx2.cc +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_fp32_avx2.cc @@ -39,14 +39,15 @@ compute_one_to_many_inner_product_avx2_fp32( const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { - std::array<__m256, dp_batch> accs; + __m256 accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm256_setzero_ps(); } size_t dim = 0; for (; dim + 8 <= dimensionality; dim += 8) { __m256 q = _mm256_loadu_ps(query + dim); - std::array<__m256, dp_batch> data_regs; + + __m256 data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm256_loadu_ps(ptrs[i] + dim); } @@ -59,13 +60,15 @@ compute_one_to_many_inner_product_avx2_fp32( accs[i] = _mm256_fnmadd_ps(q, data_regs[i], accs[i]); } } - std::array<__m128, dp_batch> sum128_regs; + + __m128 sum128_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { sum128_regs[i] = sum_top_bottom_avx(accs[i]); } if (dim + 4 <= dimensionality) { __m128 q = _mm_loadu_ps(query + dim); - std::array<__m128, dp_batch> data_regs; + + __m128 data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm_loadu_ps(ptrs[i] + dim); } @@ -81,7 +84,8 @@ compute_one_to_many_inner_product_avx2_fp32( } if (dim + 2 <= dimensionality) { __m128 q = _mm_setzero_ps(); - std::array<__m128, dp_batch> data_regs; + + __m128 data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm_setzero_ps(); } @@ -95,7 +99,8 @@ compute_one_to_many_inner_product_avx2_fp32( } dim += 2; } - std::array res; + + float res[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { res[i] = sum4(sum128_regs[i]); } diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx2.cc b/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx2.cc index 23d3566a..66d7e154 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx2.cc +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx2.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include #include #include #include @@ -27,14 +27,15 @@ compute_one_to_many_inner_product_avx2_int8( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { - std::vector<__m256i> accs(dp_batch); + __m256i accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm256_setzero_si256(); } size_t dim = 0; for (; dim + 32 <= dimensionality; dim += 32) { __m256i q = _mm256_loadu_si256((const __m256i *)(query + dim)); - std::vector<__m256i> data_regs(dp_batch); + + __m256i data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm256_loadu_si256((const __m256i *)(ptrs[i] + dim)); } @@ -63,7 +64,8 @@ compute_one_to_many_inner_product_avx2_int8( _mm256_add_epi32(accs[i], _mm256_add_epi32(prod_lo[i], prod_hi[i])); } } - std::array temp_results; + + int temp_results[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { __m128i lo = _mm256_castsi256_si128(accs[i]); __m128i hi = _mm256_extracti128_si256(accs[i], 1); diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx512.cc b/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx512.cc index 1e105832..2caf83c6 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx512.cc +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx512.cc @@ -52,7 +52,7 @@ static void compute_one_to_many_inner_product_avx512_vnni_int8( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { - std::array<__m512i, dp_batch> accs; + __m512i accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm512_setzero_si512(); } @@ -60,7 +60,8 @@ static void compute_one_to_many_inner_product_avx512_vnni_int8( for (; dim + 64 <= dimensionality; dim += 64) { __m512i q = _mm512_loadu_si512(reinterpret_cast(query + dim)); - std::array<__m512i, dp_batch> data_regs; + + __m512i data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm512_loadu_si512(reinterpret_cast(ptrs[i] + dim)); @@ -74,7 +75,8 @@ static void compute_one_to_many_inner_product_avx512_vnni_int8( accs[i] = _mm512_dpbusd_epi32(accs[i], q, data_regs[i]); } } - std::array temp_results{}; + + int temp_results[dp_batch]{}; for (size_t i = 0; i < dp_batch; ++i) { temp_results[i] = _mm512_reduce_add_epi32(accs[i]); } @@ -88,6 +90,7 @@ static void compute_one_to_many_inner_product_avx512_vnni_int8( results[i] = static_cast(temp_results[i]); } } + // // #elif defined(__AVX512BW__) // // TODO: this version is problematic From c0022d9795d968da0b390528290624d28c6ba089 Mon Sep 17 00:00:00 2001 From: luoxiaojian Date: Fri, 13 Mar 2026 17:26:22 +0800 Subject: [PATCH 21/34] feat: enable icelake and l2 batch distance for int8 quantization. (#213) --- cmake/option.cmake | 12 +++-- src/core/metric/quantized_integer_metric.cc | 6 ++- .../metric/quantized_integer_metric_batch.h | 47 ++++++++++++------- .../metric/quantized_integer_metric_matrix.h | 10 ++-- src/core/quantizer/cosine_converter.cc | 6 +-- src/core/quantizer/cosine_reformer.cc | 2 +- .../quantizer/integer_quantizer_converter.cc | 5 +- .../quantizer/integer_quantizer_reformer.cc | 2 +- src/core/quantizer/record_quantizer.h | 12 +++-- .../metric/quantized_integer_metric_test.cc | 8 ++-- 10 files changed, 71 insertions(+), 39 deletions(-) diff --git a/cmake/option.cmake b/cmake/option.cmake index 71e45784..3c042422 100644 --- a/cmake/option.cmake +++ b/cmake/option.cmake @@ -9,6 +9,7 @@ option(ENABLE_HASWELL "Enable Intel Haswell CPU microarchitecture" OFF) option(ENABLE_BROADWELL "Enable Intel Broadwell CPU microarchitecture" OFF) option(ENABLE_SKYLAKE "Enable Intel Skylake CPU microarchitecture" OFF) option(ENABLE_SKYLAKE_AVX512 "Enable Intel Skylake Server CPU microarchitecture" OFF) +option(ENABLE_ICELAKE "Enable Intel Icelake CPU microarchitecture" OFF) option(ENABLE_SAPPHIRERAPIDS "Enable Intel Sapphire Rapids Server CPU microarchitecture" OFF) option(ENABLE_EMERALDRAPIDS "Enable Intel Emerald Rapids Server CPU microarchitecture" OFF) option(ENABLE_GRANITERAPIDS "Enable Intel Granite Rapids Server CPU microarchitecture" OFF) @@ -34,8 +35,8 @@ option(ENABLE_OPENMP "Enable OpenMP support" OFF) set(ARCH_OPTIONS ENABLE_NEHALEM ENABLE_SANDYBRIDGE ENABLE_HASWELL ENABLE_BROADWELL ENABLE_SKYLAKE - ENABLE_SKYLAKE_AVX512 ENABLE_SAPPHIRERAPIDS ENABLE_EMERALDRAPIDS ENABLE_GRANITERAPIDS - ENABLE_ZEN1 ENABLE_ZEN2 ENABLE_ZEN3 + ENABLE_SKYLAKE_AVX512 ENABLE_ICELAKE ENABLE_SAPPHIRERAPIDS ENABLE_EMERALDRAPIDS + ENABLE_GRANITERAPIDS ENABLE_ZEN1 ENABLE_ZEN2 ENABLE_ZEN3 ENABLE_ARMV8A ENABLE_ARMV8.1A ENABLE_ARMV8.2A ENABLE_ARMV8.3A ENABLE_ARMV8.4A ENABLE_ARMV8.5A ENABLE_ARMV8.6A ENABLE_NATIVE @@ -111,7 +112,8 @@ function(setup_compiler_march_for_x86 VAR_NAME_SSE VAR_NAME_AVX2 VAR_NAME_AVX512 #avx512 set(_x86_flags - "graniterapids" "emeraldrapids" "sapphirerapids" "skylake-avx512" + "graniterapids" "emeraldrapids" "sapphirerapids" + "icelake-server" "skylake-avx512" ) foreach(_arch IN LISTS _x86_flags) check_c_compiler_flag("-march=${_arch}" _COMP_SUPP_${_arch}) @@ -170,6 +172,10 @@ if(NOT AUTO_DETECT_ARCH) add_arch_flag("-march=sapphirerapids" SAPPHIRERAPIDS ENABLE_SAPPHIRERAPIDS) endif() + if(ENABLE_ICELAKE) + add_arch_flag("-march=icelake-server" ICELAKE ENABLE_ICELAKE) + endif() + if(ENABLE_SKYLAKE_AVX512) add_arch_flag("-march=skylake-avx512" SKYLAKE_AVX512 ENABLE_SKYLAKE_AVX512) endif() diff --git a/src/core/metric/quantized_integer_metric.cc b/src/core/metric/quantized_integer_metric.cc index 56e95634..2b4e757a 100644 --- a/src/core/metric/quantized_integer_metric.cc +++ b/src/core/metric/quantized_integer_metric.cc @@ -266,6 +266,10 @@ class QuantizedIntegerMetric : public IndexMetric { meta_.data_type() == IndexMeta::DataType::DT_INT8) { return CosineMinusInnerProductDistanceBatchWithScoreUnquantized< int8_t, 1, 1>::GetQueryPreprocessFunc(); + } else if (origin_metric_type_ == MetricType::kSquaredEuclidean && + meta_.data_type() == IndexMeta::DataType::DT_INT8) { + return SquaredEuclideanDistanceBatchWithScoreUnquantized< + int8_t, 1, 1>::GetQueryPreprocessFunc(); } return nullptr; @@ -320,4 +324,4 @@ class QuantizedIntegerMetric : public IndexMetric { INDEX_FACTORY_REGISTER_METRIC_ALIAS(QuantizedInteger, QuantizedIntegerMetric); } // namespace core -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/core/metric/quantized_integer_metric_batch.h b/src/core/metric/quantized_integer_metric_batch.h index e9e63cef..33bbfa92 100644 --- a/src/core/metric/quantized_integer_metric_batch.h +++ b/src/core/metric/quantized_integer_metric_batch.h @@ -55,6 +55,11 @@ struct BaseDistanceBatchWithScoreUnquantized { return CosineMinusInnerProductDistanceBatchWithScoreUnquantized< ValueType, BatchSize, PrefetchStep>::ComputeBatch(m, q, num, dim, out); + } else if constexpr (std::is_same_v, + SquaredEuclidean>) { + return SquaredEuclideanDistanceBatchWithScoreUnquantized< + ValueType, BatchSize, PrefetchStep>::ComputeBatch(m, q, num, dim, + out); } _ComputeBatch(m, q, num, dim, out); @@ -75,7 +80,7 @@ struct CosineMinusInnerProductDistanceBatchWithScoreUnquantized< static inline void ComputeBatch(const int8_t **vecs, const int8_t *query, size_t num_vecs, size_t dim, float *results) { - size_t original_dim = dim - 20; + size_t original_dim = dim - 24; ImplType::ComputeBatch(vecs, query, num_vecs, original_dim, results); } @@ -87,7 +92,7 @@ struct CosineMinusInnerProductDistanceBatchWithScoreUnquantized< static void QueryPreprocess(void *query, size_t dim) { if (auto func = ImplType::GetQueryPreprocessFunc(); func != nullptr) { - return func(query, dim - 20); + return func(query, dim - 24); } } }; @@ -134,7 +139,7 @@ struct MinusInnerProductDistanceBatchWithScoreUnquantized(m_tail)[3]; + int int_sum = reinterpret_cast(m_tail)[4]; result -= 128 * int_sum; } result = -(ma * qa * result + mb * qa * qs + qb * ma * ms + @@ -192,7 +197,7 @@ struct SquaredEuclideanDistanceBatchWithScoreUnquantized; static void ComputeBatch(const int8_t **vecs, const int8_t *query, size_t num_vecs, size_t dim, float *results) { - const size_t original_dim = dim - 16; + const size_t original_dim = dim - 20; ailego::DistanceBatch::InnerProductDistanceBatch< int8_t, BatchSize, PrefetchStep>::ComputeBatch(vecs, query, num_vecs, original_dim, results); @@ -206,17 +211,21 @@ struct SquaredEuclideanDistanceBatchWithScoreUnquantized( reinterpret_cast(vecs[i]) + original_dim); float ma = m_tail[0]; float mb = m_tail[1]; float ms = m_tail[2]; float ms2 = m_tail[3]; - *results = ma * ma * ms2 + sum2 - 2 * ma * qa * *results + - (mb - qb) * (mb - qb) * original_dim + - 2 * (mb - qb) * (ms * ma - sum); - ++results; + float &result = results[i]; + if (ImplType::GetQueryPreprocessFunc() != nullptr) { + int int8_sum = reinterpret_cast(m_tail)[4]; + result -= 128 * int8_sum; + } + result = ma * ma * ms2 + sum2 - 2 * ma * qa * result + + (mb - qb) * (mb - qb) * original_dim + + 2 * (mb - qb) * (ms * ma - sum); } } @@ -226,7 +235,9 @@ struct SquaredEuclideanDistanceBatchWithScoreUnquantized struct SquaredEuclideanDistanceBatchWithScoreUnquantized { - static void ComputeBatch(const int8_t **vecs, const int8_t *query, + static void ComputeBatch(const uint8_t **vecs, const uint8_t *query, size_t num_vecs, size_t dim, float *results) { const size_t original_dim = dim - 32; const size_t original_dim_in_uint8_array = original_dim >> 1; @@ -251,7 +262,7 @@ struct SquaredEuclideanDistanceBatchWithScoreUnquantized( reinterpret_cast(vecs[i]) + original_dim_in_uint8_array); @@ -281,7 +292,7 @@ struct MipsSquaredEuclideanDistanceBatchWithScoreUnquantized; static void ComputeBatch(const int8_t **vecs, const int8_t *query, size_t num_vecs, size_t dim, float *results) { - const size_t original_dim = dim - 16; + const size_t original_dim = dim - 20; ailego::DistanceBatch::InnerProductDistanceBatch< int8_t, BatchSize, PrefetchStep>::ComputeBatch(vecs, query, num_vecs, original_dim, results); @@ -295,7 +306,7 @@ struct MipsSquaredEuclideanDistanceBatchWithScoreUnquantized( reinterpret_cast(vecs[i]) + original_dim); float ma = m_tail[0]; @@ -310,7 +321,9 @@ struct MipsSquaredEuclideanDistanceBatchWithScoreUnquantized( reinterpret_cast(vecs[i]) + original_dim_in_uint8_array); @@ -351,4 +364,4 @@ struct MipsSquaredEuclideanDistanceBatchWithScoreUnquantized struct SquaredEuclidean { static void Compute(const int8_t *m, const int8_t *q, size_t dim, float *out) { - const size_t d = dim - 16; + const size_t d = dim - 20; ailego::InnerProductMatrix::Compute(m, q, d, out); for (size_t i = 0; i < N; ++i) { @@ -141,7 +141,7 @@ template struct MinusInnerProduct { static void Compute(const int8_t *m, const int8_t *q, size_t dim, float *out) { - const size_t origin_dim = dim - 16; + const size_t origin_dim = dim - 20; MinusInnerProductImplInt8(m, q, origin_dim, out); } }; @@ -168,7 +168,7 @@ template struct CosineMinusInnerProduct { static void Compute(const int8_t *m, const int8_t *q, size_t dim, float *out) { - const size_t origin_dim = dim - 20; + const size_t origin_dim = dim - 24; MinusInnerProductImplInt8(m, q, origin_dim, out); } }; @@ -195,7 +195,7 @@ template struct MipsSquaredEuclidean { static void Compute(const int8_t *m, const int8_t *q, size_t dim, float *out) { - const size_t d = dim - 16; + const size_t d = dim - 20; ailego::InnerProductMatrix::Compute(m, q, d, out); for (size_t i = 0; i < N; ++i) { @@ -251,4 +251,4 @@ struct MipsSquaredEuclidean { } }; -} // namespace zvec::core \ No newline at end of file +} // namespace zvec::core diff --git a/src/core/quantizer/cosine_converter.cc b/src/core/quantizer/cosine_converter.cc index dda76b01..dd5cbbd0 100644 --- a/src/core/quantizer/cosine_converter.cc +++ b/src/core/quantizer/cosine_converter.cc @@ -206,7 +206,7 @@ class CosineConverterHolder : public IndexHolder { if (type == IndexMeta::DataType::DT_INT4) return 40; // 5 * sizeof(float) / sizeof(FT_INT4) else if (type == IndexMeta::DataType::DT_INT8) - return 20; // 5 * sizeof(float) / sizeof(FT_INT8) + return 24; // (5 * sizeof(float) + sizeof(int)) / sizeof(FT_INT8) else if (type == IndexMeta::DataType::DT_FP16) return 2; // 2* sizeof(float) / sizeof(FT_FP16) else if (type == IndexMeta::DataType::DT_FP32) { @@ -362,7 +362,7 @@ class CosineConverter : public IndexConverter { if (type == IndexMeta::DataType::DT_INT4) return 40; // 5 * sizeof(float) / sizeof(FT_INT4) else if (type == IndexMeta::DataType::DT_INT8) - return 20; // 5 * sizeof(float) / sizeof(FT_INT8) + return 24; // (5 * sizeof(float) + sizeof(int)) / sizeof(FT_INT8) else if (type == IndexMeta::DataType::DT_FP16) return 2; // sizeof(float) / sizeof(FT_FP16) else if (type == IndexMeta::DataType::DT_FP32) { @@ -402,4 +402,4 @@ INDEX_FACTORY_REGISTER_CONVERTER_ALIAS(CosineHalfFloatConverter, IndexMeta::DataType::DT_FP16); } // namespace core -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/core/quantizer/cosine_reformer.cc b/src/core/quantizer/cosine_reformer.cc index 5823728d..d6080b8d 100644 --- a/src/core/quantizer/cosine_reformer.cc +++ b/src/core/quantizer/cosine_reformer.cc @@ -249,7 +249,7 @@ class CosineReformer : public IndexReformer { if (type == IndexMeta::DataType::DT_INT4) return 40; // 5 * sizeof(float) / sizeof(FT_INT4) else if (type == IndexMeta::DataType::DT_INT8) - return 20; // 5 * sizeof(float) / sizeof(FT_INT8) + return 24; // (5 * sizeof(float) + sizeof(int)) / sizeof(FT_INT8) else if (type == IndexMeta::DataType::DT_FP16) return 2; // sizeof(float) / sizeof(FT_FP16) else if (type == IndexMeta::DataType::DT_FP32) { diff --git a/src/core/quantizer/integer_quantizer_converter.cc b/src/core/quantizer/integer_quantizer_converter.cc index 91757a5d..1cd88843 100644 --- a/src/core/quantizer/integer_quantizer_converter.cc +++ b/src/core/quantizer/integer_quantizer_converter.cc @@ -581,7 +581,10 @@ class IntegerStreamingConverter : public IndexConverter { static size_t ExtraDimension(IndexMeta::DataType type) { // The extra quantized params storage size to save for each vector constexpr size_t kExtraSize = 4 * sizeof(float); - return type == IndexMeta::DataType::DT_INT8 ? kExtraSize : kExtraSize * 2; + constexpr size_t kAdditionalInt32 = sizeof(int32_t); + return type == IndexMeta::DataType::DT_INT8 + ? (kExtraSize + kAdditionalInt32) + : (kExtraSize * 2); } //! Members diff --git a/src/core/quantizer/integer_quantizer_reformer.cc b/src/core/quantizer/integer_quantizer_reformer.cc index 9c741036..4228d0fd 100644 --- a/src/core/quantizer/integer_quantizer_reformer.cc +++ b/src/core/quantizer/integer_quantizer_reformer.cc @@ -279,7 +279,7 @@ class IntegerStreamingReformer : public IndexReformer { //! Constructor IntegerStreamingReformer(IndexMeta::DataType dst_type) : data_type_(dst_type), - extra_dimension_(data_type_ == IndexMeta::DataType::DT_INT8 ? 16 : 32) { + extra_dimension_(data_type_ == IndexMeta::DataType::DT_INT8 ? 20 : 32) { } //! Initialize Reformer diff --git a/src/core/quantizer/record_quantizer.h b/src/core/quantizer/record_quantizer.h index 06744f69..b1095a2a 100644 --- a/src/core/quantizer/record_quantizer.h +++ b/src/core/quantizer/record_quantizer.h @@ -74,10 +74,16 @@ class RecordQuantizer { extras[0] = 1.0f / scale; extras[1] = -bias / scale; extras[2] = sum; - if (is_euclidean) { + + if (type == IndexMeta::DataType::DT_INT8) { extras[3] = squared_sum; + reinterpret_cast(extras + 4)[0] = int8_sum; } else { - reinterpret_cast(extras)[3] = int8_sum; + if (is_euclidean) { + extras[3] = squared_sum; + } else { + reinterpret_cast(extras)[3] = int8_sum; + } } } } @@ -128,4 +134,4 @@ class RecordQuantizer { }; } // namespace core -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/tests/core/metric/quantized_integer_metric_test.cc b/tests/core/metric/quantized_integer_metric_test.cc index 30e8c256..501d8c7b 100644 --- a/tests/core/metric/quantized_integer_metric_test.cc +++ b/tests/core/metric/quantized_integer_metric_test.cc @@ -251,7 +251,7 @@ void TestDistanceMatrixInt8(const std::string &metric_name) { const size_t batch_size = M; const size_t query_size = N; - size_t dimension = (std::uniform_int_distribution(1, 65))(gen) * 4; + size_t dimension = (std::uniform_int_distribution(1, 65))(gen)*4; auto holder = GetHolder(dimension, batch_size, dist); IndexMeta meta(IndexMeta::DT_FP32, dimension); meta.set_metric(metric_name, 0, Params()); @@ -261,7 +261,7 @@ void TestDistanceMatrixInt8(const std::string &metric_name) { ASSERT_EQ(0u, IndexConverter::TrainAndTransform(converter, holder)); auto holder2 = converter->result(); auto &meta2 = converter->meta(); - ASSERT_EQ(dimension + 16, holder2->dimension()); + ASSERT_EQ(dimension + 20, holder2->dimension()); size_t matrix_size = batch_size * holder2->dimension(); std::vector matrix1(matrix_size); std::vector matrix2(matrix_size); @@ -277,7 +277,7 @@ void TestDistanceMatrixInt8(const std::string &metric_name) { auto query_holder = GetHolder(dimension, query_size, dist); ASSERT_EQ(0u, IndexConverter::TrainAndTransform(converter, query_holder)); auto query_holder2 = converter->result(); - ASSERT_EQ(dimension + 16, query_holder2->dimension()); + ASSERT_EQ(dimension + 20, query_holder2->dimension()); size_t query_matrix_size = query_size * query_holder2->dimension(); std::vector query1(query_matrix_size); std::vector query2(query_matrix_size); @@ -453,7 +453,7 @@ void TestDistanceMatrixInt4(const std::string &metric_name) { const size_t batch_size = M; const size_t query_size = N; - size_t dimension = (std::uniform_int_distribution(1, 65))(gen) * 8; + size_t dimension = (std::uniform_int_distribution(1, 65))(gen)*8; auto holder = GetHolder(dimension, batch_size, dist); IndexMeta meta(IndexMeta::DT_FP32, dimension); meta.set_metric(metric_name, 0, Params()); From 0eb486d0e616242eda0a80a9f6506e782d64b556 Mon Sep 17 00:00:00 2001 From: Qinren Zhou Date: Mon, 16 Mar 2026 10:55:51 +0800 Subject: [PATCH 22/34] fix: use per-block filter instead of per-segment filter during query (#223) * fix: id mismatch * fix: remove debug string * fix: use protected * chore(comment): add comments * feat: more ut * Update vector_recall_test.cc * fix: resolve comments --- .../combined_vector_column_indexer.cc | 13 +++- .../combined_vector_column_indexer.h | 27 +++++++- src/db/index/segment/segment.cc | 8 +-- tests/db/sqlengine/recall_base.h | 1 + tests/db/sqlengine/vector_recall_test.cc | 69 ++++++++++++++++++- 5 files changed, 110 insertions(+), 8 deletions(-) diff --git a/src/db/index/column/vector_column/combined_vector_column_indexer.cc b/src/db/index/column/vector_column/combined_vector_column_indexer.cc index 70c71d07..14fd2193 100644 --- a/src/db/index/column/vector_column/combined_vector_column_indexer.cc +++ b/src/db/index/column/vector_column/combined_vector_column_indexer.cc @@ -104,11 +104,22 @@ Result CombinedVectorColumnIndexer::Search( need_refine = true; } + const IndexFilter *filter{nullptr}; + auto per_block_filter = + BlockOffsetFilter{query_params.filter, block_offsets_[i]}; + if (query_params.filter) { + if (block_offsets_[i] > 0) { + filter = &per_block_filter; + } else { + filter = query_params.filter; + } + } + vector_column_params::QueryParams modified_query_params{ query_params.data_type, query_params.dimension, query_params.topk, - query_params.filter, + filter, query_params.fetch_vector, query_params.query_params, query_params.group_by diff --git a/src/db/index/column/vector_column/combined_vector_column_indexer.h b/src/db/index/column/vector_column/combined_vector_column_indexer.h index b0b0589f..92357918 100644 --- a/src/db/index/column/vector_column/combined_vector_column_indexer.h +++ b/src/db/index/column/vector_column/combined_vector_column_indexer.h @@ -15,9 +15,9 @@ #include #include +#include "db/index/common/index_filter.h" #include "vector_column_indexer.h" #include "vector_column_params.h" -#include "vector_index_results.h" namespace zvec { @@ -42,8 +42,31 @@ class CombinedVectorColumnIndexer { virtual Result Fetch( uint32_t segment_doc_id) const; - // for ut + protected: + /** + * A filter wrapper that applies an offset to document IDs before + * delegating to an inner filter. + * + * This is used when multiple blocks with different ID offsets are stored. + * Each block has its own local ID space, and this filter translates + * block-level IDs to segment-level IDs before checking the inner filter. + */ + class BlockOffsetFilter : public IndexFilter { + public: + BlockOffsetFilter(const IndexFilter *inner_filter, uint64_t offset) + : inner_filter_(inner_filter), offset_(offset) {} + + bool is_filtered(uint64_t id) const override { + return inner_filter_->is_filtered(id + offset_); + } + + private: + const IndexFilter *inner_filter_; + uint64_t offset_; + }; + + // for ut CombinedVectorColumnIndexer() = default; diff --git a/src/db/index/segment/segment.cc b/src/db/index/segment/segment.cc index 2d03cd78..43928f39 100644 --- a/src/db/index/segment/segment.cc +++ b/src/db/index/segment/segment.cc @@ -302,14 +302,14 @@ class SegmentImpl : public Segment, void fresh_persist_chunked_array(); private: - // scalar forward + // scalar forward (uses segment-local doc ID) MemForwardStore::Ptr memory_store_; std::vector persist_stores_; - // scalar index + // scalar index (uses segment-local doc ID) InvertedIndexer::Ptr invert_indexers_; - // vector index + // vector index (uses block-local doc ID, each indexer starts from 0) std::unordered_map memory_vector_indexers_; @@ -339,7 +339,7 @@ class SegmentImpl : public Segment, IDMap::Ptr id_map_; DeleteStore::Ptr delete_store_; - // local_id(index) -> global_doc_id(value) + // Maps segment-local doc ID (array index) to global doc ID (stored value) std::vector doc_ids_; std::array, diff --git a/tests/db/sqlengine/recall_base.h b/tests/db/sqlengine/recall_base.h index 8c2a88ab..3e457147 100644 --- a/tests/db/sqlengine/recall_base.h +++ b/tests/db/sqlengine/recall_base.h @@ -253,6 +253,7 @@ inline Segment::Ptr RecallTest::create_segment() { SegmentOptions options; options.read_only_ = false; options.enable_mmap_ = true; + options.max_buffer_size_ = 256 * 1024; auto result = Segment::CreateAndOpen(GetPath(), *collection_schema_, 0, 0, id_map, diff --git a/tests/db/sqlengine/vector_recall_test.cc b/tests/db/sqlengine/vector_recall_test.cc index 597b63a6..d3dbccd2 100644 --- a/tests/db/sqlengine/vector_recall_test.cc +++ b/tests/db/sqlengine/vector_recall_test.cc @@ -212,4 +212,71 @@ TEST_F(VectorRecallTest, Sparse) { } } -} // namespace zvec::sqlengine \ No newline at end of file +TEST_F(VectorRecallTest, DeleteFilter) { + // This test uses only one segment and thus we only operate on the first one + for (int i = 0; i < 4000; i++) { + segments_[0]->Delete("pk_" + std::to_string(i)); + } + + VectorQuery query; + query.output_fields_ = {"name", "age"}; + query.topk_ = 100; + std::vector feature(4, 0.0); + query.query_vector_.assign((const char *)feature.data(), + feature.size() * sizeof(float)); + query.field_name_ = "dense"; + + auto engine = SQLEngine::create(std::make_shared()); + auto ret = engine->execute(collection_schema_, query, segments_); + if (!ret) { + LOG_ERROR("execute failed: [%s]", ret.error().c_str()); + } + ASSERT_TRUE(ret.has_value()); + auto docs = ret.value(); + EXPECT_EQ(docs.size(), 100); + for (size_t j = 0; j < docs.size(); j++) { + auto &doc = docs[j]; + int doc_id = j + 4000; + EXPECT_EQ(doc->pk(), "pk_" + std::to_string(doc_id)); + auto age = doc->get("age"); + EXPECT_EQ(age.value(), doc_id % 100); + auto name = doc->get("name"); + ASSERT_TRUE(name); + EXPECT_EQ(name.value(), "user_" + std::to_string(doc_id % 100)); + EXPECT_FLOAT_EQ(doc->score(), (float)doc_id * doc_id * 4); + } +} + +TEST_F(VectorRecallTest, HybridInvertForwardDeleteFilter) { + // In previous test, docs[0-4000) has been deleted + VectorQuery query; + query.output_fields_ = {"name", "age"}; + query.filter_ = "invert_id >= 6000 and id < 6080"; + query.topk_ = 100; + std::vector feature(4, 0.0); + query.query_vector_.assign((const char *)feature.data(), + feature.size() * sizeof(float)); + query.field_name_ = "dense"; + + auto engine = SQLEngine::create(std::make_shared()); + auto ret = engine->execute(collection_schema_, query, segments_); + if (!ret) { + LOG_ERROR("execute failed: [%s]", ret.error().c_str()); + } + ASSERT_TRUE(ret.has_value()); + auto docs = ret.value(); + EXPECT_EQ(docs.size(), 80); + for (size_t j = 0; j < docs.size(); j++) { + auto &doc = docs[j]; + int doc_id = j + 6000; + EXPECT_EQ(doc->pk(), "pk_" + std::to_string(doc_id)); + auto age = doc->get("age"); + EXPECT_EQ(age.value(), doc_id % 100); + auto name = doc->get("name"); + ASSERT_TRUE(name); + EXPECT_EQ(name.value(), "user_" + std::to_string(doc_id % 100)); + EXPECT_FLOAT_EQ(doc->score(), (float)doc_id * doc_id * 4); + } +} + +} // namespace zvec::sqlengine From 5cdb3d5aa59bf471b52da3b84c69c1e188c5382f Mon Sep 17 00:00:00 2001 From: Abdur-Rahmaan Janhangeer Date: Mon, 16 Mar 2026 07:31:44 +0400 Subject: [PATCH 23/34] fix: minor typo (#225) --- src/db/collection.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/db/collection.cc b/src/db/collection.cc index 6a2a70c2..1ff0eecb 100644 --- a/src/db/collection.cc +++ b/src/db/collection.cc @@ -1722,7 +1722,7 @@ Status CollectionImpl::create() { } if (ailego::FileHelper::IsExist(path_.c_str())) { return Status::InvalidArgument("path validate failed: path[", path_, - "] is existed"); + "] exists"); } // check schema @@ -1876,4 +1876,4 @@ std::vector CollectionImpl::get_all_persist_segments() const { return segment_manager_->get_segments(); } -} // namespace zvec \ No newline at end of file +} // namespace zvec From 4f2cce008a7e9d6ad6d6ef8a6f83832a50e99b1a Mon Sep 17 00:00:00 2001 From: rayx Date: Mon, 16 Mar 2026 13:41:58 +0800 Subject: [PATCH 24/34] fix/fix mips euclidean (#226) * fix: fix mips euclidean --- .../math/inner_product_matrix_fp32_avx.cc | 4 + .../math/inner_product_matrix_fp32_avx512.cc | 4 + .../inner_product_matrix_fp32_dispatch.cc | 9 +- .../math/inner_product_matrix_fp32_neon.cc | 5 + .../math/inner_product_matrix_fp32_sse.cc | 5 + ...mips_euclidean_distance_matrix_fp16_avx.cc | 37 +++++- ...s_euclidean_distance_matrix_fp16_avx512.cc | 37 +++++- ...euclidean_distance_matrix_fp16_dispatch.cc | 76 +++++------ ...ips_euclidean_distance_matrix_fp16_neon.cc | 37 +++++- ...mips_euclidean_distance_matrix_fp32_avx.cc | 49 ++++++- ...s_euclidean_distance_matrix_fp32_avx512.cc | 58 ++++++++- ...euclidean_distance_matrix_fp32_dispatch.cc | 121 +++++++++++------- ...ips_euclidean_distance_matrix_fp32_neon.cc | 30 +---- ...mips_euclidean_distance_matrix_fp32_sse.cc | 36 +++++- ...ips_euclidean_distance_matrix_int4_avx2.cc | 40 +++++- ...euclidean_distance_matrix_int4_dispatch.cc | 62 ++++----- ...mips_euclidean_distance_matrix_int4_sse.cc | 37 +++++- ...ips_euclidean_distance_matrix_int8_avx2.cc | 36 +++++- ...euclidean_distance_matrix_int8_dispatch.cc | 60 ++++----- ...mips_euclidean_distance_matrix_int8_sse.cc | 37 +++++- 20 files changed, 574 insertions(+), 206 deletions(-) diff --git a/src/ailego/math/inner_product_matrix_fp32_avx.cc b/src/ailego/math/inner_product_matrix_fp32_avx.cc index 128adfdf..23c1f13f 100644 --- a/src/ailego/math/inner_product_matrix_fp32_avx.cc +++ b/src/ailego/math/inner_product_matrix_fp32_avx.cc @@ -88,6 +88,10 @@ float InnerProductAVX(const float *lhs, const float *rhs, size_t size) { return result; } +float MinusInnerProductAVX(const float *lhs, const float *rhs, size_t size) { + return -1 * InnerProductAVX(lhs, rhs, size); +} + #endif // __AVX__ } // namespace ailego diff --git a/src/ailego/math/inner_product_matrix_fp32_avx512.cc b/src/ailego/math/inner_product_matrix_fp32_avx512.cc index af3bf74c..c888115b 100644 --- a/src/ailego/math/inner_product_matrix_fp32_avx512.cc +++ b/src/ailego/math/inner_product_matrix_fp32_avx512.cc @@ -69,6 +69,10 @@ float InnerProductAVX512(const float *lhs, const float *rhs, size_t size) { return HorizontalAdd_FP32_V512(zmm_sum_0); } +float MinusInnerProductAVX512(const float *lhs, const float *rhs, size_t size) { + return -1 * InnerProductAVX512(lhs, rhs, size); +} + #endif } // namespace ailego diff --git a/src/ailego/math/inner_product_matrix_fp32_dispatch.cc b/src/ailego/math/inner_product_matrix_fp32_dispatch.cc index 57acef21..175dbf96 100644 --- a/src/ailego/math/inner_product_matrix_fp32_dispatch.cc +++ b/src/ailego/math/inner_product_matrix_fp32_dispatch.cc @@ -25,6 +25,7 @@ float MinusInnerProductNEON(const float *lhs, const float *rhs, size_t size); #if defined(__AVX512F__) float InnerProductAVX512(const float *lhs, const float *rhs, size_t size); +float MinusInnerProductAVX512(const float *lhs, const float *rhs, size_t size); #endif #if defined(__AVX__) @@ -70,12 +71,12 @@ void MinusInnerProductMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { #if defined(__ARM_NEON) - *out = -InnerProductNEON(m, q, dim); + *out = MinusInnerProductNEON(m, q, dim); #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { if (dim > 15) { - *out = -InnerProductAVX512(m, q, dim); + *out = MinusInnerProductAVX512(m, q, dim); return; } } @@ -83,12 +84,12 @@ void MinusInnerProductMatrix::Compute(const ValueType *m, #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { if (dim > 7) { - *out = -InnerProductAVX(m, q, dim); + *out = MinusInnerProductAVX(m, q, dim); return; } } #endif // __AVX__ - *out = -InnerProductSSE(m, q, dim); + *out = MinusInnerProductSSE(m, q, dim); #endif // __ARM_NEON } diff --git a/src/ailego/math/inner_product_matrix_fp32_neon.cc b/src/ailego/math/inner_product_matrix_fp32_neon.cc index e8626a3b..011f908f 100644 --- a/src/ailego/math/inner_product_matrix_fp32_neon.cc +++ b/src/ailego/math/inner_product_matrix_fp32_neon.cc @@ -51,6 +51,11 @@ float InnerProductNEON(const float *lhs, const float *rhs, size_t size) { } return result; } + +float MinusInnerProductNEON(const float *lhs, const float *rhs, size_t size) { + return -1 * InnerProductNEON(lhs, rhs, size); +} + #endif // __ARM_NEON } // namespace ailego diff --git a/src/ailego/math/inner_product_matrix_fp32_sse.cc b/src/ailego/math/inner_product_matrix_fp32_sse.cc index 8a302bf9..f90801ee 100644 --- a/src/ailego/math/inner_product_matrix_fp32_sse.cc +++ b/src/ailego/math/inner_product_matrix_fp32_sse.cc @@ -74,6 +74,11 @@ float InnerProductSSE(const float *lhs, const float *rhs, size_t size) { return result; } + +float MinusInnerProductSSE(const float *lhs, const float *rhs, size_t size) { + return -1 * InnerProductSSE(lhs, rhs, size); +} + #endif // __SSE__ // #if 1 diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx.cc index c93edc1c..bc066efc 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx.cc @@ -110,7 +110,42 @@ float InnerProductAndSquaredNormAVX(const Float16 *lhs, const Float16 *rhs, *sqr = norm2; return result; } + +float MipsEucldeanDistanceSphericalInjectionAVX(const Float16 *lhs, + const Float16 *rhs, size_t size, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormAVX(lhs, rhs, size, &u2, &v2); + + return ComputeSphericalInjection(sum, u2, v2, e2); +} + +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX(const Float16 *lhs, + const Float16 *rhs, + size_t size, size_t m, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormAVX(lhs, rhs, size, &u2, &v2); + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + + return sum; +} + #endif // __AVX__ && __F16C__ } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx512.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx512.cc index 51ce4fc4..fb87aa6a 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx512.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp16_avx512.cc @@ -128,7 +128,42 @@ float InnerProductAndSquaredNormAVX512(const Float16 *lhs, const Float16 *rhs, *sqr = norm2; return result; } + +float MipsEucldeanDistanceSphericalInjectionAVX512(const Float16 *lhs, + const Float16 *rhs, + size_t size, float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormAVX512(lhs, rhs, size, &u2, &v2); + + return ComputeSphericalInjection(sum, u2, v2, e2); +} + +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX512(const Float16 *lhs, + const Float16 *rhs, + size_t size, + size_t m, float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormAVX512(lhs, rhs, size, &u2, &v2); + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + + return sum; +} + #endif // __AVX512F__ } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp16_dispatch.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp16_dispatch.cc index b99ab45e..be997fb7 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_fp16_dispatch.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp16_dispatch.cc @@ -19,18 +19,33 @@ namespace zvec { namespace ailego { #if defined(__ARM_NEON) -float InnerProductAndSquaredNormNEON(const Float16 *lhs, const Float16 *rhs, - size_t size, float *sql, float *sqr); +float MipsEucldeanDistanceRepeatedQuadraticInjectionNEON(const Float16 *lhs, + const Float16 *rhs, + size_t size, size_t m, + float e2); +float MipsEucldeanDistanceSphericalInjectionNEON(const Float16 *lhs, + const Float16 *rhs, + size_t size, float e2); #endif #if defined(__AVX512F__) -float InnerProductAndSquaredNormAVX512(const Float16 *lhs, const Float16 *rhs, - size_t size, float *sql, float *sqr); +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX512(const Float16 *lhs, + const Float16 *rhs, + size_t size, + size_t m, float e2); +float MipsEucldeanDistanceSphericalInjectionAVX512(const Float16 *lhs, + const Float16 *rhs, + size_t size, float e2); #endif #if defined(__AVX__) -float InnerProductAndSquaredNormAVX(const Float16 *lhs, const Float16 *rhs, - size_t size, float *sql, float *sqr); +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX(const Float16 *lhs, + const Float16 *rhs, + size_t size, size_t m, + float e2); +float MipsEucldeanDistanceSphericalInjectionAVX(const Float16 *lhs, + const Float16 *rhs, size_t size, + float e2); #endif #if (defined(__F16C__) && defined(__AVX__)) || \ @@ -38,59 +53,38 @@ float InnerProductAndSquaredNormAVX(const Float16 *lhs, const Float16 *rhs, //! Compute the distance between matrix and query by SphericalInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { - float u2{0.0f}; - float v2{0.0f}; - float sum{0.0f}; - #if defined(__ARM_NEON) - sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); + *out = MipsEucldeanDistanceSphericalInjectionNEON(p, q, dim, e2); #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); - } else -#endif //__AVX512F__ - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { - sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); - } + *out = MipsEucldeanDistanceSphericalInjectionAVX512(p, q, dim, e2); + return; + } +#endif + *out = MipsEucldeanDistanceSphericalInjectionAVX(p, q, dim, e2); #endif //__ARM_NEON - - *out = ComputeSphericalInjection(sum, u2, v2, e2); } //! Compute the distance between matrix and query by RepeatedQuadraticInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { - float u2{0.0f}; - float v2{0.0f}; - float sum{0.0f}; - #if defined(__ARM_NEON) - sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); + *out = MipsEucldeanDistanceRepeatedQuadraticInjectionNEON(p, q, dim, m, e2); #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); - } else -#endif //__AVX512F__ - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { - sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); - } -#endif //__ARM_NEON - - sum = e2 * (u2 + v2 - 2 * sum); - u2 *= e2; - v2 *= e2; - for (size_t i = 0; i < m; ++i) { - sum += (u2 - v2) * (u2 - v2); - u2 = u2 * u2; - v2 = v2 * v2; + *out = + MipsEucldeanDistanceRepeatedQuadraticInjectionAVX512(p, q, dim, m, e2); + return; } - *out = sum; +#endif + *out = MipsEucldeanDistanceRepeatedQuadraticInjectionAVX(p, q, dim, m, e2); +#endif //__ARM_NEON } #endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp16_neon.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp16_neon.cc index 22493b4e..8a1dd0e1 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_fp16_neon.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp16_neon.cc @@ -119,8 +119,43 @@ float InnerProductAndSquaredNormNEON(const Float16 *lhs, const Float16 *rhs, *sqr = norm2; return result; } + #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +float MipsEucldeanDistanceSphericalInjectionNEON(const Float16 *lhs, + const Float16 *rhs, + size_t size, float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormNEON(lhs, rhs, size, &u2, &v2); + + return ComputeSphericalInjection(sum, u2, v2, e2); +} + +float MipsEucldeanDistanceRepeatedQuadraticInjectionNEON(const Float16 *lhs, + const Float16 *rhs, + size_t size, size_t m, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormNEON(lhs, rhs, size, &u2, &v2); + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + + return sum; +} #endif // __ARM_NEON && __aarch64__ } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx.cc index cff60e8f..ac958e86 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx.cc @@ -19,6 +19,11 @@ namespace zvec { namespace ailego { +#if defined(__SSE__) +float InnerProductAndSquaredNormSSE(const float *lhs, const float *rhs, + size_t size, float *sql, float *sqr); +#endif + #if defined(__AVX__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormAVX(const float *lhs, const float *rhs, @@ -108,7 +113,49 @@ float InnerProductAndSquaredNormAVX(const float *lhs, const float *rhs, *sqr = norm2; return result; } + +float MipsEucldeanDistanceSphericalInjectionAVX(const float *lhs, + const float *rhs, size_t size, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + if (size > 7) { + sum = InnerProductAndSquaredNormAVX(lhs, rhs, size, &u2, &v2); + } else { + sum = InnerProductAndSquaredNormSSE(lhs, rhs, size, &u2, &v2); + } + + return ComputeSphericalInjection(sum, u2, v2, e2); +} + +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX(const float *lhs, + const float *rhs, + size_t size, size_t m, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + if (size > 7) { + sum = InnerProductAndSquaredNormAVX(lhs, rhs, size, &u2, &v2); + } else { + sum = InnerProductAndSquaredNormSSE(lhs, rhs, size, &u2, &v2); + } + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + + return sum; +} #endif // __AVX__ } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx512.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx512.cc index 1ac56a20..d48080e7 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx512.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp32_avx512.cc @@ -19,6 +19,16 @@ namespace zvec { namespace ailego { +#if defined(__SSE__) +float InnerProductAndSquaredNormSSE(const float *lhs, const float *rhs, + size_t size, float *sql, float *sqr); +#endif + +#if defined(__AVX__) +float InnerProductAndSquaredNormAVX(const float *lhs, const float *rhs, + size_t size, float *sql, float *sqr); +#endif + #if defined(__AVX512F__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormAVX512(const float *lhs, const float *rhs, @@ -94,7 +104,53 @@ float InnerProductAndSquaredNormAVX512(const float *lhs, const float *rhs, *sqr = HorizontalAdd_FP32_V512(zmm_sum_norm2); return HorizontalAdd_FP32_V512(zmm_sum_0); } + +float MipsEucldeanDistanceSphericalInjectionAVX512(const float *lhs, + const float *rhs, + size_t size, float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + if (size > 15) { + sum = InnerProductAndSquaredNormAVX512(lhs, rhs, size, &u2, &v2); + } else if (size > 7) { + sum = InnerProductAndSquaredNormAVX(lhs, rhs, size, &u2, &v2); + } else { + sum = InnerProductAndSquaredNormSSE(lhs, rhs, size, &u2, &v2); + } + + return ComputeSphericalInjection(sum, u2, v2, e2); +} + +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX512(const float *lhs, + const float *rhs, + size_t size, + size_t m, float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + if (size > 15) { + sum = InnerProductAndSquaredNormAVX512(lhs, rhs, size, &u2, &v2); + } else if (size > 7) { + sum = InnerProductAndSquaredNormAVX(lhs, rhs, size, &u2, &v2); + } else { + sum = InnerProductAndSquaredNormSSE(lhs, rhs, size, &u2, &v2); + } + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + + return sum; +} #endif // __AVX512F__ } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp32_dispatch.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp32_dispatch.cc index 992da0d1..10cfec9b 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_fp32_dispatch.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp32_dispatch.cc @@ -24,18 +24,33 @@ float InnerProductAndSquaredNormNEON(const float *lhs, const float *rhs, #endif #if defined(__AVX512F__) -float InnerProductAndSquaredNormAVX512(const float *lhs, const float *rhs, - size_t size, float *sql, float *sqr); +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX512(const float *lhs, + const float *rhs, + size_t size, + size_t m, float e2); +float MipsEucldeanDistanceSphericalInjectionAVX512(const float *lhs, + const float *rhs, + size_t size, float e2); #endif #if defined(__AVX__) -float InnerProductAndSquaredNormAVX(const float *lhs, const float *rhs, - size_t size, float *sql, float *sqr); +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX(const float *lhs, + const float *rhs, + size_t size, size_t m, + float e2); +float MipsEucldeanDistanceSphericalInjectionAVX(const float *lhs, + const float *rhs, size_t size, + float e2); #endif #if defined(__SSE__) -float InnerProductAndSquaredNormSSE(const float *lhs, const float *rhs, - size_t size, float *sql, float *sqr); +float MipsEucldeanDistanceRepeatedQuadraticInjectionSSE(const float *lhs, + const float *rhs, + size_t size, size_t m, + float e2); +float MipsEucldeanDistanceSphericalInjectionSSE(const float *lhs, + const float *rhs, size_t size, + float e2); #endif #if defined(__SSE4_1__) @@ -58,58 +73,39 @@ float MipsInnerProductSparseInSegment(uint32_t m_sparse_count, //! Compute the distance between matrix and query by SphericalInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { - float u2{0.0f}; - float v2{0.0f}; - float sum{0.0f}; - #if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F && dim > 15) { - sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); - } else -#endif // __AVX512F__ + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + *out = MipsEucldeanDistanceSphericalInjectionAVX512(p, q, dim, e2); + return; + } +#endif //__AVX512F__ #if defined(__AVX__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX && dim > 7) { - sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); - } else -#endif // __AVX__ - { - sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { + *out = MipsEucldeanDistanceSphericalInjectionAVX(p, q, dim, e2); + return; } - - *out = ComputeSphericalInjection(sum, u2, v2, e2); +#endif // __AVX__ + *out = MipsEucldeanDistanceSphericalInjectionSSE(p, q, dim, e2); } //! Compute the distance between matrix and query by RepeatedQuadraticInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { - float u2{0.0f}; - float v2{0.0f}; - float sum{0.0f}; - #if defined(__AVX512F__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F && dim > 15) { - sum = InnerProductAndSquaredNormAVX512(p, q, dim, &u2, &v2); - } else -#endif // __AVX512F__ -#if defined(__AVX__) - if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX && dim > 7) { - sum = InnerProductAndSquaredNormAVX(p, q, dim, &u2, &v2); - } else -#endif // __AVX__ - { - sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + *out = + MipsEucldeanDistanceRepeatedQuadraticInjectionAVX512(p, q, dim, m, e2); + return; } - - sum = e2 * (u2 + v2 - 2 * sum); - u2 *= e2; - v2 *= e2; - for (size_t i = 0; i < m; ++i) { - sum += (u2 - v2) * (u2 - v2); - u2 = u2 * u2; - v2 = v2 * v2; +#endif //__AVX512F__ +#if defined(__AVX__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { + *out = MipsEucldeanDistanceRepeatedQuadraticInjectionAVX(p, q, dim, m, e2); + return; } - *out = sum; +#endif // __AVX__ + *out = MipsEucldeanDistanceRepeatedQuadraticInjectionSSE(p, q, dim, m, e2); } #endif // __SSE__ @@ -132,5 +128,36 @@ float MipsSquaredEuclideanSparseDistanceMatrix:: #endif } +#if defined(__ARM_NEON) +//! Compute the distance between matrix and query by SphericalInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { + float u2{0.0f}; + float v2{0.0f}; + float sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); + + *out = ComputeSphericalInjection(sum, u2, v2, e2); +} + +//! Compute the distance between matrix and query by RepeatedQuadraticInjection +void MipsSquaredEuclideanDistanceMatrix::Compute( + const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, + float *out) { + float u2{0.0f}; + float v2{0.0f}; + float sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + *out = sum; +} +#endif //__ARM_NEON + } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp32_neon.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp32_neon.cc index 8e98922c..ca536c32 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_fp32_neon.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp32_neon.cc @@ -71,35 +71,7 @@ float InnerProductAndSquaredNormNEON(const float *lhs, const float *rhs, return result; } -//! Compute the distance between matrix and query by SphericalInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { - float u2; - float v2; - float sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); - - *out = ComputeSphericalInjection(sum, u2, v2, e2); -} - -//! Compute the distance between matrix and query by RepeatedQuadraticInjection -void MipsSquaredEuclideanDistanceMatrix::Compute( - const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, - float *out) { - float u2; - float v2; - float sum = InnerProductAndSquaredNormNEON(p, q, dim, &u2, &v2); - - sum = e2 * (u2 + v2 - 2 * sum); - u2 *= e2; - v2 *= e2; - for (size_t i = 0; i < m; ++i) { - sum += (u2 - v2) * (u2 - v2); - u2 = u2 * u2; - v2 = v2 * v2; - } - *out = sum; -} #endif //__ARM_NEON } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_fp32_sse.cc b/src/ailego/math/mips_euclidean_distance_matrix_fp32_sse.cc index 43d8f9b7..357703db 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_fp32_sse.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_fp32_sse.cc @@ -96,6 +96,40 @@ float InnerProductAndSquaredNormSSE(const float *lhs, const float *rhs, return result; } +float MipsEucldeanDistanceSphericalInjectionSSE(const float *lhs, + const float *rhs, size_t size, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormSSE(lhs, rhs, size, &u2, &v2); + + return ComputeSphericalInjection(sum, u2, v2, e2); +} + +float MipsEucldeanDistanceRepeatedQuadraticInjectionSSE(const float *lhs, + const float *rhs, + size_t size, size_t m, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormSSE(lhs, rhs, size, &u2, &v2); + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + + return sum; +} + #endif // __SSE__ // #if 1 @@ -333,4 +367,4 @@ float MipsInnerProductSparseInSegment(uint32_t m_sparse_count, #endif // __SSE4_1__ } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int4_avx2.cc b/src/ailego/math/mips_euclidean_distance_matrix_int4_avx2.cc index 95a3f007..378fd757 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_int4_avx2.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_int4_avx2.cc @@ -23,8 +23,8 @@ namespace ailego { #if defined(__AVX2__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value -float InnerProductAndSquaredNormAVX(const uint8_t *lhs, const uint8_t *rhs, - size_t size, float *sql, float *sqr) { +float InnerProductAndSquaredNormAVX2(const uint8_t *lhs, const uint8_t *rhs, + size_t size, float *sql, float *sqr) { const uint8_t *last = lhs + size; const uint8_t *last_aligned = lhs + ((size >> 5) << 5); __m256i ymm_sum_0 = _mm256_setzero_si256(); @@ -134,7 +134,41 @@ float InnerProductAndSquaredNormAVX(const uint8_t *lhs, const uint8_t *rhs, *sqr = norm2; return result; } + +float MipsEucldeanDistanceSphericalInjectionAVX2(const uint8_t *lhs, + const uint8_t *rhs, + size_t size, float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormAVX2(lhs, rhs, size >> 1, &u2, &v2); + + return ComputeSphericalInjection(sum, u2, v2, e2); +} + +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX2(const uint8_t *lhs, + const uint8_t *rhs, + size_t size, size_t m, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormAVX2(lhs, rhs, size >> 1, &u2, &v2); + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + + return sum; +} #endif // __AVX2__ } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int4_dispatch.cc b/src/ailego/math/mips_euclidean_distance_matrix_int4_dispatch.cc index c967f832..238eb468 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_int4_dispatch.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_int4_dispatch.cc @@ -20,64 +20,52 @@ namespace zvec { namespace ailego { -#if defined(__AVX__) -float InnerProductAndSquaredNormAVX(const uint8_t *lhs, const uint8_t *rhs, - size_t size, float *sql, float *sqr); +#if defined(__AVX2__) +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX2(const uint8_t *lhs, + const uint8_t *rhs, + size_t size, size_t m, + float e2); +float MipsEucldeanDistanceSphericalInjectionAVX2(const uint8_t *lhs, + const uint8_t *rhs, + size_t size, float e2); #endif -#if defined(__SSE__) -float InnerProductAndSquaredNormSSE(const uint8_t *lhs, const uint8_t *rhs, - size_t size, float *sql, float *sqr); +#if defined(__SSE4_1__) +float MipsEucldeanDistanceRepeatedQuadraticInjectionSSE(const uint8_t *lhs, + const uint8_t *rhs, + size_t size, size_t m, + float e2); +float MipsEucldeanDistanceSphericalInjectionSSE(const uint8_t *lhs, + const uint8_t *rhs, size_t size, + float e2); #endif #if defined(__SSE4_1__) //! Compute the distance between matrix and query by SphericalInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { - float u2{0.0f}; - float v2{0.0f}; - float sum{0.0f}; - #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - sum = InnerProductAndSquaredNormAVX(p, q, dim >> 1, &u2, &v2); - } else -#endif - { - sum = InnerProductAndSquaredNormSSE(p, q, dim >> 1, &u2, &v2); + *out = MipsEucldeanDistanceSphericalInjectionAVX2(p, q, dim, e2); + return; } - - *out = ComputeSphericalInjection(sum, u2, v2, e2); +#endif + *out = MipsEucldeanDistanceSphericalInjectionSSE(p, q, dim, e2); } //! Compute the distance between matrix and query by RepeatedQuadraticInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { - float u2{0.0f}; - float v2{0.0f}; - float sum{0.0f}; - #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - sum = InnerProductAndSquaredNormAVX(p, q, dim >> 1, &u2, &v2); - } else -#endif - { - sum = InnerProductAndSquaredNormSSE(p, q, dim >> 1, &u2, &v2); + *out = MipsEucldeanDistanceRepeatedQuadraticInjectionAVX2(p, q, dim, m, e2); + return; } - - sum = e2 * (u2 + v2 - 2 * sum); - u2 *= e2; - v2 *= e2; - for (size_t i = 0; i < m; ++i) { - sum += (u2 - v2) * (u2 - v2); - u2 = u2 * u2; - v2 = v2 * v2; - } - *out = sum; +#endif + *out = MipsEucldeanDistanceRepeatedQuadraticInjectionSSE(p, q, dim, m, e2); } #endif } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int4_sse.cc b/src/ailego/math/mips_euclidean_distance_matrix_int4_sse.cc index 139b14c9..0537d347 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_int4_sse.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_int4_sse.cc @@ -98,7 +98,42 @@ float InnerProductAndSquaredNormSSE(const uint8_t *lhs, const uint8_t *rhs, *sqr = norm2; return result; } + +float MipsEucldeanDistanceSphericalInjectionSSE(const uint8_t *lhs, + const uint8_t *rhs, size_t size, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormSSE(lhs, rhs, size >> 1, &u2, &v2); + + return ComputeSphericalInjection(sum, u2, v2, e2); +} + +float MipsEucldeanDistanceRepeatedQuadraticInjectionSSE(const uint8_t *lhs, + const uint8_t *rhs, + size_t size, size_t m, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormSSE(lhs, rhs, size >> 1, &u2, &v2); + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + + return sum; +} + #endif // __SSE4_1__ } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int8_avx2.cc b/src/ailego/math/mips_euclidean_distance_matrix_int8_avx2.cc index 0b969537..65a7cc8a 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_int8_avx2.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_int8_avx2.cc @@ -153,7 +153,41 @@ float InnerProductAndSquaredNormAVX2(const int8_t *lhs, const int8_t *rhs, *sqr = norm2; return result; } + +float MipsEucldeanDistanceSphericalInjectionAVX2(const int8_t *lhs, + const int8_t *rhs, size_t size, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormAVX2(lhs, rhs, size, &u2, &v2); + + return ComputeSphericalInjection(sum, u2, v2, e2); +} + +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX2(const int8_t *lhs, + const int8_t *rhs, + size_t size, size_t m, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormAVX2(lhs, rhs, size, &u2, &v2); + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + + return sum; +} #endif // __AVX2__ } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int8_dispatch.cc b/src/ailego/math/mips_euclidean_distance_matrix_int8_dispatch.cc index 35b4a8c8..5512c6c5 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_int8_dispatch.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_int8_dispatch.cc @@ -19,63 +19,51 @@ namespace zvec { namespace ailego { #if defined(__AVX2__) -float InnerProductAndSquaredNormAVX2(const int8_t *lhs, const int8_t *rhs, - size_t size, float *sql, float *sqr); +float MipsEucldeanDistanceRepeatedQuadraticInjectionAVX2(const int8_t *lhs, + const int8_t *rhs, + size_t size, size_t m, + float e2); +float MipsEucldeanDistanceSphericalInjectionAVX2(const int8_t *lhs, + const int8_t *rhs, size_t size, + float e2); #endif -#if defined(__SSE__) -float InnerProductAndSquaredNormSSE(const int8_t *lhs, const int8_t *rhs, - size_t size, float *sql, float *sqr); +#if defined(__SSE4_1__) +float MipsEucldeanDistanceRepeatedQuadraticInjectionSSE(const int8_t *lhs, + const int8_t *rhs, + size_t size, size_t m, + float e2); +float MipsEucldeanDistanceSphericalInjectionSSE(const int8_t *lhs, + const int8_t *rhs, size_t size, + float e2); #endif #if defined(__SSE4_1__) //! Compute the distance between matrix and query by SphericalInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { - float u2{0.0f}; - float v2{0.0f}; - float sum{0.0f}; - #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - sum = InnerProductAndSquaredNormAVX2(p, q, dim, &u2, &v2); - } else -#endif - { - sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); + *out = MipsEucldeanDistanceSphericalInjectionAVX2(p, q, dim, e2); + return; } - - *out = ComputeSphericalInjection(sum, u2, v2, e2); +#endif + *out = MipsEucldeanDistanceSphericalInjectionSSE(p, q, dim, e2); } //! Compute the distance between matrix and query by RepeatedQuadraticInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { - float u2{0.0f}; - float v2{0.0f}; - float sum{0.0f}; - #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - sum = InnerProductAndSquaredNormAVX2(p, q, dim, &u2, &v2); - } else -#endif - { - sum = InnerProductAndSquaredNormSSE(p, q, dim, &u2, &v2); + *out = MipsEucldeanDistanceRepeatedQuadraticInjectionAVX2(p, q, dim, m, e2); + return; } - - sum = e2 * (u2 + v2 - 2 * sum); - u2 *= e2; - v2 *= e2; - for (size_t i = 0; i < m; ++i) { - sum += (u2 - v2) * (u2 - v2); - u2 = u2 * u2; - v2 = v2 * v2; - } - *out = sum; +#endif + *out = MipsEucldeanDistanceRepeatedQuadraticInjectionSSE(p, q, dim, m, e2); } #endif // __SSE4_1__ } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/math/mips_euclidean_distance_matrix_int8_sse.cc b/src/ailego/math/mips_euclidean_distance_matrix_int8_sse.cc index a0d6192c..8a92f52c 100644 --- a/src/ailego/math/mips_euclidean_distance_matrix_int8_sse.cc +++ b/src/ailego/math/mips_euclidean_distance_matrix_int8_sse.cc @@ -131,7 +131,42 @@ float InnerProductAndSquaredNormSSE(const int8_t *lhs, const int8_t *rhs, *sqr = norm2; return result; } + +float MipsEucldeanDistanceSphericalInjectionSSE(const int8_t *lhs, + const int8_t *rhs, size_t size, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormSSE(lhs, rhs, size, &u2, &v2); + + return ComputeSphericalInjection(sum, u2, v2, e2); +} + +float MipsEucldeanDistanceRepeatedQuadraticInjectionSSE(const int8_t *lhs, + const int8_t *rhs, + size_t size, size_t m, + float e2) { + float u2{0.0f}; + float v2{0.0f}; + float sum{0.0f}; + + sum = InnerProductAndSquaredNormSSE(lhs, rhs, size, &u2, &v2); + + sum = e2 * (u2 + v2 - 2 * sum); + u2 *= e2; + v2 *= e2; + for (size_t i = 0; i < m; ++i) { + sum += (u2 - v2) * (u2 - v2); + u2 = u2 * u2; + v2 = v2 * v2; + } + + return sum; +} + #endif // __SSE4_1__ } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec From a027f0c100d98758020673ebbb50869a0e2a98b9 Mon Sep 17 00:00:00 2001 From: Cuiys Date: Mon, 16 Mar 2026 14:20:01 +0800 Subject: [PATCH 25/34] feat: buildwheel in ghrunner (#221) --- .github/workflows/_build_wheel_job.yml | 102 ++++++++++++++++++++++ .github/workflows/build_test_wheel.yml | 115 +++++------------------- .github/workflows/build_wheel.yml | 116 +++++-------------------- pyproject.toml | 12 ++- 4 files changed, 153 insertions(+), 192 deletions(-) create mode 100644 .github/workflows/_build_wheel_job.yml diff --git a/.github/workflows/_build_wheel_job.yml b/.github/workflows/_build_wheel_job.yml new file mode 100644 index 00000000..34b46b71 --- /dev/null +++ b/.github/workflows/_build_wheel_job.yml @@ -0,0 +1,102 @@ +name: "(Reusable) Build, Publish and Smoke-test a Wheel" + +on: + workflow_call: + inputs: + runner: + description: "GitHub Actions runner label" + required: true + type: string + pypi_repository_url: + description: "PyPI repository URL (empty string means official PyPI)" + required: false + type: string + default: "" + secrets: + PYPI_API_TOKEN: + required: true + +jobs: + build_publish_test: + name: Build / publish / smoke-test on ${{ inputs.runner }} + runs-on: ${{ inputs.runner }} + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + submodules: recursive + + - name: Set up Python (for cibuildwheel controller) + uses: actions/setup-python@v6 + with: + python-version: '3.11' + + - name: Install cibuildwheel + run: | + pip install --upgrade pip + pip install cibuildwheel==3.4.0 + + - name: Build wheels using cibuildwheel + run: | + python -m cibuildwheel --output-dir wheelhouse + # Save list of built wheels for publishing + ls wheelhouse/*.whl | tee $GITHUB_STEP_SUMMARY + echo "wheels=$(ls wheelhouse/*.whl | tr '\n' ' ')" >> $GITHUB_ENV + + - name: Publish to PyPI + if: success() && github.event_name == 'workflow_dispatch' + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + TWINE_REPOSITORY_URL: ${{ inputs.pypi_repository_url }} + run: | + pip install twine + twine upload --skip-existing --verbose wheelhouse/*.whl + + - name: Smoke test from PyPI + if: success() && github.event_name == 'workflow_dispatch' + shell: bash + env: + PYPI_REPOSITORY_URL: ${{ inputs.pypi_repository_url }} + run: | + # Extract version from wheel filename (e.g. zvec-0.2.1.dev24-cp311-...whl -> 0.2.1.dev24) + WHEEL_FILE=$(ls wheelhouse/zvec-*.whl | head -1) + ZVEC_VERSION=$(basename "$WHEEL_FILE" | sed 's/zvec-\([^-]*\)-.*/\1/') + + # Build index-url flags: use TestPyPI when repository URL is set, otherwise official PyPI + if [ -n "$PYPI_REPOSITORY_URL" ]; then + INDEX_FLAGS="--index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/" + echo "Waiting for zvec==$ZVEC_VERSION to become available on TestPyPI..." + else + INDEX_FLAGS="" + echo "Waiting for zvec==$ZVEC_VERSION to become available on PyPI..." + fi + # Poll until the version is available (max 5 minutes) + FOUND=0 + for i in $(seq 1 30); do + if pip install $INDEX_FLAGS --dry-run "zvec==$ZVEC_VERSION" > /dev/null 2>&1; then + echo "Version $ZVEC_VERSION is available." + FOUND=1 + break + fi + echo "Attempt $i/30: not yet available, retrying in 10s..." + sleep 10 + done + + if [ "$FOUND" -eq 0 ]; then + echo "ERROR: Timed out (5 min) waiting for zvec==$ZVEC_VERSION on PyPI. Aborting smoke test." + exit 1 + fi + + # Create a clean venv and install + python -m venv test_env + source test_env/bin/activate + pip install --upgrade pip + pip install $INDEX_FLAGS "zvec==$ZVEC_VERSION" + pip install --upgrade pip + pip install $INDEX_FLAGS "zvec==$ZVEC_VERSION" + # Run a simple smoke test + python -c "import zvec; print('Import OK:', zvec.__version__)" diff --git a/.github/workflows/build_test_wheel.yml b/.github/workflows/build_test_wheel.yml index 8636d5e2..60f431a3 100644 --- a/.github/workflows/build_test_wheel.yml +++ b/.github/workflows/build_test_wheel.yml @@ -8,97 +8,28 @@ permissions: jobs: build_wheels_linux_x64: - name: Build wheels on self-hosted manylinux_2_28_x64 for TestPyPi - runs-on: linux_x64 - - steps: - - name: Checkout code - uses: actions/checkout@v6 - with: - submodules: recursive - - - name: Set up Python (for cibuildwheel controller) - uses: actions/setup-python@v6 - with: - python-version: '3.11' - - - name: Install cibuildwheel - run: | - pip install --upgrade pip - pip install cibuildwheel==2.17.0 - - name: Build wheels using cibuildwheel - run: | - python -m cibuildwheel --output-dir wheelhouse - # Save list of built wheels for publishing - ls wheelhouse/*.whl | tee $GITHUB_STEP_SUMMARY - echo "wheels=$(ls wheelhouse/*.whl | tr '\n' ' ')" >> $GITHUB_ENV - - name: Publish to TestPyPI - if: success() && github.event_name == 'workflow_dispatch' - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} - TWINE_REPOSITORY_URL: https://test.pypi.org/legacy/ - run: | - pip install twine - twine upload --skip-existing --verbose wheelhouse/*.whl - - name: (Optional) Install and test from TestPyPI - if: success() && github.event_name == 'workflow_dispatch' - run: | - # Create a clean venv - python -m venv test_env - source test_env/bin/activate - pip install --upgrade pip - # Install from TestPyPI (must allow pre-releases if version has dev/alpha) - pip install numpy - pip install --index-url https://test.pypi.org/simple/ zvec - # Run a simple smoke test - python -c "import zvec; print('Import OK:', zvec.__version__)" - shell: bash + name: Build wheels on ubuntu-24.04 (x64) for TestPyPi + uses: ./.github/workflows/_build_wheel_job.yml + with: + runner: ubuntu-24.04 + pypi_repository_url: https://test.pypi.org/legacy/ + secrets: + PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} build_wheels_linux_arm64: - name: Build wheels on self-hosted manylinux_2_28_arm64 for TestPyPi - runs-on: linux_arm64 - - steps: - - name: Checkout code - uses: actions/checkout@v6 - with: - submodules: recursive - - - name: Set up Python (for cibuildwheel controller) - uses: actions/setup-python@v6 - with: - python-version: '3.11' - - - name: Install cibuildwheel - run: | - pip install --upgrade pip - pip install cibuildwheel==2.17.0 - - name: Build wheels using cibuildwheel - run: | - python -m cibuildwheel --output-dir wheelhouse - # Save list of built wheels for publishing - ls wheelhouse/*.whl | tee $GITHUB_STEP_SUMMARY - echo "wheels=$(ls wheelhouse/*.whl | tr '\n' ' ')" >> $GITHUB_ENV - - name: Publish to TestPyPI - if: success() && github.event_name == 'workflow_dispatch' - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} - TWINE_REPOSITORY_URL: https://test.pypi.org/legacy/ - run: | - pip install twine - twine upload --skip-existing --verbose wheelhouse/*.whl - - name: (Optional) Install and test from TestPyPI - if: success() && github.event_name == 'workflow_dispatch' - run: | - # Create a clean venv - python -m venv test_env - source test_env/bin/activate - pip install --upgrade pip - # Install from TestPyPI (must allow pre-releases if version has dev/alpha) - pip install numpy - pip install --index-url https://test.pypi.org/simple/ zvec - # Run a simple smoke test - python -c "import zvec; print('Import OK:', zvec.__version__)" - shell: bash \ No newline at end of file + name: Build wheels on ubuntu-24.04-arm (arm64) for TestPyPi + uses: ./.github/workflows/_build_wheel_job.yml + with: + runner: ubuntu-24.04-arm + pypi_repository_url: https://test.pypi.org/legacy/ + secrets: + PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} + + build_wheels_macos_arm64: + name: Build wheels on macos-15 (arm64) for TestPyPi + uses: ./.github/workflows/_build_wheel_job.yml + with: + runner: macos-15 + pypi_repository_url: https://test.pypi.org/legacy/ + secrets: + PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index 21cf3c40..0b7f2a38 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -8,101 +8,25 @@ permissions: jobs: build_wheels_linux_x64: - name: Build wheels on self-hosted manylinux_2_28_x64 - runs-on: linux_x64 - - steps: - - name: Checkout code - uses: actions/checkout@v6 - with: - submodules: recursive - - - name: Set up Python (for cibuildwheel controller) - uses: actions/setup-python@v6 - with: - python-version: '3.11' - - - name: Install cibuildwheel - run: | - pip install --upgrade pip - pip install cibuildwheel==2.17.0 - - - name: Build wheels using cibuildwheel - run: | - python -m cibuildwheel --output-dir wheelhouse - # Save list of built wheels for publishing - ls wheelhouse/*.whl | tee $GITHUB_STEP_SUMMARY - echo "wheels=$(ls wheelhouse/*.whl | tr '\n' ' ')" >> $GITHUB_ENV - - - name: Publish to PyPI - if: success() && github.event_name == 'workflow_dispatch' - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} - TWINE_REPOSITORY: pypi - run: | - pip install twine - twine upload --skip-existing --verbose wheelhouse/*.whl - - - name: (Optional) Install and test from PyPI - if: success() && github.event_name == 'workflow_dispatch' - run: | - # Create a clean venv - python -m venv test_env - source test_env/bin/activate - pip install --upgrade pip - # Install from PyPI - pip install zvec - # Run a simple smoke test - python -c "import zvec; print('Import OK:', zvec.__version__)" - shell: bash + name: Build wheels on ubuntu-24.04 (x64) for PyPi + uses: ./.github/workflows/_build_wheel_job.yml + with: + runner: ubuntu-24.04 + secrets: + PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} build_wheels_linux_arm64: - name: Build wheels on self-hosted manylinux_2_28_arm64 - runs-on: linux_arm64 - - steps: - - name: Checkout code - uses: actions/checkout@v6 - with: - submodules: recursive - - - name: Set up Python (for cibuildwheel controller) - uses: actions/setup-python@v6 - with: - python-version: '3.11' - - - name: Install cibuildwheel - run: | - pip install --upgrade pip - pip install cibuildwheel==2.17.0 - - - name: Build wheels using cibuildwheel - run: | - python -m cibuildwheel --output-dir wheelhouse - # Save list of built wheels for publishing - ls wheelhouse/*.whl | tee $GITHUB_STEP_SUMMARY - echo "wheels=$(ls wheelhouse/*.whl | tr '\n' ' ')" >> $GITHUB_ENV - - - name: Publish to PyPI - if: success() && github.event_name == 'workflow_dispatch' - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} - TWINE_REPOSITORY: pypi - run: | - pip install twine - twine upload --skip-existing --verbose wheelhouse/*.whl - - - name: (Optional) Install and test from PyPI - if: success() && github.event_name == 'workflow_dispatch' - run: | - # Create a clean venv - python -m venv test_env - source test_env/bin/activate - pip install --upgrade pip - # Install from PyPI - pip install zvec - # Run a simple smoke test - python -c "import zvec; print('Import OK:', zvec.__version__)" - shell: bash \ No newline at end of file + name: Build wheels on ubuntu-24.04-arm (arm64) for PyPi + uses: ./.github/workflows/_build_wheel_job.yml + with: + runner: ubuntu-24.04-arm + secrets: + PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + + build_wheels_macos_arm64: + name: Build wheels on macos-15 (arm64) for PyPi + uses: ./.github/workflows/_build_wheel_job.yml + with: + runner: macos-15 + secrets: + PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} diff --git a/pyproject.toml b/pyproject.toml index 349a6c2b..486b0b36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ test = [ "pytest >=8.0", "pytest-cov >=4.1", "pytest-mock >=3.12", - "cibuildwheel == 2.17.0", + "cibuildwheel == 3.4.0", ] docs = [ "mkdocs >=1.5", @@ -70,7 +70,7 @@ dev = [ "pytest >=8.0", "pytest-cov >=4.1", "pytest-mock >=3.12", - "cibuildwheel == 2.17.0", + "cibuildwheel == 3.4.0", # Inherit docs deps "mkdocs >=1.5", "mkdocs-material >=9.5", @@ -130,6 +130,7 @@ BUILD_PYTHON_BINDINGS = "ON" [tool.setuptools_scm] local_scheme = "no-local-version" version_scheme = "guess-next-dev" +fallback_version = "0.2.1b1" ###################################################################################################### # TESTING & QUALITY ###################################################################################################### @@ -169,11 +170,12 @@ build = [ ] build-frontend = "build" test-requires = ["pytest", "numpy"] +test-command = "cd {project} && pytest python/tests -v --tb=short" build-verbosity = 1 [tool.cibuildwheel.linux] archs = ["auto"] -test-command = "cd {project} && pytest python/tests -v --tb=short" +environment = { CMAKE_GENERATOR = "Unix Makefiles", CMAKE_BUILD_PARALLEL_LEVEL = "16" } manylinux-x86_64-image = "manylinux_2_28" manylinux-aarch64-image = "manylinux_2_28" # Skip 32-bit builds and musllinux @@ -181,7 +183,9 @@ skip = ["*-manylinux_i686", "*-musllinux*"] [tool.cibuildwheel.macos] archs = ["arm64"] -environment = { MACOSX_DEPLOYMENT_TARGET = "11.0" } +# Inherits CMAKE_GENERATOR and CMAKE_BUILD_PARALLEL_LEVEL from [tool.cibuildwheel] won't work; +# platform-level environment overrides the top-level entirely, so all vars must be listed here +environment = { CMAKE_GENERATOR = "Unix Makefiles", CMAKE_BUILD_PARALLEL_LEVEL = "16", MACOSX_DEPLOYMENT_TARGET = "11.0" } ###################################################################################################### # CODE QUALITY & FORMATTING (Ruff) ###################################################################################################### From 734c45d6d3c810b9921950e5479d07e27eca72ce Mon Sep 17 00:00:00 2001 From: Qinren Zhou Date: Mon, 16 Mar 2026 14:23:42 +0800 Subject: [PATCH 26/34] minor: add deepwiki badges (#228) --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 226d4f15..6632f78b 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,10 @@

Code Coverage Main + License PyPI Release Python Versions - License + npm Release

@@ -22,6 +23,7 @@ 🏠 Home | 📚 Docs | 📊 Benchmarks | + 🔎 DeepWiki | 🎮 Discord

From e4af3cb3e740f3225e6677a0d2ebf36758cae120 Mon Sep 17 00:00:00 2001 From: Qinren Zhou Date: Mon, 16 Mar 2026 17:12:24 +0800 Subject: [PATCH 27/34] feat: enlarge indice size limit for sparse vectors (#229) --- .../flat_sparse/flat_sparse_streamer.cc | 12 ++++++++---- .../flat_sparse/flat_sparse_utility.h | 2 +- .../hnsw_sparse/hnsw_sparse_builder_entity.cc | 8 +++++--- .../algorithm/hnsw_sparse/hnsw_sparse_entity.h | 2 +- .../hnsw_sparse/hnsw_sparse_streamer.cc | 16 ++++++++++------ src/db/common/constants.h | 2 +- src/db/index/common/doc.cc | 18 ++++++++++++++++-- tests/db/index/common/doc_test.cc | 2 +- 8 files changed, 43 insertions(+), 19 deletions(-) diff --git a/src/core/algorithm/flat_sparse/flat_sparse_streamer.cc b/src/core/algorithm/flat_sparse/flat_sparse_streamer.cc index 4df83d8f..bc5a3da1 100644 --- a/src/core/algorithm/flat_sparse/flat_sparse_streamer.cc +++ b/src/core/algorithm/flat_sparse/flat_sparse_streamer.cc @@ -195,8 +195,10 @@ int FlatSparseStreamer::add_impl(uint64_t pkey, const uint32_t sparse_count, } if (ailego_unlikely(sparse_count > PARAM_FLAT_SPARSE_MAX_DIM_SIZE)) { - LOG_ERROR("Add vector failed, dim size too larg, dim_size=%u, key=%zu", - sparse_count, (size_t)pkey); + LOG_ERROR( + "Failed to add sparse vector: number of non-zero elements (%u) exceeds " + "maximum allowed (%u), key=%zu", + sparse_count, PARAM_FLAT_SPARSE_MAX_DIM_SIZE, (size_t)pkey); (*stats_.mutable_discarded_count())++; return IndexError_InvalidValue; } @@ -252,8 +254,10 @@ int FlatSparseStreamer::add_with_id_impl(uint32_t pkey, } if (ailego_unlikely(sparse_count > PARAM_FLAT_SPARSE_MAX_DIM_SIZE)) { - LOG_ERROR("Add vector failed, dim size too larg, dim_size=%u, key=%zu", - sparse_count, (size_t)pkey); + LOG_ERROR( + "Failed to add sparse vector: number of non-zero elements (%u) exceeds " + "maximum allowed (%u), key=%zu", + sparse_count, PARAM_FLAT_SPARSE_MAX_DIM_SIZE, (size_t)pkey); (*stats_.mutable_discarded_count())++; return IndexError_InvalidValue; } diff --git a/src/core/algorithm/flat_sparse/flat_sparse_utility.h b/src/core/algorithm/flat_sparse/flat_sparse_utility.h index 4566b5de..a66e2f63 100644 --- a/src/core/algorithm/flat_sparse/flat_sparse_utility.h +++ b/src/core/algorithm/flat_sparse/flat_sparse_utility.h @@ -19,7 +19,7 @@ namespace zvec { namespace core { -static constexpr uint32_t PARAM_FLAT_SPARSE_MAX_DIM_SIZE = 4096; +static constexpr uint32_t PARAM_FLAT_SPARSE_MAX_DIM_SIZE = 16384; static const std::string PARAM_FLAT_SPARSE_META_SEG_ID = "bruteforce_sparse_meta"; diff --git a/src/core/algorithm/hnsw_sparse/hnsw_sparse_builder_entity.cc b/src/core/algorithm/hnsw_sparse/hnsw_sparse_builder_entity.cc index 48c20d72..25c5c00c 100644 --- a/src/core/algorithm/hnsw_sparse/hnsw_sparse_builder_entity.cc +++ b/src/core/algorithm/hnsw_sparse/hnsw_sparse_builder_entity.cc @@ -88,9 +88,11 @@ int HnswSparseBuilderEntity::add_vector(level_t level, key_t key, const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_vec, node_id_t *id) { - if (ailego_unlikely(sparse_count >= HnswSparseEntity::kSparseMaxDimSize)) { - LOG_WARN("Add vector failed, dim size too larg, dim_size=%u, key=%zu", - sparse_count, (size_t)key); + if (ailego_unlikely(sparse_count > HnswSparseEntity::kSparseMaxDimSize)) { + LOG_WARN( + "Failed to add sparse vector: number of non-zero elements (%u) exceeds " + "maximum allowed (%u), key=%zu", + sparse_count, HnswSparseEntity::kSparseMaxDimSize, (size_t)key); return IndexError_InvalidValue; } diff --git a/src/core/algorithm/hnsw_sparse/hnsw_sparse_entity.h b/src/core/algorithm/hnsw_sparse/hnsw_sparse_entity.h index 37166027..d514e0c2 100644 --- a/src/core/algorithm/hnsw_sparse/hnsw_sparse_entity.h +++ b/src/core/algorithm/hnsw_sparse/hnsw_sparse_entity.h @@ -610,7 +610,7 @@ class HnswSparseEntity { constexpr static uint32_t kSparseMetaSize = 2u * sizeof(uint64_t); constexpr static float kDefaultSparseNeighborRatio = 0.5f; - constexpr static uint32_t kSparseMaxDimSize = 4096; + constexpr static uint32_t kSparseMaxDimSize = 16384; constexpr static float kDefaultQueryFilteringRatio = 0.0f; // turn off protected: diff --git a/src/core/algorithm/hnsw_sparse/hnsw_sparse_streamer.cc b/src/core/algorithm/hnsw_sparse/hnsw_sparse_streamer.cc index f51ebb5e..3abce808 100644 --- a/src/core/algorithm/hnsw_sparse/hnsw_sparse_streamer.cc +++ b/src/core/algorithm/hnsw_sparse/hnsw_sparse_streamer.cc @@ -438,9 +438,11 @@ int HnswSparseStreamer::add_with_id_impl(uint32_t id, return ret; } - if (ailego_unlikely(sparse_count >= HnswSparseEntity::kSparseMaxDimSize)) { - LOG_WARN("Add vector failed, dim size too larg, dim_size=%u, id=%u", - sparse_count, id); + if (ailego_unlikely(sparse_count > HnswSparseEntity::kSparseMaxDimSize)) { + LOG_WARN( + "Failed to add sparse vector: number of non-zero elements (%u) exceeds " + "maximum allowed (%u), id=%u", + sparse_count, HnswSparseEntity::kSparseMaxDimSize, id); return IndexError_InvalidValue; } @@ -523,9 +525,11 @@ int HnswSparseStreamer::add_impl(uint64_t pkey, const uint32_t sparse_count, return ret; } - if (ailego_unlikely(sparse_count >= HnswSparseEntity::kSparseMaxDimSize)) { - LOG_WARN("Add vector failed, dim size too larg, dim_size=%u, key=%zu", - sparse_count, (size_t)pkey); + if (ailego_unlikely(sparse_count > HnswSparseEntity::kSparseMaxDimSize)) { + LOG_WARN( + "Failed to add sparse vector: number of non-zero elements (%u) exceeds " + "maximum allowed (%u), key=%zu", + sparse_count, HnswSparseEntity::kSparseMaxDimSize, (size_t)pkey); return IndexError_InvalidValue; } diff --git a/src/db/common/constants.h b/src/db/common/constants.h index 39b13f44..d07512f3 100644 --- a/src/db/common/constants.h +++ b/src/db/common/constants.h @@ -32,7 +32,7 @@ const std::string GLOBAL_DOC_ID = "_zvec_g_doc_id_"; const std::string USER_ID = "_zvec_uid_"; -const int kSparseMaxDimSize = 4096; +const int kSparseMaxDimSize = 16384; const int64_t kMaxRecordBatchNumRows = 4096; diff --git a/src/db/index/common/doc.cc b/src/db/index/common/doc.cc index 6d411bfb..2f9f12b0 100644 --- a/src/db/index/common/doc.cc +++ b/src/db/index/common/doc.cc @@ -866,6 +866,12 @@ Status Doc::validate(const CollectionSchema::Ptr &schema, "doc validate failed: field[", field_name, "]'s sparse vector indices and values size not match"); } + if (sparse_indices.size() > kSparseMaxDimSize) { + return Status::InvalidArgument( + "doc validate failed: vector[", field_name, + "], the number of sparse indices exceeds the maximum limit ", + kSparseMaxDimSize); + } } break; } @@ -881,6 +887,12 @@ Status Doc::validate(const CollectionSchema::Ptr &schema, "doc validate failed: field[", field_name, "]'s sparse vector indices and values size not match"); } + if (sparse_indices.size() > kSparseMaxDimSize) { + return Status::InvalidArgument( + "doc validate failed: vector[", field_name, + "], the number of sparse indices exceeds the maximum limit ", + kSparseMaxDimSize); + } } break; } @@ -1251,9 +1263,11 @@ Status VectorQuery::validate(const FieldSchema *schema) const { } } else if (schema->is_sparse_vector()) { // validate sparse indices size - if (query_sparse_indices_.size() >= kSparseMaxDimSize * sizeof(uint32_t)) { + if (query_sparse_indices_.size() > kSparseMaxDimSize * sizeof(uint32_t)) { return Status::InvalidArgument( - "query validate failed: sparse indices size is too large"); + "query validate failed: the number of sparse indices exceeds the " + "maximum limit ", + kSparseMaxDimSize); } } else { return Status::InvalidArgument( diff --git a/tests/db/index/common/doc_test.cc b/tests/db/index/common/doc_test.cc index 5047ecb5..9098478b 100644 --- a/tests/db/index/common/doc_test.cc +++ b/tests/db/index/common/doc_test.cc @@ -1255,7 +1255,7 @@ TEST(VectorQuery, Validate) { VectorQuery query; query.field_name_ = "field_name"; query.topk_ = 100; - std::vector query_indices = std::vector(4097); + std::vector query_indices = std::vector(16385); std::string query_indices_str = std::string(reinterpret_cast(query_indices.data()), query_indices.size() * sizeof(uint32_t)); From 3c051051bb5e87e43844b0bd1c0908092fba73dd Mon Sep 17 00:00:00 2001 From: Cuiys Date: Tue, 17 Mar 2026 13:40:56 +0800 Subject: [PATCH 28/34] Update wechat qrcode in README.md (#231) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6632f78b..a153f48d 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ Stay updated and get support — scan or click: | 💬 DingTalk | 📱 WeChat | 🎮 Discord | |:---:|:---:|:---:| -| | | [![Discord](https://img.shields.io/badge/Discord-Join%20Server-5865F2?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/rKddFBBu9z) | +| | | [![Discord](https://img.shields.io/badge/Discord-Join%20Server-5865F2?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/rKddFBBu9z) | | Scan to join | Scan to join | Click to join | From 2ca3d7d6c870f3b86e32cbb86c18f2b49c34cf41 Mon Sep 17 00:00:00 2001 From: rayx Date: Tue, 17 Mar 2026 19:59:27 +0800 Subject: [PATCH 29/34] fix: fix ut for sparse builder dump time (#237) * fix ut for sparse builder dump time --- tests/core/algorithm/flat_sparse/flat_sparse_builder_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/algorithm/flat_sparse/flat_sparse_builder_test.cc b/tests/core/algorithm/flat_sparse/flat_sparse_builder_test.cc index af770255..c89d086b 100644 --- a/tests/core/algorithm/flat_sparse/flat_sparse_builder_test.cc +++ b/tests/core/algorithm/flat_sparse/flat_sparse_builder_test.cc @@ -257,7 +257,7 @@ TEST_F(FlatSparseBuilderTest, TestHalfFloatConverter) { ASSERT_EQ(0UL, stats.discarded_count()); ASSERT_EQ(0UL, stats.trained_costtime()); ASSERT_EQ(stats.built_costtime(), 0UL); - ASSERT_GT(stats.dumped_costtime(), 0UL); + //ASSERT_GT(stats.dumped_costtime(), 0UL); // cleanup and rebuild ASSERT_EQ(0, builder->cleanup()); @@ -298,4 +298,4 @@ TEST_F(FlatSparseBuilderTest, TestHalfFloatConverter) { #if defined(__GNUC__) || defined(__GNUG__) #pragma GCC diagnostic pop -#endif \ No newline at end of file +#endif From 9c9fafe9dbf3e776e787ae014468648e83d910f9 Mon Sep 17 00:00:00 2001 From: egolearner Date: Wed, 18 Mar 2026 14:42:00 +0800 Subject: [PATCH 30/34] feat: local_builder simplify disable idmap (#239) --- tools/core/local_builder.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tools/core/local_builder.cc b/tools/core/local_builder.cc index 9d502a1e..9a19a079 100644 --- a/tools/core/local_builder.cc +++ b/tools/core/local_builder.cc @@ -932,14 +932,17 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { cout << "Prepare data done!" << endl; ailego::Params params; - if (g_disable_id_map) { - params.set(PARAM_HNSW_STREAMER_USE_ID_MAP, false); - params.set(PARAM_FLAT_USE_ID_MAP, false); - } if (!prepare_params(config_root["BuilderParams"], params)) { LOG_ERROR("Failed to prepare params"); return -1; } + std::vector id_map_param_list = { + PARAM_HNSW_STREAMER_USE_ID_MAP, + PARAM_FLAT_USE_ID_MAP, + }; + for (auto ¶m : id_map_param_list) { + params.set(param, !g_disable_id_map); + } // INIT int ret = From f52c3156ab544aa303e9c3b53157904582dfa5a2 Mon Sep 17 00:00:00 2001 From: feihongxu0824 Date: Wed, 18 Mar 2026 16:16:49 +0800 Subject: [PATCH 31/34] chore: android ci needs lint first (#240) --- .github/workflows/01-ci-pipeline.yml | 55 +++++++++++ .github/workflows/02-lint-check.yml | 52 +++++++++++ .../{main.yml => 03-macos-linux-build.yml} | 91 ++++--------------- ...android_build.yml => 04-android-build.yml} | 17 +--- 4 files changed, 125 insertions(+), 90 deletions(-) create mode 100644 .github/workflows/01-ci-pipeline.yml create mode 100644 .github/workflows/02-lint-check.yml rename .github/workflows/{main.yml => 03-macos-linux-build.yml} (51%) rename .github/workflows/{android_build.yml => 04-android-build.yml} (94%) diff --git a/.github/workflows/01-ci-pipeline.yml b/.github/workflows/01-ci-pipeline.yml new file mode 100644 index 00000000..290b9c8a --- /dev/null +++ b/.github/workflows/01-ci-pipeline.yml @@ -0,0 +1,55 @@ +name: Main + +on: + push: + branches: [ "main" ] + paths-ignore: + - '**.md' + merge_group: + pull_request: + branches: [ "main" ] + paths-ignore: + - '**.md' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + # Code quality checks (fast, run first) + lint: + uses: ./.github/workflows/02-lint-check.yml + + # Main build and test matrix + build-and-test-macos-arm64: + name: Build & Test (macos-arm64) + needs: lint + uses: ./.github/workflows/03-macos-linux-build.yml + with: + platform: macos-arm64 + os: macos-15 + + build-and-test-linux-arm64: + name: Build & Test (linux-arm64) + needs: lint + uses: ./.github/workflows/03-macos-linux-build.yml + with: + platform: linux-arm64 + os: ubuntu-24.04-arm + + build-and-test-linux-x64: + name: Build & Test (linux-x64) + needs: lint + uses: ./.github/workflows/03-macos-linux-build.yml + with: + platform: linux-x64 + os: ubuntu-24.04 + + build-android: + name: Build & Test (android) + needs: lint + uses: ./.github/workflows/04-android-build.yml diff --git a/.github/workflows/02-lint-check.yml b/.github/workflows/02-lint-check.yml new file mode 100644 index 00000000..4f9076c6 --- /dev/null +++ b/.github/workflows/02-lint-check.yml @@ -0,0 +1,52 @@ +name: Lint + +on: + workflow_call: + +jobs: + lint: + name: Code Quality Checks + runs-on: ubuntu-24.04 + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.10' + cache: 'pip' + cache-dependency-path: 'pyproject.toml' + + - name: Install linting tools + run: | + python -m pip install --upgrade pip \ + ruff==v0.14.4 \ + clang-format==18.1.8 + shell: bash + + - name: Run Ruff Linter + run: python -m ruff check . + shell: bash + + - name: Run Ruff Formatter Check + run: python -m ruff format --check . + shell: bash + + - name: Run clang-format Check + run: | + CPP_FILES=$(find . -type f \( -name "*.cpp" -o -name "*.h" -o -name "*.hpp" -o -name "*.cc" -o -name "*.cxx" \) \ + ! -path "./build/*" \ + ! -path "./tests/*" \ + ! -path "./scripts/*" \ + ! -path "./python/*" \ + ! -path "./thirdparty/*" \ + ! -path "./.git/*") + + if [ -z "$CPP_FILES" ]; then + echo "No C++ files found to check." + exit 0 + fi + + clang-format --dry-run --Werror $CPP_FILES + shell: bash diff --git a/.github/workflows/main.yml b/.github/workflows/03-macos-linux-build.yml similarity index 51% rename from .github/workflows/main.yml rename to .github/workflows/03-macos-linux-build.yml index abfd4d73..9ec52cc7 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/03-macos-linux-build.yml @@ -1,92 +1,33 @@ -name: Main +name: MacOS & Linux Build on: - push: - branches: [ "main" ] - paths-ignore: - - '**.md' - merge_group: - pull_request: - branches: [ "main" ] - paths-ignore: - - '**.md' - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }} - cancel-in-progress: true + workflow_call: + inputs: + platform: + description: 'Platform identifier' + required: true + type: string + os: + description: 'GitHub Actions runner OS' + required: true + type: string permissions: contents: read jobs: - # Code quality checks (fast, run first) - lint: - name: Code Quality Checks - runs-on: ubuntu-24.04 - steps: - - name: Checkout code - uses: actions/checkout@v6 - - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: '3.10' - cache: 'pip' - cache-dependency-path: 'pyproject.toml' - - - name: Install linting tools - run: | - python -m pip install --upgrade pip \ - ruff==v0.14.4 \ - clang-format==18.1.8 - shell: bash - - - name: Run Ruff Linter - run: python -m ruff check . - shell: bash - - - name: Run Ruff Formatter Check - run: python -m ruff format --check . - shell: bash - - - name: Run clang-format Check - run: | - CPP_FILES=$(find . -type f \( -name "*.cpp" -o -name "*.h" -o -name "*.hpp" -o -name "*.cc" -o -name "*.cxx" \) \ - ! -path "./build/*" \ - ! -path "./tests/*" \ - ! -path "./scripts/*" \ - ! -path "./python/*" \ - ! -path "./thirdparty/*" \ - ! -path "./.git/*") - - if [ -z "$CPP_FILES" ]; then - echo "No C++ files found to check." - exit 0 - fi - - clang-format --dry-run --Werror $CPP_FILES - shell: bash - # Build and test matrix (parallel execution) build-and-test: - name: Build & Test (${{ matrix.platform }}) - needs: lint - runs-on: ${{ matrix.os }} + name: Build & Test (${{ inputs.platform }}) + runs-on: ${{ inputs.os }} strategy: fail-fast: false matrix: include: - - os: macos-15 - platform: macos-arm64 - arch_flag: "" # ARM64 uses auto-detection - - os: ubuntu-24.04-arm - platform: linux-arm64 - arch_flag: "" # ARM64 uses auto-detection - - os: ubuntu-24.04 - platform: linux-x64 - arch_flag: "" # Use native CPU microarchitecture + - os: ${{ inputs.os }} + platform: ${{ inputs.platform }} + arch_flag: "" # Use appropriate architecture steps: - name: Checkout code diff --git a/.github/workflows/android_build.yml b/.github/workflows/04-android-build.yml similarity index 94% rename from .github/workflows/android_build.yml rename to .github/workflows/04-android-build.yml index 099d4868..5ee6fe34 100644 --- a/.github/workflows/android_build.yml +++ b/.github/workflows/04-android-build.yml @@ -1,20 +1,7 @@ -name: android-cross-build +name: Android Cross Build on: - push: - branches: [ "main" ] - paths-ignore: - - '**.md' - merge_group: - pull_request: - branches: [ "main" ] - paths-ignore: - - '**.md' - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }} - cancel-in-progress: true + workflow_call: permissions: contents: read From ca52c0bce587475a527c98b80e3b46077f0c5a31 Mon Sep 17 00:00:00 2001 From: luoxiaojian Date: Wed, 18 Mar 2026 20:43:18 +0800 Subject: [PATCH 32/34] feat: Introduce turbo. (#243) --- examples/c++/CMakeLists.txt | 2 +- src/CMakeLists.txt | 1 + src/core/CMakeLists.txt | 2 +- src/core/metric/CMakeLists.txt | 2 +- src/core/metric/quantized_integer_metric.cc | 41 ++- src/include/zvec/turbo/turbo.h | 55 +++ src/turbo/CMakeLists.txt | 36 ++ .../record_quantized_int8/common.h | 312 ++++++++++++++++++ .../record_quantized_int8/cosine.cc | 144 ++++++++ .../record_quantized_int8/cosine.h | 39 +++ .../squared_euclidean.cc | 138 ++++++++ .../record_quantized_int8/squared_euclidean.h | 41 +++ src/turbo/turbo.cc | 75 +++++ 13 files changed, 881 insertions(+), 7 deletions(-) create mode 100644 src/include/zvec/turbo/turbo.h create mode 100644 src/turbo/CMakeLists.txt create mode 100644 src/turbo/avx512_vnni/record_quantized_int8/common.h create mode 100644 src/turbo/avx512_vnni/record_quantized_int8/cosine.cc create mode 100644 src/turbo/avx512_vnni/record_quantized_int8/cosine.h create mode 100644 src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.cc create mode 100644 src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.h create mode 100644 src/turbo/turbo.cc diff --git a/examples/c++/CMakeLists.txt b/examples/c++/CMakeLists.txt index d0dbf8b6..d1731579 100644 --- a/examples/c++/CMakeLists.txt +++ b/examples/c++/CMakeLists.txt @@ -43,7 +43,7 @@ set(zvec_ailego_deps ) set(zvec_core_deps - # empty + zvec_turbo ) set(zvec_db_deps diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c516187c..00383c99 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -6,6 +6,7 @@ git_version(ZVEC_VERSION ${CMAKE_CURRENT_SOURCE_DIR}) # Add repository cc_directory(ailego) +cc_directory(turbo) cc_directory(core) cc_directory(db) if(BUILD_PYTHON_BINDINGS) diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 7742db59..5f696c08 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -15,7 +15,7 @@ file(GLOB_RECURSE ALL_CORE_SRCS *.cc *.c *.h) cc_library( NAME zvec_core STATIC STRICT PACKED SRCS ${ALL_CORE_SRCS} - LIBS zvec_ailego sparsehash magic_enum + LIBS zvec_ailego zvec_turbo sparsehash magic_enum INCS . ${PROJECT_ROOT_DIR}/src/core VERSION "${GIT_SRCS_VER}" ) \ No newline at end of file diff --git a/src/core/metric/CMakeLists.txt b/src/core/metric/CMakeLists.txt index cbc1049f..55dfc901 100644 --- a/src/core/metric/CMakeLists.txt +++ b/src/core/metric/CMakeLists.txt @@ -5,7 +5,7 @@ cc_library( NAME core_metric STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc - LIBS zvec_ailego core_framework + LIBS zvec_ailego zvec_turbo core_framework INCS . ${PROJECT_ROOT_DIR}/src/core VERSION "${PROXIMA_ZVEC_VERSION}" ) diff --git a/src/core/metric/quantized_integer_metric.cc b/src/core/metric/quantized_integer_metric.cc index 2b4e757a..f2c19cf2 100644 --- a/src/core/metric/quantized_integer_metric.cc +++ b/src/core/metric/quantized_integer_metric.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "metric_params.h" #include "quantized_integer_metric_batch.h" #include "quantized_integer_metric_matrix.h" @@ -95,6 +96,12 @@ class QuantizedIntegerMetric : public IndexMetric { switch (origin_metric_type_) { case MetricType::kSquaredEuclidean: if (meta_.data_type() == IndexMeta::DataType::DT_INT8) { + auto turbo_ret = turbo::get_distance_func( + turbo::MetricType::kSquaredEuclidean, turbo::DataType::kInt8, + turbo::QuantizeType::kDefault); + if (turbo_ret) { + return turbo_ret; + } return DistanceMatrixCompute(m, n); } if (meta_.data_type() == IndexMeta::DataType::DT_INT4) { @@ -118,7 +125,6 @@ class QuantizedIntegerMetric : public IndexMetric { if (meta_.data_type() == IndexMeta::DataType::DT_INT4) { return DistanceMatrixCompute(m, n); } - // TODO: support MipsSquaredEuclidean other injection type break; case MetricType::kNormalizedCosine: @@ -131,6 +137,12 @@ class QuantizedIntegerMetric : public IndexMetric { break; case MetricType::kCosine: if (meta_.data_type() == IndexMeta::DataType::DT_INT8) { + auto turbo_ret = turbo::get_distance_func( + turbo::MetricType::kCosine, turbo::DataType::kInt8, + turbo::QuantizeType::kDefault); + if (turbo_ret) { + return turbo_ret; + } return DistanceMatrixCompute(m, n); } if (meta_.data_type() == IndexMeta::DataType::DT_INT4) { @@ -146,6 +158,12 @@ class QuantizedIntegerMetric : public IndexMetric { switch (origin_metric_type_) { case MetricType::kSquaredEuclidean: if (meta_.data_type() == IndexMeta::DataType::DT_INT8) { + auto turbo_ret = turbo::get_batch_distance_func( + turbo::MetricType::kSquaredEuclidean, turbo::DataType::kInt8, + turbo::QuantizeType::kDefault); + if (turbo_ret) { + return turbo_ret; + } return reinterpret_cast( BaseDistanceBatchWithScoreUnquantized::ComputeBatch); @@ -180,7 +198,6 @@ class QuantizedIntegerMetric : public IndexMetric { BaseDistanceBatchWithScoreUnquantized< MipsSquaredEuclidean, uint8_t, 12, 2>::ComputeBatch); } - // TODO: support MipsSquaredEuclidean other injection type break; case MetricType::kNormalizedCosine: if (meta_.data_type() == IndexMeta::DataType::DT_INT8) { @@ -196,6 +213,12 @@ class QuantizedIntegerMetric : public IndexMetric { break; case MetricType::kCosine: if (meta_.data_type() == IndexMeta::DataType::DT_INT8) { + auto turbo_ret = turbo::get_batch_distance_func( + turbo::MetricType::kCosine, turbo::DataType::kInt8, + turbo::QuantizeType::kDefault); + if (turbo_ret) { + return turbo_ret; + } return reinterpret_cast( BaseDistanceBatchWithScoreUnquantized< CosineMinusInnerProduct, int8_t, 12, 2>::ComputeBatch); @@ -205,7 +228,6 @@ class QuantizedIntegerMetric : public IndexMetric { BaseDistanceBatchWithScoreUnquantized< CosineMinusInnerProduct, uint8_t, 12, 2>::ComputeBatch); } - break; } return nullptr; @@ -264,14 +286,25 @@ class QuantizedIntegerMetric : public IndexMetric { const override { if (origin_metric_type_ == MetricType::kCosine && meta_.data_type() == IndexMeta::DataType::DT_INT8) { + auto turbo_ret = turbo::get_query_preprocess_func( + turbo::MetricType::kCosine, turbo::DataType::kInt8, + turbo::QuantizeType::kDefault); + if (turbo_ret) { + return turbo_ret; + } return CosineMinusInnerProductDistanceBatchWithScoreUnquantized< int8_t, 1, 1>::GetQueryPreprocessFunc(); } else if (origin_metric_type_ == MetricType::kSquaredEuclidean && meta_.data_type() == IndexMeta::DataType::DT_INT8) { + auto turbo_ret = turbo::get_query_preprocess_func( + turbo::MetricType::kSquaredEuclidean, turbo::DataType::kInt8, + turbo::QuantizeType::kDefault); + if (turbo_ret) { + return turbo_ret; + } return SquaredEuclideanDistanceBatchWithScoreUnquantized< int8_t, 1, 1>::GetQueryPreprocessFunc(); } - return nullptr; } diff --git a/src/include/zvec/turbo/turbo.h b/src/include/zvec/turbo/turbo.h new file mode 100644 index 00000000..6ecbfdd1 --- /dev/null +++ b/src/include/zvec/turbo/turbo.h @@ -0,0 +1,55 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include + +namespace zvec::turbo { + +using DistanceFunc = + std::function; +using BatchDistanceFunc = std::function; +using QueryPreprocessFunc = + zvec::ailego::DistanceBatch::DistanceBatchQueryPreprocessFunc; + +enum class MetricType { + kSquaredEuclidean, + kCosine, + kMipsSquaredEuclidean, + kUnknown, +}; + +enum class DataType { + kInt8, + kUnknown, +}; + +enum class QuantizeType { + kDefault, +}; + +DistanceFunc get_distance_func(MetricType metric_type, DataType data_type, + QuantizeType quantize_type); + +BatchDistanceFunc get_batch_distance_func(MetricType metric_type, + DataType data_type, + QuantizeType quantize_type); + +QueryPreprocessFunc get_query_preprocess_func(MetricType metric_type, + DataType data_type, + QuantizeType quantize_type); + +} // namespace zvec::turbo diff --git a/src/turbo/CMakeLists.txt b/src/turbo/CMakeLists.txt new file mode 100644 index 00000000..0aa834a2 --- /dev/null +++ b/src/turbo/CMakeLists.txt @@ -0,0 +1,36 @@ +include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) +include(${PROJECT_ROOT_DIR}/cmake/option.cmake) + +if(NOT ANDROID AND AUTO_DETECT_ARCH) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|i686|i386|x64") + setup_compiler_march_for_x86(TURBO_MARCH_FLAG_SSE TURBO_MARCH_FLAG_AVX2 TURBO_MARCH_FLAG_AVX512) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64|ARM64") + # ARM64 architecture - no special march flags needed for now + # NEON implementations can be added here if needed + message(STATUS "turbo: ARM64 detected, skipping x86-specific optimizations") + endif() +endif() + +file(GLOB_RECURSE ALL_SRCS *.cc *.c *.h) + +# Set per-file compile flags for AVX512-VNNI sources. +# set_source_files_properties is directory-scoped, so it must be called in the +# same directory that adds the sources to a target (i.e. here, not in a +# subdirectory). +if(NOT ANDROID AND AUTO_DETECT_ARCH) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|i686|i386|x64") + file(GLOB_RECURSE AVX512_VNNI_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/avx512_vnni/*.cc) + set_source_files_properties( + ${AVX512_VNNI_SRCS} + PROPERTIES + COMPILE_FLAGS "${TURBO_MARCH_FLAG_AVX512}" + ) + endif() +endif() + +cc_library( + NAME zvec_turbo STATIC STRICT PACKED + SRCS ${ALL_SRCS} + LIBS zvec_ailego + INCS ${CMAKE_CURRENT_SOURCE_DIR} ${PROJECT_ROOT_DIR}/src/include +) diff --git a/src/turbo/avx512_vnni/record_quantized_int8/common.h b/src/turbo/avx512_vnni/record_quantized_int8/common.h new file mode 100644 index 00000000..55fb5898 --- /dev/null +++ b/src/turbo/avx512_vnni/record_quantized_int8/common.h @@ -0,0 +1,312 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Shared AVX512-VNNI inner product kernels for record_quantized_int8 distance +// implementations (cosine, l2, mips_l2, etc.). +// +// All functions are marked always_inline so that when this header is included +// from a per-file-march .cc translation unit, the compiler can fully inline +// and optimize them under the correct -march flag without any cross-TU call +// overhead. + +#pragma once + +#if defined(__AVX512VNNI__) +#include +#include +#include + +namespace zvec::turbo::avx512_vnni::internal { + +static inline int32_t HorizontalAdd_INT32_V256(__m256i v) { + __m256i x1 = _mm256_hadd_epi32(v, v); + __m256i x2 = _mm256_hadd_epi32(x1, x1); + __m128i x3 = _mm256_extractf128_si256(x2, 1); + __m128i x4 = _mm_add_epi32(_mm256_castsi256_si128(x2), x3); + return _mm_cvtsi128_si32(x4); +} + +#define FMA_INT8_GENERAL(m, q, sum) sum += static_cast(m * q); + +// Compute the raw integer inner product of two int8 vectors of length `size`. +// The result is written to `*distance` as a float. +// Both `a` and `b` must point to int8_t arrays. +static __attribute__((always_inline)) void ip_int8_avx512_vnni( + const void *a, const void *b, size_t size, float *distance) { + const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); + const __m128i ONES_INT16_SSE = _mm_set1_epi32(0x00010001); + + const int8_t *lhs = reinterpret_cast(a); + const int8_t *rhs = reinterpret_cast(b); + + const int8_t *last = lhs + size; + const int8_t *last_aligned = lhs + ((size >> 6) << 6); + + float result = 0.0f; + + __m256i ymm_sum_0 = _mm256_setzero_si256(); + __m256i ymm_sum_1 = _mm256_setzero_si256(); + + if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + __m256i ymm_lhs_0 = _mm256_load_si256((const __m256i *)(lhs + 0)); + __m256i ymm_lhs_1 = _mm256_load_si256((const __m256i *)(lhs + 32)); + __m256i ymm_rhs_0 = _mm256_load_si256((const __m256i *)(rhs + 0)); + __m256i ymm_rhs_1 = _mm256_load_si256((const __m256i *)(rhs + 32)); + + ymm_lhs_0 = _mm256_sign_epi8(ymm_lhs_0, ymm_rhs_0); + ymm_lhs_1 = _mm256_sign_epi8(ymm_lhs_1, ymm_rhs_1); + ymm_rhs_0 = _mm256_abs_epi8(ymm_rhs_0); + ymm_rhs_1 = _mm256_abs_epi8(ymm_rhs_1); + + ymm_sum_0 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_lhs_0), + ONES_INT16_AVX), + ymm_sum_0); + ymm_sum_1 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_1, ymm_lhs_1), + ONES_INT16_AVX), + ymm_sum_1); + } + + if (last >= last_aligned + 32) { + __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); + __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); + ymm_lhs = _mm256_sign_epi8(ymm_lhs, ymm_rhs); + ymm_rhs = _mm256_abs_epi8(ymm_rhs); + ymm_sum_0 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs, ymm_lhs), + ONES_INT16_AVX), + ymm_sum_0); + lhs += 32; + rhs += 32; + } + + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); + xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); + xmm_rhs = _mm_abs_epi8(xmm_rhs); + ymm_sum_0 = _mm256_add_epi32( + _mm256_set_m128i(_mm_setzero_si128(), + _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), + ONES_INT16_SSE)), + ymm_sum_0); + lhs += 16; + rhs += 16; + } + } else { + for (; lhs != last_aligned; lhs += 64, rhs += 64) { + __m256i ymm_lhs_0 = _mm256_loadu_si256((const __m256i *)(lhs + 0)); + __m256i ymm_lhs_1 = _mm256_loadu_si256((const __m256i *)(lhs + 32)); + __m256i ymm_rhs_0 = _mm256_loadu_si256((const __m256i *)(rhs + 0)); + __m256i ymm_rhs_1 = _mm256_loadu_si256((const __m256i *)(rhs + 32)); + + ymm_lhs_0 = _mm256_sign_epi8(ymm_lhs_0, ymm_rhs_0); + ymm_lhs_1 = _mm256_sign_epi8(ymm_lhs_1, ymm_rhs_1); + ymm_rhs_0 = _mm256_abs_epi8(ymm_rhs_0); + ymm_rhs_1 = _mm256_abs_epi8(ymm_rhs_1); + + ymm_sum_0 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_lhs_0), + ONES_INT16_AVX), + ymm_sum_0); + ymm_sum_1 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_1, ymm_lhs_1), + ONES_INT16_AVX), + ymm_sum_1); + } + + if (last >= last_aligned + 32) { + __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); + __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); + ymm_lhs = _mm256_sign_epi8(ymm_lhs, ymm_rhs); + ymm_rhs = _mm256_abs_epi8(ymm_rhs); + ymm_sum_0 = _mm256_add_epi32( + _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs, ymm_lhs), + ONES_INT16_AVX), + ymm_sum_0); + lhs += 32; + rhs += 32; + } + + if (last >= lhs + 16) { + __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); + __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); + xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); + xmm_rhs = _mm_abs_epi8(xmm_rhs); + ymm_sum_0 = _mm256_add_epi32( + _mm256_set_m128i(_mm_setzero_si128(), + _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), + ONES_INT16_SSE)), + ymm_sum_0); + lhs += 16; + rhs += 16; + } + } + result = static_cast( + HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); + + switch (last - lhs) { + case 15: + FMA_INT8_GENERAL(lhs[14], rhs[14], result) + /* FALLTHRU */ + case 14: + FMA_INT8_GENERAL(lhs[13], rhs[13], result) + /* FALLTHRU */ + case 13: + FMA_INT8_GENERAL(lhs[12], rhs[12], result) + /* FALLTHRU */ + case 12: + FMA_INT8_GENERAL(lhs[11], rhs[11], result) + /* FALLTHRU */ + case 11: + FMA_INT8_GENERAL(lhs[10], rhs[10], result) + /* FALLTHRU */ + case 10: + FMA_INT8_GENERAL(lhs[9], rhs[9], result) + /* FALLTHRU */ + case 9: + FMA_INT8_GENERAL(lhs[8], rhs[8], result) + /* FALLTHRU */ + case 8: + FMA_INT8_GENERAL(lhs[7], rhs[7], result) + /* FALLTHRU */ + case 7: + FMA_INT8_GENERAL(lhs[6], rhs[6], result) + /* FALLTHRU */ + case 6: + FMA_INT8_GENERAL(lhs[5], rhs[5], result) + /* FALLTHRU */ + case 5: + FMA_INT8_GENERAL(lhs[4], rhs[4], result) + /* FALLTHRU */ + case 4: + FMA_INT8_GENERAL(lhs[3], rhs[3], result) + /* FALLTHRU */ + case 3: + FMA_INT8_GENERAL(lhs[2], rhs[2], result) + /* FALLTHRU */ + case 2: + FMA_INT8_GENERAL(lhs[1], rhs[1], result) + /* FALLTHRU */ + case 1: + FMA_INT8_GENERAL(lhs[0], rhs[0], result) + } + *distance = result; +} + +#undef FMA_INT8_GENERAL + +// Shift the first `original_dim` bytes of `query` in-place from int8 to uint8 +// by adding 128 to each element. The metadata tail beyond `original_dim` is +// left untouched. This prepares the query for use with dpbusd (uint8 * int8). +static __attribute__((always_inline)) void shift_int8_to_uint8_avx512( + void *query, size_t original_dim) { + const int8_t *input = reinterpret_cast(query); + uint8_t *output = reinterpret_cast(query); + + // 128 represented as int8_t wraps to -128, but two's complement addition + // produces the correct uint8 result. + const __m512i offset = _mm512_set1_epi8(static_cast(128)); + + size_t i = 0; + for (; i + 64 <= original_dim; i += 64) { + __m512i data = + _mm512_loadu_si512(reinterpret_cast(input + i)); + __m512i shifted = _mm512_add_epi8(data, offset); + _mm512_storeu_si512(reinterpret_cast<__m512i *>(output + i), shifted); + } + for (; i < original_dim; ++i) { + output[i] = static_cast(static_cast(input[i]) + 128); + } +} + +// Compute raw integer inner products for a batch of int8 vectors against a +// single query. Uses AVX512-VNNI dpbusd instruction. +// `query` is treated as uint8 (preprocessed), `vectors[i]` as int8. +template +__attribute__((always_inline)) void ip_int8_batch_avx512_vnni_impl( + const void *query, const void *const *vectors, + const std::array &prefetch_ptrs, + size_t dimensionality, float *distances) { + __m512i accs[batch_size]; + for (size_t i = 0; i < batch_size; ++i) { + accs[i] = _mm512_setzero_si512(); + } + size_t dim = 0; + for (; dim + 64 <= dimensionality; dim += 64) { + __m512i q = _mm512_loadu_si512(reinterpret_cast( + reinterpret_cast(query) + dim)); + __m512i data_regs[batch_size]; + for (size_t i = 0; i < batch_size; ++i) { + data_regs[i] = _mm512_loadu_si512(reinterpret_cast( + reinterpret_cast(vectors[i]) + dim)); + } + for (size_t i = 0; i < batch_size; ++i) { + if (prefetch_ptrs[i]) { + _mm_prefetch( + reinterpret_cast( + reinterpret_cast(prefetch_ptrs[i]) + dim), + _MM_HINT_T0); + } + accs[i] = _mm512_dpbusd_epi32(accs[i], q, data_regs[i]); + } + } + std::array temp_results{}; + for (size_t i = 0; i < batch_size; ++i) { + temp_results[i] = _mm512_reduce_add_epi32(accs[i]); + } + for (; dim < dimensionality; ++dim) { + int q = static_cast(reinterpret_cast(query)[dim]); + for (size_t i = 0; i < batch_size; ++i) { + temp_results[i] += + q * + static_cast(reinterpret_cast(vectors[i])[dim]); + } + } + for (size_t i = 0; i < batch_size; ++i) { + distances[i] = static_cast(temp_results[i]); + } +} + +// Dispatch batched inner product over all `n` vectors with prefetching. +static __attribute__((always_inline)) void ip_int8_batch_avx512_vnni( + const void *const *vectors, const void *query, size_t n, size_t dim, + float *distances) { + static constexpr size_t batch_size = 2; + static constexpr size_t prefetch_step = 2; + size_t i = 0; + for (; i + batch_size <= n; i += batch_size) { + std::array prefetch_ptrs; + for (size_t j = 0; j < batch_size; ++j) { + if (i + j + batch_size * prefetch_step < n) { + prefetch_ptrs[j] = vectors[i + j + batch_size * prefetch_step]; + } else { + prefetch_ptrs[j] = nullptr; + } + } + ip_int8_batch_avx512_vnni_impl( + query, &vectors[i], prefetch_ptrs, dim, distances + i); + } + for (; i < n; i++) { + std::array prefetch_ptrs{nullptr}; + ip_int8_batch_avx512_vnni_impl<1>(query, &vectors[i], prefetch_ptrs, dim, + distances + i); + } +} + +} // namespace zvec::turbo::avx512_vnni::internal + +#endif // defined(__AVX512VNNI__) diff --git a/src/turbo/avx512_vnni/record_quantized_int8/cosine.cc b/src/turbo/avx512_vnni/record_quantized_int8/cosine.cc new file mode 100644 index 00000000..34843ba6 --- /dev/null +++ b/src/turbo/avx512_vnni/record_quantized_int8/cosine.cc @@ -0,0 +1,144 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is compiled with per-file -march=avx512vnni (set in CMakeLists.txt) +// so that all AVX512-VNNI intrinsics and the inlined inner product kernels from +// common.h are compiled with the correct target ISA. + +#include "avx512_vnni/record_quantized_int8/cosine.h" +#include "avx512_vnni/record_quantized_int8/common.h" +#if defined(__AVX512VNNI__) +#include +#endif + +// Tail layout for quantized INT8 cosine vectors: +// +// [ original_dim bytes: int8_t elements ] +// [ float scale_a ] (ma) +// [ float bias_a ] (mb) +// [ float sum_a ] (ms) +// [ float square_sum_a ] (ms2) +// [ int int8_sum ] (sum of raw int8 elements, used when query is +// preprocessed to uint8 via +128 shift) +// +// The query tail has the same layout (qa, qb, qs, qs2) without int8_sum. + +namespace zvec::turbo::avx512_vnni { + +void cosine_int8_distance(const void *a, const void *b, size_t dim, + float *distance) { +#if defined(__AVX512VNNI__) + // `dim` is the full encoded size; the original vector occupies dim-24 bytes. + const int original_dim = dim - 24; + if (original_dim <= 0) { + return; + } + + // Compute raw integer inner product over the original_dim bytes. + // Note: for the single-vector path there is no query preprocessing, so both + // sides are treated as int8_t (same as the non-preprocessed path in + // MinusInnerProductDistanceBatchWithScoreUnquantized). + internal::ip_int8_avx512_vnni(a, b, original_dim, distance); + + const float *a_tail = reinterpret_cast( + reinterpret_cast(a) + original_dim); + const float *b_tail = reinterpret_cast( + reinterpret_cast(b) + original_dim); + + float ma = a_tail[0]; + float mb = a_tail[1]; + float ms = a_tail[2]; + + float qa = b_tail[0]; + float qb = b_tail[1]; + float qs = b_tail[2]; + + // Dequantize and compute cosine distance: + // cosine_dist = -(ma * qa * ip + mb * qa * qs + qb * ma * ms + // + original_dim * qb * mb) + *distance = -(ma * qa * *distance + mb * qa * qs + qb * ma * ms + + static_cast(original_dim) * qb * mb); +#else + (void)a; + (void)b; + (void)dim; + (void)distance; +#endif +} + +void cosine_int8_batch_distance(const void *const *vectors, const void *query, + size_t n, size_t dim, float *distances) { +#if defined(__AVX512VNNI__) + // `dim` is the full encoded size; the original vector occupies dim-24 bytes. + const int original_dim = dim - 24; + if (original_dim <= 0) { + return; + } + + // Compute raw inner products for all vectors. The query has been preprocessed + // (int8 + 128 -> uint8) so dpbusd can be used via ip_int8_batch_avx512_vnni. + internal::ip_int8_batch_avx512_vnni(vectors, query, n, original_dim, + distances); + + const float *q_tail = reinterpret_cast( + reinterpret_cast(query) + original_dim); + float qa = q_tail[0]; + float qb = q_tail[1]; + float qs = q_tail[2]; + + for (int i = 0; i < n; ++i) { + const float *m_tail = reinterpret_cast( + reinterpret_cast(vectors[i]) + original_dim); + float ma = m_tail[0]; + float mb = m_tail[1]; + float ms = m_tail[2]; + // Correct for the +128 shift applied to the query during preprocessing: + // dpbusd computes sum(uint8_query[i] * int8_data[i]) + // = sum((int8_query[i] + 128) * int8_data[i]) + // = true_ip + 128 * sum(int8_data[i]) + // int8_sum is stored as the 5th int-sized field after the 4 floats. + int int8_sum = reinterpret_cast(m_tail)[4]; + float &result = distances[i]; + result -= 128.0f * static_cast(int8_sum); + + // Dequantize and compute cosine distance: + // cosine_dist = -(ma * qa * ip + mb * qa * qs + qb * ma * ms + // + original_dim * qb * mb) + result = -(ma * qa * result + mb * qa * qs + qb * ma * ms + + static_cast(original_dim) * qb * mb); + } +#else + (void)vectors; + (void)query; + (void)n; + (void)dim; + (void)distances; +#endif +} + +void cosine_int8_query_preprocess(void *query, size_t dim) { +#if defined(__AVX512VNNI__) + // The original vector occupies dim-24 bytes; only those bytes are shifted. + const int original_dim = static_cast(dim) - 24; + if (original_dim <= 0) { + return; + } + internal::shift_int8_to_uint8_avx512(query, original_dim); +#else + (void)query; + (void)dim; +#endif +} + +} // namespace zvec::turbo::avx512_vnni diff --git a/src/turbo/avx512_vnni/record_quantized_int8/cosine.h b/src/turbo/avx512_vnni/record_quantized_int8/cosine.h new file mode 100644 index 00000000..836af103 --- /dev/null +++ b/src/turbo/avx512_vnni/record_quantized_int8/cosine.h @@ -0,0 +1,39 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace zvec::turbo::avx512_vnni { + +// Compute cosine distance (negative inner product after normalization) between +// a single quantized INT8 vector pair. +// `dim` includes the original vector bytes plus a 24-byte metadata tail +// (3 floats: scale_a, bias_a, sum_a). +void cosine_int8_distance(const void *a, const void *b, size_t dim, + float *distance); + +// Batch version of cosine_int8_distance. +// The query must have been preprocessed by cosine_int8_query_preprocess +// (int8 -> uint8 via +128 shift) before calling this function. +void cosine_int8_batch_distance(const void *const *vectors, const void *query, + size_t n, size_t dim, float *distances); + +// Preprocess the query vector in-place (shift int8 -> uint8 by adding 128) +// so that the AVX512-VNNI dpbusd instruction can be used for inner product. +// `dim` includes the 24-byte metadata tail. +void cosine_int8_query_preprocess(void *query, size_t dim); + +} // namespace zvec::turbo::avx512_vnni diff --git a/src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.cc b/src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.cc new file mode 100644 index 00000000..44850238 --- /dev/null +++ b/src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.cc @@ -0,0 +1,138 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is compiled with per-file -march=avx512vnni (set in CMakeLists.txt) +// so that all AVX512-VNNI intrinsics and the inlined inner product kernels from +// common.h are compiled with the correct target ISA. + +#include "avx512_vnni/record_quantized_int8/squared_euclidean.h" +#include "avx512_vnni/record_quantized_int8/common.h" +#if defined(__AVX512VNNI__) +#include +#endif + +// Tail layout for quantized INT8 squared Euclidean vectors: +// +// [ original_dim bytes: int8_t elements ] +// [ float scale_a ] (ma) +// [ float bias_a ] (mb) +// [ float sum_a ] (ms) +// [ float sum2_a ] (ms2) +// [ int int8_sum ] (sum of raw int8 elements, used for bias correction +// when the query has been shifted to uint8 via +128) +// +// Total tail size: 4 floats + 1 int = 20 bytes, so dim = original_dim + 20. + +namespace zvec::turbo::avx512_vnni { + +void squared_euclidean_int8_distance(const void *a, const void *b, size_t dim, + float *distance) { +#if defined(__AVX512VNNI__) + const int original_dim = dim - 20; + if (original_dim <= 0) { + return; + } + internal::ip_int8_avx512_vnni(a, b, original_dim, distance); + + const float *a_tail = reinterpret_cast( + reinterpret_cast(a) + original_dim); + const float *b_tail = reinterpret_cast( + reinterpret_cast(b) + original_dim); + + float ma = a_tail[0]; + float mb = a_tail[1]; + float ms = a_tail[2]; + float ms2 = a_tail[3]; + + float qa = b_tail[0]; + float qb = b_tail[1]; + float qs = b_tail[2]; + float qs2 = b_tail[3]; + + const float sum = qa * qs; + const float sum2 = qa * qa * qs2; + + *distance = ma * ma * ms2 + sum2 - 2 * ma * qa * *distance + + (mb - qb) * (mb - qb) * original_dim + + 2 * (mb - qb) * (ms * ma - sum); +#else + (void)a; + (void)b; + (void)dim; + (void)distance; +#endif +} + +void squared_euclidean_int8_batch_distance(const void *const *vectors, + const void *query, size_t n, + size_t dim, float *distances) { +#if defined(__AVX512VNNI__) + const int original_dim = dim - 20; + if (original_dim <= 0) { + return; + } + + internal::ip_int8_batch_avx512_vnni(vectors, query, n, original_dim, + distances); + const float *q_tail = reinterpret_cast( + reinterpret_cast(query) + original_dim); + float qa = q_tail[0]; + float qb = q_tail[1]; + float qs = q_tail[2]; + float qs2 = q_tail[3]; + + const float sum = qa * qs; + const float sum2 = qa * qa * qs2; + for (size_t i = 0; i < n; ++i) { + const float *m_tail = reinterpret_cast( + reinterpret_cast(vectors[i]) + original_dim); + float ma = m_tail[0]; + float mb = m_tail[1]; + float ms = m_tail[2]; + float ms2 = m_tail[3]; + // Correct for the +128 shift applied to the query during preprocessing: + // dpbusd computes sum(uint8_query[i] * int8_data[i]) + // = sum((int8_query[i] + 128) * int8_data[i]) + // = true_ip + 128 * sum(int8_data[i]) + // int8_sum is stored as the 5th int-sized field after the 4 floats. + int int8_sum = reinterpret_cast(m_tail)[4]; + float &result = distances[i]; + result -= 128.0f * static_cast(int8_sum); + result = ma * ma * ms2 + sum2 - 2 * ma * qa * result + + (mb - qb) * (mb - qb) * original_dim + + 2 * (mb - qb) * (ms * ma - sum); + } +#else + (void)vectors; + (void)query; + (void)n; + (void)dim; + (void)distances; +#endif +} + +void squared_euclidean_int8_query_preprocess(void *query, size_t dim) { +#if defined(__AVX512VNNI__) + const int original_dim = static_cast(dim) - 20; + if (original_dim <= 0) { + return; + } + internal::shift_int8_to_uint8_avx512(query, original_dim); +#else + (void)query; + (void)dim; +#endif +} + +} // namespace zvec::turbo::avx512_vnni diff --git a/src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.h b/src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.h new file mode 100644 index 00000000..c1830912 --- /dev/null +++ b/src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.h @@ -0,0 +1,41 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace zvec::turbo::avx512_vnni { + +// Compute squared Euclidean distance between a single quantized INT8 +// vector pair. +// `dim` includes the original vector bytes plus a 20-byte metadata tail +// (4 floats: scale_a, bias_a, sum_a, sum2_a). +void squared_euclidean_int8_distance(const void *a, const void *b, size_t dim, + float *distance); + +// Batch version of squared_euclidean_int8_distance. +// The query must have been preprocessed by +// squared_euclidean_int8_query_preprocess (int8 -> uint8 via +128 shift) +// before calling this function. +void squared_euclidean_int8_batch_distance(const void *const *vectors, + const void *query, size_t n, + size_t dim, float *distances); + +// Preprocess the query vector in-place (shift int8 -> uint8 by adding 128) +// for the batch path. Only the original_dim bytes are shifted; the metadata +// tail is left intact. `dim` includes the 20-byte metadata tail. +void squared_euclidean_int8_query_preprocess(void *query, size_t dim); + +} // namespace zvec::turbo::avx512_vnni diff --git a/src/turbo/turbo.cc b/src/turbo/turbo.cc new file mode 100644 index 00000000..a731cfed --- /dev/null +++ b/src/turbo/turbo.cc @@ -0,0 +1,75 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "avx512_vnni/record_quantized_int8/cosine.h" +#include "avx512_vnni/record_quantized_int8/squared_euclidean.h" + +namespace zvec::turbo { + +DistanceFunc get_distance_func(MetricType metric_type, DataType data_type, + QuantizeType quantize_type) { + if (data_type == DataType::kInt8) { + if (quantize_type == QuantizeType::kDefault) { + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { + if (metric_type == MetricType::kSquaredEuclidean) { + return avx512_vnni::squared_euclidean_int8_distance; + } + if (metric_type == MetricType::kCosine) { + return avx512_vnni::cosine_int8_distance; + } + } + } + } + return nullptr; +} + +BatchDistanceFunc get_batch_distance_func(MetricType metric_type, + DataType data_type, + QuantizeType quantize_type) { + if (data_type == DataType::kInt8) { + if (quantize_type == QuantizeType::kDefault) { + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { + if (metric_type == MetricType::kSquaredEuclidean) { + return avx512_vnni::squared_euclidean_int8_batch_distance; + } + if (metric_type == MetricType::kCosine) { + return avx512_vnni::cosine_int8_batch_distance; + } + } + } + } + return nullptr; +} + +QueryPreprocessFunc get_query_preprocess_func(MetricType metric_type, + DataType data_type, + QuantizeType quantize_type) { + if (data_type == DataType::kInt8) { + if (quantize_type == QuantizeType::kDefault) { + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { + if (metric_type == MetricType::kSquaredEuclidean) { + return avx512_vnni::squared_euclidean_int8_query_preprocess; + } + if (metric_type == MetricType::kCosine) { + return avx512_vnni::cosine_int8_query_preprocess; + } + } + } + } + return nullptr; +} + +} // namespace zvec::turbo From 3b8e9c91e7bee2e1364c8f6a9b2520ac10057b7d Mon Sep 17 00:00:00 2001 From: "yinzefeng.yzf" Date: Thu, 19 Mar 2026 11:24:43 +0800 Subject: [PATCH 33/34] clang-format --- src/core/algorithm/hnsw/hnsw_streamer.cc | 6 ++++-- src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc | 8 +++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/core/algorithm/hnsw/hnsw_streamer.cc b/src/core/algorithm/hnsw/hnsw_streamer.cc index e08953aa..4ff3b96b 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer.cc +++ b/src/core/algorithm/hnsw/hnsw_streamer.cc @@ -88,7 +88,8 @@ int HnswStreamer::init(const IndexMeta &imeta, const ailego::Params ¶ms) { ef_construction_ = HnswStreamerEntityNew::kDefaultEfConstruction; } if (upper_max_neighbor_cnt_ == 0U) { - upper_max_neighbor_cnt_ = HnswStreamerEntityNew::kDefaultUpperMaxNeighborCnt; + upper_max_neighbor_cnt_ = + HnswStreamerEntityNew::kDefaultUpperMaxNeighborCnt; } if (upper_max_neighbor_cnt_ > HnswStreamerEntityNew::kMaxNeighborCnt) { LOG_ERROR("[%s] must be in range (0,%d)", @@ -342,7 +343,8 @@ int HnswStreamer::dump(const IndexDumper::Pointer &dumper) { shared_mutex_.lock(); AILEGO_DEFER([&]() { shared_mutex_.unlock(); }); - meta_.set_searcher("HnswSearcher", HnswStreamerEntityNew::kRevision, ailego::Params()); + meta_.set_searcher("HnswSearcher", HnswStreamerEntityNew::kRevision, + ailego::Params()); int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (ret != 0) { diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc index e1875fff..564c0b3e 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc @@ -196,12 +196,13 @@ const Neighbors HnswStreamerEntityNew::get_neighbors(level_t level, } const Neighbors HnswStreamerEntityNew::get_neighbors_new(level_t level, - node_id_t id) const { + node_id_t id) const { if (id) { return get_neighbors(level, id); } else { const void *src = neighbors_value_ptr_->data() + id * neighbor_size_; - const NeighborsHeader *header = reinterpret_cast(src); + const NeighborsHeader *header = + reinterpret_cast(src); return Neighbors(header->neighbor_cnt, header->neighbors); } } @@ -489,7 +490,8 @@ int HnswStreamerEntityNew::open(IndexStorage::Pointer stg, neighbors_value_ptr_->reserve(neighbor_size_ * doc_cnt()); for (int i = 0; i < doc_cnt(); i++) { Neighbors neighbor = get_neighbors(0, i); - neighbors_value_ptr_->append((const char *)neighbor.neighbor_block.data(), neighbor_size_); + neighbors_value_ptr_->append((const char *)neighbor.neighbor_block.data(), + neighbor_size_); } stats_.set_loaded_count(doc_cnt()); From 9693c4ffa6a532a6c383917e1d9efd3f06d24699 Mon Sep 17 00:00:00 2001 From: "yinzefeng.yzf" Date: Thu, 19 Mar 2026 17:27:20 +0800 Subject: [PATCH 34/34] fix comment@greptile --- src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc | 2 +- src/core/algorithm/hnsw/hnsw_streamer_entity_new.h | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc index e975f0d5..c186974b 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.cc @@ -145,7 +145,7 @@ int HnswStreamerEntityNew::cleanup() { int HnswStreamerEntityNew::update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) { - std::vector buffer(neighbor_size_); + std::vector buffer(level == 0 ? neighbor_size_ : upper_neighbor_size_); NeighborsHeader *hd = reinterpret_cast(buffer.data()); hd->neighbor_cnt = neighbors.size(); size_t i = 0; diff --git a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h index 11707a02..dcc27e9b 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h +++ b/src/core/algorithm/hnsw/hnsw_streamer_entity_new.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include #include @@ -396,17 +395,13 @@ class HnswStreamerEntityNew { } void print_key_map() const { - std::cout << "key map begins" << std::endl; - + LOG_DEBUG("key map begins"); auto iter = keys_map_->begin(); while (iter != keys_map_->end()) { - std::cout << "key: " << iter->first << ", id: " << iter->second - << std::endl; - ; + LOG_DEBUG("key: %lld, id: %u", (long long)iter->first, iter->second); iter++; } - - std::cout << "key map ends" << std::endl; + LOG_DEBUG("key map ends"); } //! Get l0 neighbors size @@ -709,7 +704,7 @@ class HnswStreamerEntityNew { private: HnswStreamerEntityNew(const HnswStreamerEntityNew &) = delete; HnswStreamerEntityNew &operator=(const HnswStreamerEntityNew &) = delete; - static constexpr uint64_t kUpperHashMemoryInflateRatio = 2.0f; + static constexpr float kUpperHashMemoryInflateRatio = 2.0f; private: IndexStreamer::Stats &stats_;