@@ -306,6 +306,7 @@ typedef struct cxl_window {
306306 MPI_Comm comm ; // Associated communicator
307307 bool cxl_enabled ; // Whether CXL acceleration is enabled
308308 uint32_t win_id ; // Window ID
309+ _Atomic int pending_mpi_rma ; // Count of RMA ops that fell back to MPI (not CXL direct)
309310 struct cxl_window * next ; // Linked list for tracking
310311} cxl_window_t ;
311312
@@ -1836,6 +1837,7 @@ static cxl_window_t *register_cxl_window(MPI_Win win, void *base, size_t size,
18361837 cxl_win -> win_id = atomic_fetch_add (& g_next_win_id , 1 );
18371838 cxl_win -> cxl_enabled = false;
18381839 cxl_win -> shm = NULL ;
1840+ atomic_store (& cxl_win -> pending_mpi_rma , 0 );
18391841
18401842 // Allocate shared memory for window metadata if CXL is available
18411843 if (g_cxl .initialized && g_cxl .cxl_comm_enabled ) {
@@ -1993,11 +1995,29 @@ int MPI_Win_allocate(MPI_Aint size, int disp_unit, MPI_Info info,
19931995 MPI_Comm comm , void * baseptr , MPI_Win * win ) {
19941996 LOG_DEBUG ("MPI_Win_allocate: size=%ld, disp_unit=%d\n" , (long )size , disp_unit );
19951997
1996- // Try to allocate from CXL memory first
1998+ int rank , comm_size ;
1999+ MPI_Comm_rank (comm , & rank );
2000+ MPI_Comm_size (comm , & comm_size );
2001+
2002+ // Try to allocate from CXL memory first.
2003+ // Rank 0 allocates a single block for ALL ranks to avoid cross-VM
2004+ // allocation races (atomic_fetch_add may not be coherent on CXL Type-3).
19972005 void * cxl_base = NULL ;
19982006 if (g_cxl .initialized && g_cxl .cxl_comm_enabled ) {
1999- cxl_base = allocate_cxl_memory (size );
2000- if (cxl_base ) {
2007+ LOAD_ORIGINAL (MPI_Bcast );
2008+ cxl_rptr_t block_rptr = CXL_RPTR_NULL ;
2009+ MPI_Aint alloc_size = size > 0 ? size : (MPI_Aint )CXL_ALIGNMENT ;
2010+
2011+ if (rank == 0 ) {
2012+ void * block = allocate_cxl_memory ((size_t )alloc_size * comm_size );
2013+ if (block ) {
2014+ block_rptr = ptr_to_rptr (block );
2015+ }
2016+ }
2017+ orig_MPI_Bcast (& block_rptr , sizeof (block_rptr ), MPI_BYTE , 0 , comm );
2018+
2019+ if (block_rptr != CXL_RPTR_NULL ) {
2020+ cxl_base = (char * )rptr_to_ptr (block_rptr ) + (size_t )rank * alloc_size ;
20012021 LOG_DEBUG ("MPI_Win_allocate: using CXL memory at %p (rptr=0x%lx)\n" ,
20022022 cxl_base , ptr_to_rptr (cxl_base ));
20032023 }
@@ -2016,10 +2036,6 @@ int MPI_Win_allocate(MPI_Aint size, int disp_unit, MPI_Info info,
20162036 }
20172037
20182038 if (ret == MPI_SUCCESS ) {
2019- int rank , comm_size ;
2020- MPI_Comm_rank (comm , & rank );
2021- MPI_Comm_size (comm , & comm_size );
2022-
20232039 cxl_window_t * cxl_win = register_cxl_window (* win , * (void * * )baseptr , size ,
20242040 rank , comm_size , comm );
20252041 if (cxl_win && cxl_win -> shm ) {
@@ -2099,7 +2115,9 @@ int MPI_Put(const void *origin_addr, int origin_count, MPI_Datatype origin_datat
20992115 }
21002116
21012117put_fallback :
2102- // Fallback to MPI
2118+ // Fallback to MPI - track so flush knows to call orig
2119+ if (cxl_win )
2120+ atomic_fetch_add (& cxl_win -> pending_mpi_rma , 1 );
21032121 return orig_MPI_Put (origin_addr , origin_count , origin_datatype ,
21042122 target_rank , target_disp , target_count ,
21052123 target_datatype , win );
@@ -2157,7 +2175,9 @@ int MPI_Get(void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
21572175 }
21582176
21592177get_fallback :
2160- // Fallback to MPI
2178+ // Fallback to MPI - track so flush knows to call orig
2179+ if (cxl_win )
2180+ atomic_fetch_add (& cxl_win -> pending_mpi_rma , 1 );
21612181 return orig_MPI_Get (origin_addr , origin_count , origin_datatype ,
21622182 target_rank , target_disp , target_count ,
21632183 target_datatype , win );
@@ -2214,7 +2234,9 @@ int MPI_Accumulate(const void *origin_addr, int origin_count, MPI_Datatype origi
22142234 }
22152235
22162236acc_fallback :
2217- // Fallback to MPI
2237+ // Fallback to MPI - track so flush knows to call orig
2238+ if (cxl_win )
2239+ atomic_fetch_add (& cxl_win -> pending_mpi_rma , 1 );
22182240 return orig_MPI_Accumulate (origin_addr , origin_count , origin_datatype ,
22192241 target_rank , target_disp , target_count ,
22202242 target_datatype , op , win );
@@ -2286,6 +2308,7 @@ int MPI_Win_lock(int lock_type, int rank, int assert, MPI_Win win) {
22862308 }
22872309
22882310 __atomic_thread_fence (__ATOMIC_ACQUIRE );
2311+ atomic_store (& cxl_win -> pending_mpi_rma , 0 );
22892312 LOG_DEBUG ("MPI_Win_lock: CXL lock acquired (type=%d, rank=%d)\n" , lock_type , rank );
22902313 }
22912314
@@ -2321,11 +2344,24 @@ int MPI_Win_unlock(int rank, MPI_Win win) {
23212344int MPI_Win_flush (int rank , MPI_Win win ) {
23222345 LOG_DEBUG ("MPI_Win_flush: rank=%d\n" , rank );
23232346
2324- LOAD_ORIGINAL (MPI_Win_flush );
2325-
23262347 // Memory fence for CXL
23272348 __atomic_thread_fence (__ATOMIC_SEQ_CST );
23282349
2350+ // If all RMA ops were handled via CXL direct, skip the MPI-level flush.
2351+ // Calling orig_MPI_Win_flush on windows created via MPI_Win_create with CXL
2352+ // memory can trigger UCX OSC errors and eventually crash the transport.
2353+ cxl_window_t * cxl_win = find_cxl_window (win );
2354+ if (cxl_win && cxl_win -> cxl_enabled ) {
2355+ int pending = atomic_load (& cxl_win -> pending_mpi_rma );
2356+ if (pending == 0 ) {
2357+ LOG_DEBUG ("MPI_Win_flush: CXL-only epoch, skipping MPI flush\n" );
2358+ return MPI_SUCCESS ;
2359+ }
2360+ // Some ops fell back to MPI - must flush those
2361+ atomic_store (& cxl_win -> pending_mpi_rma , 0 );
2362+ }
2363+
2364+ LOAD_ORIGINAL (MPI_Win_flush );
23292365 return orig_MPI_Win_flush (rank , win );
23302366}
23312367
0 commit comments