diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index a77fbb34..cd0f99e2 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -13,7 +13,8 @@ set(CHEETAH_SOURCES global.cpp init.cpp internal-malloc.cpp - local-hypertable.cpp + # local-hypertable.cpp + local-hyper-pagetable.cpp local-reducer-api.cpp pedigree_globals.cpp personality.cpp diff --git a/runtime/closure.h b/runtime/closure.h index 6f5af1a9..7844218f 100644 --- a/runtime/closure.h +++ b/runtime/closure.h @@ -2,7 +2,8 @@ #define _CLOSURE_TYPE_H #include "cilk-internal.h" -#include "local-hypertable.h" +// #include "local-hypertable.h" +#include "local-hyper-pagetable.h" #include "rts-config.h" #include diff --git a/runtime/init.cpp b/runtime/init.cpp index ff5c4dc4..2f5fa16b 100644 --- a/runtime/init.cpp +++ b/runtime/init.cpp @@ -697,7 +697,8 @@ static void worker_terminate(__cilkrts_worker *w, [[maybe_unused]] void *data) { cilk_fiber_pool_per_worker_terminate(w); hyper_table *ht = w->hyper_table; if (ht) { - local_hyper_table_free(ht); + // local_hyper_table_free(ht); + delete ht; w->hyper_table = nullptr; } worker_local_destroy(w->l, w->g); diff --git a/runtime/local-hyper-pagetable.cpp b/runtime/local-hyper-pagetable.cpp new file mode 100644 index 00000000..6fe016d0 --- /dev/null +++ b/runtime/local-hyper-pagetable.cpp @@ -0,0 +1,183 @@ +#include "local-hyper-pagetable.h" +#include "cilk/cilk_api.h" +#include "cilk/reducer" +#include "hyperobject_base.h" +#include "internal-malloc.h" +#include + +/////////////////////////////////////////////////////////////////////////// +// Implementations of methods from PageTableTy. + +template +V *PageTableTy::lookupInnerNode(uintptr_t Addr) { + if (auto *Page = (*std::get_if(&Table))->lookup(Addr)) { + return (*Page)->lookup(Addr); + } + return nullptr; +} + +template +EntryTy *PageTableTy::findInnerNode(uintptr_t Addr) { + if (auto *Page = (*std::get_if(&Table))->lookup(Addr)) { + return (*Page)->find(Addr); + } + return nullptr; +} + +template +bool PageTableTy::removeInnerNode(uintptr_t Addr) { + if (std::holds_alternative(Table)) { + InnerNodeTy *InnerNode = *std::get_if(&Table); + if (auto *Page = InnerNode->lookup(Addr)) { + return (*Page)->remove(Addr); + } + } + + return false; +} + +// Ensure the PageTableTy is fully instantiated for hyper_table. +template struct PageTableTy; + +/////////////////////////////////////////////////////////////////////////// +// Query, insert, and delete methods for the hash table. + +hyper_table *__cilkrts_local_hyper_table_alloc(void) { + auto *Tmp = new hyper_table(); + // fprintf(stderr, "cilkrts_local_hyper_table_alloc -> %p\n", Tmp); + return Tmp; +} + +__reducer_base *__cilkrts_insert_new_view_0(hyper_table *table, + __reducer_base *key) { + // Create a new view and initialize it with the identity function. + size_t size = key->size(); + void *new_view = cilk_aligned_alloc(64, round_size_to_alignment(64, size)); + __reducer_base *base = key->identity(new_view); + + // Insert the new view into the local hypertable. + [[maybe_unused]] bool success = + table->insert((uintptr_t)key, {new_view, base}); + assert(success && "Failed to insert reducer data"); + return base; +} + +void *__cilkrts_insert_new_view_1(hyper_table *table, uintptr_t key, + const __reducer_callbacks &callbacks) { + // Create a new view and initialize it with the identity function. + void *new_view = + cilk_aligned_alloc(64, round_size_to_alignment(64, callbacks.size)); + callbacks.identity(new_view); + + // Insert the new view into the local hypertable. + [[maybe_unused]] bool success = + table->insert((uintptr_t)key, {new_view, &callbacks.reduce}); + assert(success && "Failed to insert reducer data"); + return new_view; +} + +void *__cilkrts_insert_new_view_2(hyper_table *table, uintptr_t key, + size_t size, __cilk_c_identity_fn *identity, + __cilk_c_reduce_fn *reduce) { + // Create a new view and initialize it with the identity function. + void *new_view = cilk_aligned_alloc(64, round_size_to_alignment(64, size)); + identity(new_view); + + // Insert the new view into the local hypertable. + [[maybe_unused]] bool success = + table->insert((uintptr_t)key, {new_view, reduce}); + // fprintf(stderr, "cilkrts_insert_new_view_2: inserted %lx -> %p into + // %p\n", + // key, new_view, table); + assert(success && "Failed to insert reducer data"); + return new_view; +} + +void bucket_reduce(bucket *Left, bucket *Right) { + assert(Left->data.extra.index() == Right->data.extra.index()); + void *LeftView = Left->data.view, *RightView = Right->data.view; + if (std::holds_alternative<__reducer_base *>(Left->data.extra)) { + __reducer_base *Leftmost = static_cast<__reducer_base *>( + reinterpret_cast(getAddrFromKey(Left->key))); + __reducer_base *LeftR = std::get<__reducer_base *>(Left->data.extra); + __reducer_base *RightR = std::get<__reducer_base *>(Right->data.extra); + Leftmost->reduce(LeftR, RightR); + RightR->~__reducer_base(); + } else if (std::holds_alternative( + Left->data.extra)) { + (*std::get(Left->data.extra))(LeftView, + RightView); + } else { + // fprintf(stderr, "bucket_reduce %p, %p\n", LeftView, RightView); + std::get<__cilk_c_reduce_fn *>(Left->data.extra)(LeftView, RightView); + } + Right->data.extra = (__reducer_base *)nullptr; + Right->data.view = nullptr; + free(RightView); +} + +// Merge two hypertables, left and right. Returns the merged hypertable and +// deletes the other. +hyper_table *merge_two_hts(hyper_table *__restrict Left, + hyper_table *__restrict Right) { + // fprintf(stderr, "merge_two_hts %p (%zu), %p (%zu)\n", Left, Left->size(), + // Right, Right->size()); + // In the trivial case of an empty hyper_table, return the other + // hyper_table. + if (!Left) + return Right; + if (!Right) + return Left; + if (Left->size() == 0) { + delete Left; + return Right; + } + if (Right->size() == 0) { + delete Right; + return Left; + } + + // Pick the smaller hyper_table to be the source, which we will iterate + // over. + bool LeftDst; + hyper_table *Src, *Dst; + if (Left->size() >= Right->size()) { + Src = Right; + Dst = Left; + LeftDst = true; + } else { + Src = Left; + Dst = Right; + LeftDst = false; + } + + for (bucket B : *Src) { + uintptr_t Addr = getAddrFromKey(B.key); + bucket *DstB = Dst->find(Addr); + if (DstB == nullptr) { + // fprintf(stderr, "merge_two_hts: inserting %lx -> %p into %p\n", + // Addr, B.Data.view, Dst); + Dst->insert(Addr, std::forward(B.data)); + } else { + if (LeftDst) { + // fprintf(stderr, "merge_two_hts: reduction (%d): %p -> %p and + // %p -> %p\n", + // LeftDst, Dst, DstB->Data.view, Src, B.Data.view); + bucket_reduce(DstB, &B); + } else { + // fprintf(stderr, "merge_two_hts: reduction (%d): %p -> %p and + // %p -> %p\n", + // LeftDst, Src, B.Data.view, Dst, DstB->Data.view); + bucket_reduce(&B, DstB); + DstB->data = B.data; + B.data.extra = (__reducer_base *)nullptr; + B.data.view = nullptr; + } + } + } + + // Destroy the source hyper_table, and return the destination. + delete Src; + + return Dst; +} \ No newline at end of file diff --git a/runtime/local-hyper-pagetable.h b/runtime/local-hyper-pagetable.h new file mode 100644 index 00000000..07933fd0 --- /dev/null +++ b/runtime/local-hyper-pagetable.h @@ -0,0 +1,666 @@ +#ifndef _LOCAL_HYPER_PAGETABLE_H +#define _LOCAL_HYPER_PAGETABLE_H + +#include "cilk/cilk_api.h" +#include "hyperobject_base.h" +#include "rts-config.h" +#include +#include +#include +#include +#include +#include + +template struct EntryTy { + uintptr_t key = 0; + V data; + + void reset() { + key = 0; + if (std::is_destructible_v) + data.~V(); + } +}; + +template struct SmallEntrySetTy { + using EntryTy = EntryTy; + + ssize_t Occupied = 0; + EntryTy Entries[Capacity]; + + static_assert(Capacity <= 8 * sizeof(Occupied)); + + EntryTy *get(uintptr_t Key) { + for (ssize_t i = Occupied - 1; i >= 0; --i) { + if (Entries[i].key == Key) + return &Entries[i]; + } + return nullptr; + } + const EntryTy *get(uintptr_t Key) const { + return const_cast(this).get(Key); + } + + bool insert(uintptr_t Key, V &&Value) { + if (Occupied < Capacity) { + Entries[Occupied++] = {Key, Value}; + return true; + } + return false; + } + + bool remove(uintptr_t Key) { + for (ssize_t i = Occupied - 1; i >= 0; --i) { + if (Entries[i].key == Key) { + Entries[i].reset(); + if (i != Occupied - 1) + Entries[i] = std::forward((Entries[Occupied - 1])); + --Occupied; + return true; + } + } + return false; + } + + EntryTy &operator[](size_t Idx) { return Entries[Idx]; } +}; + +template struct SmallEntrySetTy { + using EntryTy = EntryTy; + EntryTy Entry; + + EntryTy *get(uintptr_t Key) { + if (Entry.key == Key) + return &Entry; + return nullptr; + } + const EntryTy *get(uintptr_t Key) const { + return const_cast(this).get(Key); + } + + bool insert(uintptr_t Key, const V &Value) { + auto OldKey = Entry.key; + if (OldKey && OldKey != Key) + // This slot is occupied by a different entry. + return false; + + // Insert into this slot. + Entry = {Key, Value}; + return true; + } + + bool remove(uintptr_t Key) { + if (Entry.key && Entry.key == Key) { + Entry.reset(); + return true; + } + return false; + } +}; + +template struct TableSizeTy { + static_assert(LgSz <= 64); + static_assert(RShift < 64); + static constexpr size_t LgSize = + (LgSz + RShift < 48) ? LgSz : (48 - RShift); + static constexpr size_t Size = (size_t)1 << LgSize; + static constexpr uintptr_t AddrMask = Size - 1; + static constexpr size_t Bits = LgSize + RShift; + static constexpr size_t KeyMask = ((size_t)1 << Bits) - 1; + static constexpr size_t toIndex(uintptr_t Addr) { + return (Addr >> RShift) & AddrMask; + } +}; + +template +static inline uintptr_t makeKey(uintptr_t Addr) { + return ~((Addr >> RShift) << RShift) & Mask; +} + +static inline uintptr_t getAddrFromKey(uintptr_t Key) { return ~Key; } + +template > +struct TableTy : public TableSizeTy { + using SizeTy = TableSizeTy; + using EntryTy = EntryTy; + + SmallEntrySetTy Entries[SizeTy::Size]; + + // Insert a value associated with an address. + bool insert(uintptr_t Addr, V &&Value) { + return Entries[SizeTy::toIndex(Addr)].insert(makeKey(Addr), + std::forward(Value)); + } + + // Remove the value associated with the given address. + bool remove(uintptr_t Addr) { + return Entries[SizeTy::toIndex(Addr)].remove(makeKey(Addr)); + } + + // Lookup the value associated with an address. + V *lookup(uintptr_t Addr) { + if (EntryTy *Entry = Entries[SizeTy::toIndex(Addr)].get(makeKey(Addr))) + return &Entry->data; + return nullptr; + } + + // Get the entry at the associated address. + EntryTy *find(uintptr_t Addr) { + return Entries[SizeTy::toIndex(Addr)].get(makeKey(Addr)); + } +}; + +// Leaf tables are indexed simply by the least significant 12 bits of an +// address. +static constexpr size_t LeafLgSz = 9; +static constexpr size_t LeafLgSetCapacity = 3; +static constexpr size_t LeafSetCapacity = 1 << LeafLgSetCapacity; +template +struct LeafTableTy + : public TableTy { + using TableTy = TableTy; + using SizeTy = TableTy::SizeTy; + + // Bit set to track which sets in this table contain elements. + static constexpr size_t AccessedFieldSize = 8 * sizeof(uint64_t); + static constexpr size_t AccessedSize = + (1UL << LeafLgSz) / AccessedFieldSize; + uint64_t Accessed[AccessedSize] = {0UL}; + + private: + static size_t getAccessedIdx(size_t Idx) { return Idx / AccessedFieldSize; } + static uint64_t getAccessedMask(size_t Idx) { + return 1UL << (Idx % AccessedFieldSize); + } + + public: + bool insert(uintptr_t Addr, V &&Value) { + if (TableTy::insert(Addr, std::forward(Value))) { + const size_t Idx = SizeTy::toIndex(Addr); + if (this->Entries[Idx].Occupied == 1) + Accessed[getAccessedIdx(Idx)] |= getAccessedMask(Idx); + return true; + } + return false; + } + + bool remove(uintptr_t Addr) { + if (TableTy::remove(Addr)) { + const size_t Idx = SizeTy::toIndex(Addr); + if (this->Entries[Idx].Occupied == 0) + Accessed[getAccessedIdx(Idx)] &= ~getAccessedMask(Idx); + return true; + } + return false; + } + + // Constants and methods for iterating through the entries in this table. + static constexpr uintptr_t EndIteratorValue = + 1UL << (LeafLgSz + LeafLgSetCapacity); + + static ssize_t getSetIdx(uintptr_t It) { + return It & (LeafSetCapacity - 1); + } + static uintptr_t getEntryIdx(uintptr_t It) { + return It >> LeafLgSetCapacity; + } + + uintptr_t advanceToNextEntry(uintptr_t It) { + if (this->Entries[getEntryIdx(It)].Occupied > getSetIdx(It)) + return It; + + uintptr_t NextEntryIdx = getEntryIdx(It) + 1; + size_t AccessedIdx = LeafTableTy::getAccessedIdx(NextEntryIdx); + uint64_t AccessedMask = LeafTableTy::getAccessedMask(NextEntryIdx); + for (; AccessedIdx < AccessedSize; ++AccessedIdx) { + uint64_t AccessedField = + Accessed[AccessedIdx] & ~(AccessedMask - 1); + if (AccessedField) { + It = ((AccessedIdx * AccessedFieldSize) + + __builtin_ctzl(AccessedField)) + << LeafLgSetCapacity; + return It; + } + AccessedMask = 1; + } + // Return the end iterator + It = EndIteratorValue; + return It; + } + + EntryTy &getEntryAt(uintptr_t It) { + return this->Entries[getEntryIdx(It)].Entries[getSetIdx(It)]; + } +}; + +// Pages are indexed by the 12 more significant bits of the address than the +// `RShift` template parameter. +template using PageSizeTy = TableSizeTy<12, RShift>; + +template +struct PageTy : public TableTy::LgSize, RShift, 1, + makeKey::KeyMask>> { + using PageSizeTy = PageSizeTy; + using TableTy = TableTy>; + using SizeTy = typename TableTy::SizeTy; + using EntryTy = EntryTy; + + // List of addresses in this page that have been inserted into. Used for + // destroying the higher-level PageTableTy. + std::vector Accessed; + // Add the given address to the list of addresses accessed in this page. + void recordAccess(uintptr_t Addr) { Accessed.push_back(Addr); } + + // Pages can be quite large. Use mmap and munmap to manage their physical + // memory on demand. + void *operator new(size_t Size) { + // Use MAP_ANONYMOUS to guarantee the page is initialized to zero. + return mmap(nullptr, sizeof(PageTy), PROT_READ | PROT_WRITE, + MAP_ANONYMOUS | MAP_PRIVATE, -1, 0); + } + void operator delete(void *Ptr) { munmap(Ptr, sizeof(PageTy)); } + + V &operator[](uintptr_t Addr) { + return this->Entries[PageSizeTy::toIndex(Addr)].Entry.data; + } + const V &operator[](uintptr_t Addr) const { + return const_cast(this)[Addr]; + } +}; + +// A page table is a dynamic tree structure, where each node is either a leaf +// table or a page of page tables. +template ::Bits> struct PageTableTy { + static_assert(Bits < 48); + using LeafTableTy = LeafTableTy; + // Subtables are page tables indexed by more significant bits of the + // address. + using SubTableTy = PageTableTy::Bits>; + // An inner node is a page of pointers to page tables indexed using more + // significant bits than the `Bits` parameter. + using InnerNodeTy = PageTy; + using InnerNodeSizeTy = typename InnerNodeTy::PageSizeTy; + // The table is either a leaf table or a pointer to an inner node. + using NodeTy = std::variant; + NodeTy Table; + + private: + // Helper method to insert an address-value pair into an inner node. + static void insertIntoInnerNode(InnerNodeTy *Node, uintptr_t Addr, + V &&Value) { + // NOTE: This method is defined in the header in order to ensure it is + // properly instantiated in all recursive PageTableTy instantiations. + + // Get the subtable corresponding with Addr. + SubTableTy *Page = (*Node)[Addr]; + if (Page == nullptr) { + // Create and insert a new subtable. + Page = new SubTableTy; + Node->recordAccess(Addr); + [[maybe_unused]] bool Result = + Node->insert(Addr, std::forward(Page)); + assert(Result && "Failed to add new subtable to node."); + } + // Insert into subtable. + [[maybe_unused]] bool Result = + Page->insert(Addr, std::forward(Value)); + assert(Result && "Failed to add address to to subtable."); + } + + // Promote a leaf table to an inner node and then insert the given value + // associated with the given address. + [[clang::noinline]] + static InnerNodeTy *promoteLeafNodeAndInsert(LeafTableTy &LeafTable, + uintptr_t Addr, V &&Value) { + // NOTE: This method is defined in the header in order to ensure it is + // properly instantiated in all recursive PageTableTy instantiations. + + // The leaf table could not insert the new entry. Convert the leaf + // table into an inner node. + InnerNodeTy *NewNode = new InnerNodeTy; + // Insert all entries in the leaf table into the new inner node. + for (auto EntrySet : LeafTable.Entries) { + for (ssize_t Idx = 0; Idx < EntrySet.Occupied; ++Idx) { + auto &Entry = EntrySet[Idx]; + insertIntoInnerNode(NewNode, getAddrFromKey(Entry.key), + std::forward(Entry.data)); + } + } + + // Insert the new entry into the new inner node. + insertIntoInnerNode(NewNode, Addr, std::forward(Value)); + return NewNode; + } + + // Insert the given value associated with the given address into an inner + // node. + [[clang::noinline]] + bool insertInnerNode(uintptr_t Addr, V &&Value) { + // NOTE: This method is defined in the header in order to ensure it is + // properly instantiated in all recursive PageTableTy instantiations. + if (std::holds_alternative(Table)) { + InnerNodeTy *Node = std::get(Table); + // Insert this entry into the inner node. + insertIntoInnerNode(Node, Addr, std::forward(Value)); + return true; + } + + return false; + } + + // Lookup the value at the given address from an inner node. + V *lookupInnerNode(uintptr_t Addr); + // Get the entry for the given address from an inner node. + EntryTy *findInnerNode(uintptr_t Addr); + + // Remove the entry for the given address from an inner node. + bool removeInnerNode(uintptr_t Addr); + + public: + ~PageTableTy() { + if (std::holds_alternative(Table)) { + InnerNodeTy *Node = std::get(Table); + for (size_t Addr : Node->Accessed) { + delete (*Node)[Addr]; + } + delete Node; + } + } + + // Get the value at the given address. + V *lookup(uintptr_t Addr) { + if (std::holds_alternative(Table)) { + return std::get(Table).lookup(Addr); + } + return lookupInnerNode(Addr); + } + + // Get the table entry at the given address. + EntryTy *find(uintptr_t Addr) { + if (std::holds_alternative(Table)) { + return std::get(Table).find(Addr); + } + return findInnerNode(Addr); + } + + // Insert the given value at the given address. + bool insert(uintptr_t Addr, V &&Value) { + if (std::holds_alternative(Table)) { + // Try to insert into this leaf table. + auto &LeafTable = std::get(Table); + if (LeafTable.insert(Addr, std::forward(Value))) + return true; + + InnerNodeTy *NewNode = promoteLeafNodeAndInsert( + LeafTable, Addr, std::forward(Value)); + // Replace this table with new inner node. + Table = NewNode; + return true; + } + + return insertInnerNode(Addr, std::forward(Value)); + } + + // Remove the value at the given address. + bool remove(uintptr_t Addr) { + if (std::holds_alternative(Table)) { + auto &LeafTable = std::get(Table); + return LeafTable.remove(Addr); + } + + return removeInnerNode(Addr); + } + + // Iterator to traverse the elements in the table. This iterator is used + // for merging two tables. + struct Iterator { + using value_type = EntryTy; + using difference_type = ptrdiff_t; + using InnerIterator = std::vector::iterator; + + PageTableTy *PageTable = nullptr; + std::variant It = + LeafTableTy::EndIteratorValue; + SubTableTy::Iterator SubIt; + + Iterator() : SubIt() {} + Iterator(PageTableTy &PageTable, bool MakeEnd = false) + : PageTable(&PageTable) { + if (MakeEnd) { + // Create an end iterator for this table. + if (std::holds_alternative(PageTable.Table)) { + It = LeafTableTy::EndIteratorValue; + } else { + InnerNodeTy *Node = + std::get(PageTable.Table); + It = Node->Accessed.end(); + } + return; + } + if (std::holds_alternative(PageTable.Table)) { + // Get the first entry in this leaf table. + It = std::get(PageTable.Table) + .advanceToNextEntry(0); + } else { + // Get a pointer to the first value within this inner-node. + InnerNodeTy *Node = std::get(PageTable.Table); + auto InnerIt = Node->Accessed.begin(); + auto EndInnerIt = Node->Accessed.end(); + // Because inner nodes are not depopulated when all elements + // within a page are removed, a scan is necessary to find the + // first element in any subtable in this inner node. + do { + SubIt = (*Node)[*InnerIt]->begin(); + } while (SubIt.atEnd() && ++InnerIt != EndInnerIt); + It = InnerIt; + } + } + + value_type &operator*() { + if (std::holds_alternative(PageTable->Table)) { + // Return the entry in this leaf table corresponding with the + // iterator value. + LeafTableTy &Leaf = std::get(PageTable->Table); + return Leaf.getEntryAt(std::get(It)); + } + // Dereference the subtable iterator to get the value. + return *SubIt; + } + + Iterator &operator++() { + if (std::holds_alternative(PageTable->Table)) { + // Advance the leaf iterator to the next entry. + LeafTableTy &Leaf = std::get(PageTable->Table); + uintptr_t LeafIt = std::get(It); + It = Leaf.advanceToNextEntry(++LeafIt); + return *this; + } + // Advance the subtable iterator. + ++SubIt; + if (SubIt.atEnd()) { + // The subtable iterator reached the end of its subtable. Find + // the next subtable with elements. + InnerNodeTy *Node = std::get(PageTable->Table); + InnerIterator InnerIt = std::get(It); + InnerIterator EndInnerIt = Node->Accessed.end(); + // Scan entries of this inner node until a valid entry is found. + while (SubIt.atEnd() && ++InnerIt != EndInnerIt) { + SubIt = (*Node)[*InnerIt]->begin(); + } + It = InnerIt; + if (InnerIt == EndInnerIt) + // This inner node has no more entries. Set the subtable + // iterator to the end-iterator value. + SubIt = typename SubTableTy::Iterator(); + } + return *this; + } + Iterator operator++(int) { + auto Tmp = *this; + ++*this; + return Tmp; + } + + bool atEnd() const { + if (std::holds_alternative(PageTable->Table)) { + // Check if this leaf-table iterator has the end value of a leaf + // table. + uintptr_t LeafIt = std::get(It); + return LeafIt == LeafTableTy::EndIteratorValue; + } + // Check if this inner-node iterator is pointing to the end of the + // node's accessed list. + InnerNodeTy *Node = std::get(PageTable->Table); + const InnerIterator &InnerIt = std::get(It); + return InnerIt == Node->Accessed.end(); + } + + bool operator==(const Iterator &Other) const { + return It == Other.It && SubIt == Other.SubIt; + } + }; + static_assert(std::input_or_output_iterator); + + Iterator begin() { return Iterator(*this); } + Iterator end() { return Iterator(*this, /*MakeEnd=*/true); } +}; + +// Template instantiation to prevent infinite recursion in template expansion. +template struct PageTableTy { + // This version of a PageTableTy should simply be a leaf table. + using LeafTableTy = LeafTableTy; + using NodeTy = LeafTableTy; + + NodeTy Table; + + // Get the value at the given address. + V *lookup(uintptr_t Addr) { return Table.lookup(Addr); } + // Get the table entry at the given address. + EntryTy *find(uintptr_t Addr) { return Table.find(Addr); } + + // Insert the given value at the given address. + bool insert(uintptr_t Addr, V &&Value) { + return Table.insert(Addr, std::forward(Value)); + } + + // Remove the value at the given address. + bool remove(uintptr_t Addr) { return Table.remove(Addr); } + + // Iterator type with the same methods as the general PageTableTy::Iterator, + // to support recursive template instantiation. + struct Iterator { + // Because this particular instantiation of PageTableTy simply contains + // a leaf table, this iterator simply handles the leaf table. + using value_type = EntryTy; + using difference_type = ptrdiff_t; + LeafTableTy *Table = nullptr; + uintptr_t It = 0; + + public: + Iterator() = default; + Iterator(LeafTableTy &Table, bool MakeEnd = false) + : Table(&Table), It(MakeEnd ? LeafTableTy::EndIteratorValue + : Table.advanceToNextEntry(0)) {} + Iterator(const Iterator &Other) : Table(Other.Table), It(Other.It) {} + Iterator &operator=(const Iterator &Other) { + Table = Other.Table; + It = Other.It; + return *this; + } + + value_type &operator*() const { return Table->getEntryAt(It); } + + Iterator &operator++() { + It = Table->advanceToNextEntry(++It); + return *this; + } + Iterator operator++(int) { + auto Tmp = *this; + ++*this; + return Tmp; + } + + bool atEnd() const { return It == LeafTableTy::EndIteratorValue; } + bool operator==(const Iterator &Other) const { return It == Other.It; } + }; + static_assert(std::input_or_output_iterator); + + Iterator begin() { return Iterator(Table); } + Iterator end() { return Iterator(Table, /*MakeEnd=*/true); } +}; + +using bucket = EntryTy; + +struct hyper_table : public PageTableTy { + using PageTableTy = PageTableTy; + using V = reducer_data; + + size_t NumEntries = 0; + + size_t size() const { return NumEntries; } + + bool insert(uintptr_t Addr, V &&Value) { + if (PageTableTy::insert(Addr, std::forward(Value))) { + ++NumEntries; + return true; + } + return false; + } + + bool remove(uintptr_t Addr) { + if (PageTableTy::remove(Addr)) { + --NumEntries; + return true; + } + return false; + } +}; + +CHEETAH_API +hyper_table *__cilkrts_local_hyper_table_alloc(void); + +static inline bucket *find_hyperobject(hyper_table *table, + uintptr_t key) noexcept { + auto *Tmp = table->find(key); + return Tmp; +} + +CHEETAH_INTERNAL +static inline bool remove_hyperobject(hyper_table *table, + uintptr_t key) noexcept { + auto Tmp = table->remove(key); + return Tmp; +} + +CHEETAH_INTERNAL +static inline bool insert_hyperobject(hyper_table *table, uintptr_t key, + reducer_data &&data) noexcept { + return table->insert(key, std::forward(data)); +} + +CHEETAH_API +bucket *__cilkrts_find_hyperobject_hash(hyper_table *table, uintptr_t key); + +CHEETAH_API +__reducer_base *__cilkrts_insert_new_view_0(hyper_table *table, + struct __reducer_base *key) + __attribute__((nonnull, returns_nonnull)); + +CHEETAH_API +void *__cilkrts_insert_new_view_1(hyper_table *table, uintptr_t key, + const __reducer_callbacks &callbacks) + __attribute__((nonnull, returns_nonnull)); + +CHEETAH_API +void *__cilkrts_insert_new_view_2(hyper_table *table, uintptr_t key, + size_t size, __cilk_c_identity_fn identity, + __cilk_c_reduce_fn reduce) + __attribute__((nonnull, returns_nonnull)); + +CHEETAH_INTERNAL +hyper_table *merge_two_hts(hyper_table *__restrict left, + hyper_table *__restrict right); + +#endif // _LOCAL_HYPER_PAGETABLE_H diff --git a/runtime/local-reducer-api.cpp b/runtime/local-reducer-api.cpp index 0eaa3395..46b56b37 100644 --- a/runtime/local-reducer-api.cpp +++ b/runtime/local-reducer-api.cpp @@ -1,6 +1,7 @@ #include "cilk-internal.h" #include "cilk2c_inlined.h" -#include "local-hypertable.h" +// #include "local-hypertable.h" +#include "local-hyper-pagetable.h" #include "local-reducer-api.h" #include "rts-config.h" @@ -17,34 +18,26 @@ __reducer_base::__reducer_base() { __reducer_base::~__reducer_base() {} -static void reducer_register(bucket &b) __CILKRTS_NOTHROW { +__attribute__((always_inline)) +static void reducer_register(uintptr_t key, reducer_data &&data) __CILKRTS_NOTHROW { struct hyper_table *table = get_local_hyper_table(__cilkrts_get_tls_worker()); - [[maybe_unused]] bool success = insert_hyperobject(table, b); + [[maybe_unused]] bool success = insert_hyperobject(table, key, std::move(data)); CILK_ASSERT(success && "Failed to register reducer."); } void __cilkrts_reducer_register_0(__reducer_base *key) __CILKRTS_NOTHROW { - bucket b{.key = (uintptr_t)key, .data = {.view = nullptr, .extra = key}}; - reducer_register(b); + reducer_register((uintptr_t)key, {.view = nullptr, .extra = key}); } void __cilkrts_reducer_register_1(void *key, __reducer_callbacks *cb) __CILKRTS_NOTHROW { - bucket b{ - .key = (uintptr_t)key, - .data = {.view = key, .extra = &cb->reduce}, - }; - reducer_register(b); + reducer_register((uintptr_t)key, {.view = key, .extra = &cb->reduce}); } void __cilkrts_reducer_register_2(void *key, __cilk_c_reduce_fn *reduce) __CILKRTS_NOTHROW { - bucket b{ - .key = (uintptr_t)key, - .data = {.view = key, .extra = reduce}, - }; - reducer_register(b); + reducer_register((uintptr_t)key, {.view = key, .extra = reduce}); } void __cilkrts_reducer_unregister(void *key) noexcept { @@ -63,7 +56,7 @@ __reducer_base *internal_reducer_lookup(__cilkrts_worker *w, struct hyper_table *table = get_local_hyper_table(w); bucket *b = find_hyperobject(table, (uintptr_t)key); if (__builtin_expect(!!b, true)) { - CILK_ASSERT_POINTER_EQUAL(key, (void *)b->key); + CILK_ASSERT_POINTER_EQUAL(key, (void *)getAddrFromKey(b->key)); // Return the existing view. return std::get<__reducer_base *>(b->data.extra); } diff --git a/runtime/local-reducer-api.h b/runtime/local-reducer-api.h index bb22c55c..e1c912de 100644 --- a/runtime/local-reducer-api.h +++ b/runtime/local-reducer-api.h @@ -3,7 +3,8 @@ #include "cilk-internal.h" #include "global.h" -#include "local-hypertable.h" +// #include "local-hypertable.h" +#include "local-hyper-pagetable.h" static inline hyper_table * get_local_hyper_table(__cilkrts_worker *w) { diff --git a/runtime/local.h b/runtime/local.h index 78e0ae91..e67b547f 100644 --- a/runtime/local.h +++ b/runtime/local.h @@ -4,7 +4,8 @@ #include "fiber.h" #include "internal-malloc-impl.h" /* for cilk_im_desc */ #include "jmpbuf.h" -#include "local-hypertable.h" +// #include "local-hypertable.h" +#include "local-hyper-pagetable.h" enum __cilkrts_worker_state : unsigned char { WORKER_IDLE = 10, diff --git a/runtime/personality.cpp b/runtime/personality.cpp index 8c50d328..8354d5e2 100644 --- a/runtime/personality.cpp +++ b/runtime/personality.cpp @@ -7,6 +7,7 @@ #include "fiber.h" #include "frame.h" #include "init.h" +#include "local-hyper-pagetable.h" #include "local-reducer-api.h" #include #include @@ -90,7 +91,7 @@ closure_exception *get_exception_reducer_or_null(__cilkrts_worker *w) noexcept { bucket *b = find_hyperobject(table, (uintptr_t)key); if (b) { - CILK_ASSERT_POINTER_EQUAL(key, (void *)b->key); + CILK_ASSERT_POINTER_EQUAL(key, (void *)getAddrFromKey(b->key)); // Return the existing view. __reducer_base *base = std::get<__reducer_base *>(b->data.extra); return static_cast(base); @@ -151,6 +152,7 @@ sync_in_personality(__cilkrts_worker *w, __cilkrts_stack_frame *sf, __cilkrts_sync(sf); } else { sanitizer_finish_switch_fiber(); + __cilkrts_do_reductions(sf); } } diff --git a/runtime/scheduler.cpp b/runtime/scheduler.cpp index c98b73f4..31f4d261 100644 --- a/runtime/scheduler.cpp +++ b/runtime/scheduler.cpp @@ -25,7 +25,8 @@ #include "frame.h" #include "global.h" #include "jmpbuf.h" -#include "local-hypertable.h" +// #include "local-hypertable.h" +#include "local-hyper-pagetable.h" #include "local.h" #include "scheduler.h" #include "worker.h" @@ -395,9 +396,9 @@ static Closure *Closure_return(__cilkrts_worker *const w, worker_id self, hyper_table **lht_ptr; Closure *const left_sib = child->left_sib; if (left_sib != nullptr) { - lht_ptr = &left_sib->right_ht; + lht_ptr = &left_sib->right_ht; } else { - lht_ptr = &parent->child_ht; + lht_ptr = &parent->child_ht; } hyper_table *lht = *lht_ptr; *lht_ptr = nullptr; @@ -416,8 +417,10 @@ static Closure *Closure_return(__cilkrts_worker *const w, worker_id self, l->lht = lht; setup_for_execution(w, child); - l->provably_good_steal = true; // Use the existing SP in the frame + l->provably_good_steal = true; // Use the existing SP in the frame + // Return this closure, so it will be scheduled again to perform the + // reduction. return child; } @@ -491,7 +494,6 @@ static Closure *Closure_return(__cilkrts_worker *const w, worker_id self, hyper_table *active_ht = parent->user_ht; parent->child_ht = nullptr; parent->user_ht = nullptr; - // w->hyper_table = merge_two_hts(child_ht, active_ht); CILK_ASSERT_NULL(l->lht); l->lht = child_ht; w->hyper_table = active_ht; diff --git a/unittests/test-hyper-pagetable.cpp b/unittests/test-hyper-pagetable.cpp new file mode 100644 index 00000000..f862737c --- /dev/null +++ b/unittests/test-hyper-pagetable.cpp @@ -0,0 +1,73 @@ +// #include "test-hypertable-common.h" +#include "../runtime/hyperobject_base.h" +#include "../runtime/local-hyper-pagetable.h" +#include +#include + +namespace cilk { +template static void zero(void *v) { + *static_cast(v) = static_cast(0); +} +template static void plus(void *l, void *r) { + *static_cast(l) += *static_cast(r); +} +} // namespace cilk + +// using HyperPageTableTy = PageTableTy; + +int main(int argc, char *argv[]) { + reducer_data Value{nullptr, cilk::plus}; + // HyperPageTableTy T; + hyper_table T; + // PageTableTy T; + + for (uintptr_t i = 0; i < 640; i += 64) { + // T.insert(i, getValueFor(i)); + assert(T.insert(i, Value)); + assert(T.insert(i + ((uintptr_t)1 << 27), Value)); + } + + bool BadEntry = false; + for (uintptr_t i = 0; i < 640; i += 64) { + auto *RD = T.lookup(i); + if (!RD) { + fprintf(stderr, "Missing entry at index %lx\n", i); + BadEntry = true; + } else if (RD->view != Value.view || RD->extra != Value.extra) { + fprintf(stderr, "Incorrect entry at index %lx\n", i); + BadEntry = true; + } + // } else if (*RD != Value) { + // std::cerr << "Incorrect entry at index " << i << ": Expected " << + // getValueFor(i) << ", found " << *RD << "\n"; + // } + // if (RD) { + // std::cout << "T[" << i << "] = " << *RD << "\n"; + // } else { + // std::cout << "T[" << i << "] is empty\n"; + // } + } + for (uintptr_t ii = 0; ii < 640; ii += 64) { + uintptr_t i = ii + ((uintptr_t)1 << 27); + auto *RD = T.lookup(i); + if (!RD) { + fprintf(stderr, "Missing entry at index %lx\n", i); + BadEntry = true; + } else if (RD->view != Value.view || RD->extra != Value.extra) { + fprintf(stderr, "Incorrect entry at index %lx\n", i); + BadEntry = true; + } + // } else if (*RD != Value) { + // std::cerr << "Incorrect entry at index " << i << ": Expected " << + // getValueFor(i) << ", found " << *RD << "\n"; + // } + // if (RD) { + // std::cout << "T[" << i << "] = " << *RD << "\n"; + // } else { + // std::cout << "T[" << i << "] is empty\n"; + // } + } + std::cout << (BadEntry ? "Test failed" : "Test passed") << "\n"; + + return 0; +} \ No newline at end of file