Skip to content

Patch segment fault when insert and rehash data based on main branch. #14

@MoFHeka

Description

@MoFHeka
# Patch overview and rationale:
#
# 1. 57-bit VMA Pointer Truncation Fix:
#    Increases bitmask shift from 48 to 56 bits in `IndirectEntries` and `state::list_head`.
#    This prevents truncation of pointers on systems with 5-level paging (57-bit VMA) 
#    where memory allocations exceed the 48-bit address boundary, preventing segmentation faults.
#
# 2. Epoch GC Use-After-Free Fix:
#    Moves `current_table_version.load()` inside `epoch::with_epoch` blocks across various 
#    functions (Upsert, Remove, Find). This guarantees the table version is not garbage 
#    collected by a concurrent thread before the epoch protection is established.
#
# 3. Concurrent Rehash Data Race Fix:
#    Adds `try_lock` protection in `copy_bucket` to prevent multiple threads from overwriting
#    the same bucket during expansion. Also adds null pointer checks in `check_bucket_and_state` 
#    to avoid segfaults when chasing forwarded buckets that are not yet fully initialized.

--- a/include/parlay_hash/parlay_hash.h
+++ b/include/parlay_hash/parlay_hash.h
@@ -99,13 +99,13 @@
     size_t list_head;
     Entry buffer[buffer_size];
     state() : list_head(0) {}
-    state(const Entry& e) : list_head(1ul << 48) {
+    state(const Entry& e) : list_head(1ul << 56) {
       buffer[0] = e;
     }
     static constexpr size_t forwarded_val = 1ul;
     
     size_t make_head(link* l, size_t bsize) {
-      return (((size_t) l) | (bsize << 48)); }
+      return (((size_t) l) | (bsize << 56)); }
 
     // update overflow list with new ptr (assumes buffer is full)
     state(const state& s, link* ptr)
@@ -156,15 +156,15 @@
     bool is_forwarded() const {return list_head == forwarded_val ;}
 
     // number of entries in buffer, or buffer_size+1 if overflow
-    long buffer_cnt() const {return (list_head >> 48) & 255ul ;}
+    long buffer_cnt() const {return (list_head >> 56) & 255ul ;}
 
     // number of entries in bucket (includes those in the overflow list)
     long size() const {
       if (buffer_cnt() <= buffer_size) return buffer_cnt();
       return buffer_size + list_length(overflow_list());
     }
 
     // get the overflow list
     link* overflow_list() const {
-      return (link*) (list_head & ((1ul << 48) - 1));}
+      return (link*) (list_head & ((1ul << 56) - 1));}
   };
 
   // returns std::optional(f(entry)) for entry with given key
@@ -411,14 +411,17 @@
 
   // Copies a bucket into grow_factor new buckets.
   void copy_bucket(table_version* t, table_version* next, long i) {
-    long exp_start = i * grow_factor;
-    // Clear grow_factor buckets in the next table version to put them in.
-    for (int j = exp_start; j < exp_start + grow_factor; j++)
-      initialize(next->buckets[j]); 
-    // copy bucket to grow_factor new buckets in next table version
-    while (true) {
-      // the bucket to copy
-      auto [s, tag] = t->buckets[i].v.ll();
+    get_locks().try_lock(i, [&] {
+      if (t->buckets[i].v.load().is_forwarded()) return true;
+      long exp_start = i * grow_factor;
+      // Clear grow_factor buckets in the next table version to put them in.
+      for (int j = exp_start; j < exp_start + grow_factor; j++)
+        initialize(next->buckets[j]); 
+      // copy bucket to grow_factor new buckets in next table version
+      while (true) {
+        // the bucket to copy
+        auto [s, tag] = t->buckets[i].v.ll();
+        if (s.is_forwarded()) break;
 
       // insert into grow_factor buckets (states) for next larger table
       state hold[grow_factor];
@@ -447,6 +450,8 @@
 	next->buckets[j].v.store_sequential(state());
       }
     }
+    return true;
+    });
   }
 
   // If copying is ongoing (i.e., next is not null), and if the the
@@ -602,10 +607,15 @@
 			      big_atomic<state>*& b, state& s, tag_type& tag, long& idx) {
     if (s.is_forwarded()) {
       table_version* nxt = t->next.load();
-      idx = nxt->get_index(k);
-      b = &(nxt->buckets[idx].v);
-      std::tie(s, tag) = b->ll();
-      check_bucket_and_state(nxt, k, b, s, tag, idx);
+      if (!nxt) return;
+      long next_idx = nxt->get_index(k);
+      big_atomic<state>* next_b = &(nxt->buckets[next_idx].v);
+      auto [next_s, next_tag] = next_b->ll();
+      check_bucket_and_state(nxt, k, next_b, next_s, next_tag, next_idx);
+      b = next_b;
+      s = next_s;
+      tag = next_tag;
+      idx = next_idx;
     }
   }
 
@@ -724,10 +734,10 @@
     -> std::optional<typename std::invoke_result<G,Entry>::type>
   {
     using rtype = std::optional<typename std::invoke_result<G,Entry>::type>;
-    table_version* ht = current_table_version.load();
-    long idx = ht->get_index(key);
-    auto b = &(ht->buckets[idx].v);
     return epoch::with_epoch([&] () -> rtype {
+      table_version* ht = current_table_version.load();
+      long idx = ht->get_index(key);
+      auto b = &(ht->buckets[idx].v);
       int delay = 200;
       while (true) {
 	auto [s, tag] = b->ll();
@@ -792,11 +802,11 @@
     -> std::optional<typename std::invoke_result<F,Entry>::type>
   {
     using rtype = std::optional<typename std::invoke_result<F,Entry>::type>;
-    table_version* ht = current_table_version.load();
-    long idx = ht->get_index(key);
-    auto b = &(ht->buckets[idx].v);
     // if entries are direct safe to scan the buffer without epoch protection
     if constexpr (Entry::Direct) {
+      table_version* ht = current_table_version.load();
+      long idx = ht->get_index(key);
+      auto b = &(ht->buckets[idx].v);
       auto [s, tag] = b->ll();
       copy_if_needed(ht, idx);
       check_bucket_and_state(ht, key, b, s, tag, idx);
@@ -812,6 +822,9 @@
     }
     // if buffer overfull, or indirect, then need to protect
     return epoch::with_epoch([&] () -> rtype {
+      table_version* ht = current_table_version.load();
+      long idx = ht->get_index(key);
+      auto b = &(ht->buckets[idx].v);
       int delay = 200;
       while (true) {
         auto [s, tag] = b->ll();
@@ -1064,12 +1064,12 @@
       static constexpr bool Direct = false;
       Data* ptr;
       static Data* tag_ptr(size_t hashv, Data* data) {
-	return (Data*) (((hashv >> 48) << 48) | ((size_t) data));
+	return (Data*) (((hashv >> 56) << 56) | ((size_t) data));
       }
       Data* get_ptr() const {
-	return (Data*) (((size_t) ptr) & ((1ul << 48) - 1)); }
+	return (Data*) (((size_t) ptr) & ((1ul << 56) - 1)); }
       static unsigned long hash(const Key& k) {
 	return k.second;}
       bool equal(const Key& k) const {
-	return (((k.second >> 48) == (((size_t) ptr) >> 48)) &&
+	return (((k.second >> 56) == (((size_t) ptr) >> 56)) &&
 		KeyEqual{}(DataS::get_key(*get_ptr()), *k.first)); }

@gblelloch

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions