# Patch overview and rationale:
#
# 1. 57-bit VMA Pointer Truncation Fix:
# Increases bitmask shift from 48 to 56 bits in `IndirectEntries` and `state::list_head`.
# This prevents truncation of pointers on systems with 5-level paging (57-bit VMA)
# where memory allocations exceed the 48-bit address boundary, preventing segmentation faults.
#
# 2. Epoch GC Use-After-Free Fix:
# Moves `current_table_version.load()` inside `epoch::with_epoch` blocks across various
# functions (Upsert, Remove, Find). This guarantees the table version is not garbage
# collected by a concurrent thread before the epoch protection is established.
#
# 3. Concurrent Rehash Data Race Fix:
# Adds `try_lock` protection in `copy_bucket` to prevent multiple threads from overwriting
# the same bucket during expansion. Also adds null pointer checks in `check_bucket_and_state`
# to avoid segfaults when chasing forwarded buckets that are not yet fully initialized.
--- a/include/parlay_hash/parlay_hash.h
+++ b/include/parlay_hash/parlay_hash.h
@@ -99,13 +99,13 @@
size_t list_head;
Entry buffer[buffer_size];
state() : list_head(0) {}
- state(const Entry& e) : list_head(1ul << 48) {
+ state(const Entry& e) : list_head(1ul << 56) {
buffer[0] = e;
}
static constexpr size_t forwarded_val = 1ul;
size_t make_head(link* l, size_t bsize) {
- return (((size_t) l) | (bsize << 48)); }
+ return (((size_t) l) | (bsize << 56)); }
// update overflow list with new ptr (assumes buffer is full)
state(const state& s, link* ptr)
@@ -156,15 +156,15 @@
bool is_forwarded() const {return list_head == forwarded_val ;}
// number of entries in buffer, or buffer_size+1 if overflow
- long buffer_cnt() const {return (list_head >> 48) & 255ul ;}
+ long buffer_cnt() const {return (list_head >> 56) & 255ul ;}
// number of entries in bucket (includes those in the overflow list)
long size() const {
if (buffer_cnt() <= buffer_size) return buffer_cnt();
return buffer_size + list_length(overflow_list());
}
// get the overflow list
link* overflow_list() const {
- return (link*) (list_head & ((1ul << 48) - 1));}
+ return (link*) (list_head & ((1ul << 56) - 1));}
};
// returns std::optional(f(entry)) for entry with given key
@@ -411,14 +411,17 @@
// Copies a bucket into grow_factor new buckets.
void copy_bucket(table_version* t, table_version* next, long i) {
- long exp_start = i * grow_factor;
- // Clear grow_factor buckets in the next table version to put them in.
- for (int j = exp_start; j < exp_start + grow_factor; j++)
- initialize(next->buckets[j]);
- // copy bucket to grow_factor new buckets in next table version
- while (true) {
- // the bucket to copy
- auto [s, tag] = t->buckets[i].v.ll();
+ get_locks().try_lock(i, [&] {
+ if (t->buckets[i].v.load().is_forwarded()) return true;
+ long exp_start = i * grow_factor;
+ // Clear grow_factor buckets in the next table version to put them in.
+ for (int j = exp_start; j < exp_start + grow_factor; j++)
+ initialize(next->buckets[j]);
+ // copy bucket to grow_factor new buckets in next table version
+ while (true) {
+ // the bucket to copy
+ auto [s, tag] = t->buckets[i].v.ll();
+ if (s.is_forwarded()) break;
// insert into grow_factor buckets (states) for next larger table
state hold[grow_factor];
@@ -447,6 +450,8 @@
next->buckets[j].v.store_sequential(state());
}
}
+ return true;
+ });
}
// If copying is ongoing (i.e., next is not null), and if the the
@@ -602,10 +607,15 @@
big_atomic<state>*& b, state& s, tag_type& tag, long& idx) {
if (s.is_forwarded()) {
table_version* nxt = t->next.load();
- idx = nxt->get_index(k);
- b = &(nxt->buckets[idx].v);
- std::tie(s, tag) = b->ll();
- check_bucket_and_state(nxt, k, b, s, tag, idx);
+ if (!nxt) return;
+ long next_idx = nxt->get_index(k);
+ big_atomic<state>* next_b = &(nxt->buckets[next_idx].v);
+ auto [next_s, next_tag] = next_b->ll();
+ check_bucket_and_state(nxt, k, next_b, next_s, next_tag, next_idx);
+ b = next_b;
+ s = next_s;
+ tag = next_tag;
+ idx = next_idx;
}
}
@@ -724,10 +734,10 @@
-> std::optional<typename std::invoke_result<G,Entry>::type>
{
using rtype = std::optional<typename std::invoke_result<G,Entry>::type>;
- table_version* ht = current_table_version.load();
- long idx = ht->get_index(key);
- auto b = &(ht->buckets[idx].v);
return epoch::with_epoch([&] () -> rtype {
+ table_version* ht = current_table_version.load();
+ long idx = ht->get_index(key);
+ auto b = &(ht->buckets[idx].v);
int delay = 200;
while (true) {
auto [s, tag] = b->ll();
@@ -792,11 +802,11 @@
-> std::optional<typename std::invoke_result<F,Entry>::type>
{
using rtype = std::optional<typename std::invoke_result<F,Entry>::type>;
- table_version* ht = current_table_version.load();
- long idx = ht->get_index(key);
- auto b = &(ht->buckets[idx].v);
// if entries are direct safe to scan the buffer without epoch protection
if constexpr (Entry::Direct) {
+ table_version* ht = current_table_version.load();
+ long idx = ht->get_index(key);
+ auto b = &(ht->buckets[idx].v);
auto [s, tag] = b->ll();
copy_if_needed(ht, idx);
check_bucket_and_state(ht, key, b, s, tag, idx);
@@ -812,6 +822,9 @@
}
// if buffer overfull, or indirect, then need to protect
return epoch::with_epoch([&] () -> rtype {
+ table_version* ht = current_table_version.load();
+ long idx = ht->get_index(key);
+ auto b = &(ht->buckets[idx].v);
int delay = 200;
while (true) {
auto [s, tag] = b->ll();
@@ -1064,12 +1064,12 @@
static constexpr bool Direct = false;
Data* ptr;
static Data* tag_ptr(size_t hashv, Data* data) {
- return (Data*) (((hashv >> 48) << 48) | ((size_t) data));
+ return (Data*) (((hashv >> 56) << 56) | ((size_t) data));
}
Data* get_ptr() const {
- return (Data*) (((size_t) ptr) & ((1ul << 48) - 1)); }
+ return (Data*) (((size_t) ptr) & ((1ul << 56) - 1)); }
static unsigned long hash(const Key& k) {
return k.second;}
bool equal(const Key& k) const {
- return (((k.second >> 48) == (((size_t) ptr) >> 48)) &&
+ return (((k.second >> 56) == (((size_t) ptr) >> 56)) &&
KeyEqual{}(DataS::get_key(*get_ptr()), *k.first)); }
@gblelloch