Skip to content

Commit 47f1386

Browse files
committed
try resolving one rma
1 parent 551173a commit 47f1386

File tree

1 file changed

+48
-12
lines changed

1 file changed

+48
-12
lines changed

workloads/gromacs/mpi_cxl_shim.c

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ typedef struct cxl_window {
306306
MPI_Comm comm; // Associated communicator
307307
bool cxl_enabled; // Whether CXL acceleration is enabled
308308
uint32_t win_id; // Window ID
309+
_Atomic int pending_mpi_rma; // Count of RMA ops that fell back to MPI (not CXL direct)
309310
struct cxl_window *next; // Linked list for tracking
310311
} cxl_window_t;
311312

@@ -1836,6 +1837,7 @@ static cxl_window_t *register_cxl_window(MPI_Win win, void *base, size_t size,
18361837
cxl_win->win_id = atomic_fetch_add(&g_next_win_id, 1);
18371838
cxl_win->cxl_enabled = false;
18381839
cxl_win->shm = NULL;
1840+
atomic_store(&cxl_win->pending_mpi_rma, 0);
18391841

18401842
// Allocate shared memory for window metadata if CXL is available
18411843
if (g_cxl.initialized && g_cxl.cxl_comm_enabled) {
@@ -1993,11 +1995,29 @@ int MPI_Win_allocate(MPI_Aint size, int disp_unit, MPI_Info info,
19931995
MPI_Comm comm, void *baseptr, MPI_Win *win) {
19941996
LOG_DEBUG("MPI_Win_allocate: size=%ld, disp_unit=%d\n", (long)size, disp_unit);
19951997

1996-
// Try to allocate from CXL memory first
1998+
int rank, comm_size;
1999+
MPI_Comm_rank(comm, &rank);
2000+
MPI_Comm_size(comm, &comm_size);
2001+
2002+
// Try to allocate from CXL memory first.
2003+
// Rank 0 allocates a single block for ALL ranks to avoid cross-VM
2004+
// allocation races (atomic_fetch_add may not be coherent on CXL Type-3).
19972005
void *cxl_base = NULL;
19982006
if (g_cxl.initialized && g_cxl.cxl_comm_enabled) {
1999-
cxl_base = allocate_cxl_memory(size);
2000-
if (cxl_base) {
2007+
LOAD_ORIGINAL(MPI_Bcast);
2008+
cxl_rptr_t block_rptr = CXL_RPTR_NULL;
2009+
MPI_Aint alloc_size = size > 0 ? size : (MPI_Aint)CXL_ALIGNMENT;
2010+
2011+
if (rank == 0) {
2012+
void *block = allocate_cxl_memory((size_t)alloc_size * comm_size);
2013+
if (block) {
2014+
block_rptr = ptr_to_rptr(block);
2015+
}
2016+
}
2017+
orig_MPI_Bcast(&block_rptr, sizeof(block_rptr), MPI_BYTE, 0, comm);
2018+
2019+
if (block_rptr != CXL_RPTR_NULL) {
2020+
cxl_base = (char *)rptr_to_ptr(block_rptr) + (size_t)rank * alloc_size;
20012021
LOG_DEBUG("MPI_Win_allocate: using CXL memory at %p (rptr=0x%lx)\n",
20022022
cxl_base, ptr_to_rptr(cxl_base));
20032023
}
@@ -2016,10 +2036,6 @@ int MPI_Win_allocate(MPI_Aint size, int disp_unit, MPI_Info info,
20162036
}
20172037

20182038
if (ret == MPI_SUCCESS) {
2019-
int rank, comm_size;
2020-
MPI_Comm_rank(comm, &rank);
2021-
MPI_Comm_size(comm, &comm_size);
2022-
20232039
cxl_window_t *cxl_win = register_cxl_window(*win, *(void **)baseptr, size,
20242040
rank, comm_size, comm);
20252041
if (cxl_win && cxl_win->shm) {
@@ -2099,7 +2115,9 @@ int MPI_Put(const void *origin_addr, int origin_count, MPI_Datatype origin_datat
20992115
}
21002116

21012117
put_fallback:
2102-
// Fallback to MPI
2118+
// Fallback to MPI - track so flush knows to call orig
2119+
if (cxl_win)
2120+
atomic_fetch_add(&cxl_win->pending_mpi_rma, 1);
21032121
return orig_MPI_Put(origin_addr, origin_count, origin_datatype,
21042122
target_rank, target_disp, target_count,
21052123
target_datatype, win);
@@ -2157,7 +2175,9 @@ int MPI_Get(void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
21572175
}
21582176

21592177
get_fallback:
2160-
// Fallback to MPI
2178+
// Fallback to MPI - track so flush knows to call orig
2179+
if (cxl_win)
2180+
atomic_fetch_add(&cxl_win->pending_mpi_rma, 1);
21612181
return orig_MPI_Get(origin_addr, origin_count, origin_datatype,
21622182
target_rank, target_disp, target_count,
21632183
target_datatype, win);
@@ -2214,7 +2234,9 @@ int MPI_Accumulate(const void *origin_addr, int origin_count, MPI_Datatype origi
22142234
}
22152235

22162236
acc_fallback:
2217-
// Fallback to MPI
2237+
// Fallback to MPI - track so flush knows to call orig
2238+
if (cxl_win)
2239+
atomic_fetch_add(&cxl_win->pending_mpi_rma, 1);
22182240
return orig_MPI_Accumulate(origin_addr, origin_count, origin_datatype,
22192241
target_rank, target_disp, target_count,
22202242
target_datatype, op, win);
@@ -2286,6 +2308,7 @@ int MPI_Win_lock(int lock_type, int rank, int assert, MPI_Win win) {
22862308
}
22872309

22882310
__atomic_thread_fence(__ATOMIC_ACQUIRE);
2311+
atomic_store(&cxl_win->pending_mpi_rma, 0);
22892312
LOG_DEBUG("MPI_Win_lock: CXL lock acquired (type=%d, rank=%d)\n", lock_type, rank);
22902313
}
22912314

@@ -2321,11 +2344,24 @@ int MPI_Win_unlock(int rank, MPI_Win win) {
23212344
int MPI_Win_flush(int rank, MPI_Win win) {
23222345
LOG_DEBUG("MPI_Win_flush: rank=%d\n", rank);
23232346

2324-
LOAD_ORIGINAL(MPI_Win_flush);
2325-
23262347
// Memory fence for CXL
23272348
__atomic_thread_fence(__ATOMIC_SEQ_CST);
23282349

2350+
// If all RMA ops were handled via CXL direct, skip the MPI-level flush.
2351+
// Calling orig_MPI_Win_flush on windows created via MPI_Win_create with CXL
2352+
// memory can trigger UCX OSC errors and eventually crash the transport.
2353+
cxl_window_t *cxl_win = find_cxl_window(win);
2354+
if (cxl_win && cxl_win->cxl_enabled) {
2355+
int pending = atomic_load(&cxl_win->pending_mpi_rma);
2356+
if (pending == 0) {
2357+
LOG_DEBUG("MPI_Win_flush: CXL-only epoch, skipping MPI flush\n");
2358+
return MPI_SUCCESS;
2359+
}
2360+
// Some ops fell back to MPI - must flush those
2361+
atomic_store(&cxl_win->pending_mpi_rma, 0);
2362+
}
2363+
2364+
LOAD_ORIGINAL(MPI_Win_flush);
23292365
return orig_MPI_Win_flush(rank, win);
23302366
}
23312367

0 commit comments

Comments
 (0)