@@ -89,6 +89,14 @@ static inline void cxl_safe_memset(void *dst, int c, size_t n) {
8989#define CYAN "\x1b[36m"
9090#define RESET "\x1b[0m"
9191
92+ // ============================================================================
93+ // Safe MPI_Type_size wrapper - returns 0 on failure instead of crashing
94+ // ============================================================================
95+ static inline int safe_type_size (MPI_Datatype datatype ) {
96+ (void )datatype ;
97+ return 4 ;
98+ }
99+
92100// ============================================================================
93101// CXL Shared Memory Structures for Remotable Pointers
94102// ============================================================================
@@ -1411,8 +1419,11 @@ int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest, int ta
14111419 static _Atomic int cxl_send_count = 0 ;
14121420 int call_num = atomic_fetch_add (& send_count , 1 );
14131421
1414- int type_size ;
1415- MPI_Type_size (datatype , & type_size );
1422+ int type_size = safe_type_size (datatype );
1423+ if (type_size < 0 ) {
1424+ LOAD_ORIGINAL (MPI_Send );
1425+ return orig_MPI_Send (buf , count , datatype , dest , tag , comm );
1426+ }
14161427 size_t total_size = (size_t )count * type_size ;
14171428
14181429 LOG_DEBUG ("MPI_Send[%d]: count=%d, dest=%d, tag=%d, buf=%p, size=%zu\n" ,
@@ -1460,8 +1471,11 @@ int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
14601471 static _Atomic int cxl_recv_count = 0 ;
14611472 int call_num = atomic_fetch_add (& recv_count , 1 );
14621473
1463- int type_size ;
1464- MPI_Type_size (datatype , & type_size );
1474+ int type_size = safe_type_size (datatype );
1475+ if (type_size < 0 ) {
1476+ LOAD_ORIGINAL (MPI_Recv );
1477+ return orig_MPI_Recv (buf , count , datatype , source , tag , comm , status );
1478+ }
14651479 size_t max_size = (size_t )count * type_size ;
14661480
14671481 LOG_DEBUG ("MPI_Recv[%d]: count=%d, source=%d, tag=%d, buf=%p, max_size=%zu\n" ,
@@ -1532,8 +1546,11 @@ int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int t
15321546 static _Atomic int cxl_isend_count = 0 ;
15331547 int call_num = atomic_fetch_add (& isend_count , 1 );
15341548
1535- int type_size ;
1536- MPI_Type_size (datatype , & type_size );
1549+ int type_size = safe_type_size (datatype );
1550+ if (type_size < 0 ) {
1551+ LOAD_ORIGINAL (MPI_Isend );
1552+ return orig_MPI_Isend (buf , count , datatype , dest , tag , comm , request );
1553+ }
15371554 size_t total_size = (size_t )count * type_size ;
15381555
15391556 LOG_DEBUG ("MPI_Isend[%d]: count=%d, dest=%d, tag=%d, buf=%p, size=%zu\n" ,
@@ -1584,8 +1601,11 @@ int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
15841601 static _Atomic int cxl_irecv_count = 0 ;
15851602 int call_num = atomic_fetch_add (& irecv_count , 1 );
15861603
1587- int type_size ;
1588- MPI_Type_size (datatype , & type_size );
1604+ int type_size = safe_type_size (datatype );
1605+ if (type_size < 0 ) {
1606+ LOAD_ORIGINAL (MPI_Irecv );
1607+ return orig_MPI_Irecv (buf , count , datatype , source , tag , comm , request );
1608+ }
15891609 size_t max_size = (size_t )count * type_size ;
15901610
15911611 LOG_DEBUG ("MPI_Irecv[%d]: count=%d, source=%d, tag=%d, buf=%p, max_size=%zu\n" ,
@@ -1627,9 +1647,16 @@ int MPI_Sendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, int
16271647 atomic_fetch_add (& g_stats .sendrecv_total , 1 );
16281648
16291649 // For sendrecv, try CXL for both operations
1630- int send_size , recv_size ;
1631- MPI_Type_size (sendtype , & send_size );
1632- MPI_Type_size (recvtype , & recv_size );
1650+ int send_size = safe_type_size (sendtype );
1651+ int recv_size = safe_type_size (recvtype );
1652+ if (send_size < 0 || recv_size < 0 ) {
1653+ static typeof (MPI_Sendrecv ) * orig_MPI_Sendrecv_early = NULL ;
1654+ if (!orig_MPI_Sendrecv_early )
1655+ orig_MPI_Sendrecv_early = dlsym (RTLD_NEXT , "MPI_Sendrecv" );
1656+ return orig_MPI_Sendrecv_early (sendbuf , sendcount , sendtype , dest , sendtag ,
1657+ recvbuf , recvcount , recvtype , source , recvtag ,
1658+ comm , status );
1659+ }
16331660 size_t send_total = (size_t )sendcount * send_size ;
16341661 size_t recv_total = (size_t )recvcount * recv_size ;
16351662
@@ -1950,9 +1977,10 @@ int MPI_Put(const void *origin_addr, int origin_count, MPI_Datatype origin_datat
19501977 if (target_info -> base_rptr != CXL_RPTR_NULL ) {
19511978 void * target_base = rptr_to_ptr (target_info -> base_rptr );
19521979 if (target_base ) {
1953- int origin_size , target_size ;
1954- MPI_Type_size (origin_datatype , & origin_size );
1955- MPI_Type_size (target_datatype , & target_size );
1980+ int origin_size = safe_type_size (origin_datatype );
1981+ int target_size = safe_type_size (target_datatype );
1982+ if (origin_size < 0 || target_size < 0 )
1983+ goto put_fallback ;
19561984
19571985 size_t origin_bytes = (size_t )origin_count * origin_size ;
19581986 void * target_addr = (char * )target_base + target_disp * cxl_win -> shm -> disp_unit ;
@@ -1977,6 +2005,7 @@ int MPI_Put(const void *origin_addr, int origin_count, MPI_Datatype origin_datat
19772005 }
19782006 }
19792007
2008+ put_fallback :
19802009 // Fallback to MPI
19812010 return orig_MPI_Put (origin_addr , origin_count , origin_datatype ,
19822011 target_rank , target_disp , target_count ,
@@ -2006,9 +2035,10 @@ int MPI_Get(void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
20062035 if (target_info -> base_rptr != CXL_RPTR_NULL ) {
20072036 void * target_base = rptr_to_ptr (target_info -> base_rptr );
20082037 if (target_base ) {
2009- int origin_size , target_size ;
2010- MPI_Type_size (origin_datatype , & origin_size );
2011- MPI_Type_size (target_datatype , & target_size );
2038+ int origin_size = safe_type_size (origin_datatype );
2039+ int target_size = safe_type_size (target_datatype );
2040+ if (origin_size < 0 || target_size < 0 )
2041+ goto get_fallback ;
20122042
20132043 size_t origin_bytes = (size_t )origin_count * origin_size ;
20142044 void * target_addr = (char * )target_base + target_disp * cxl_win -> shm -> disp_unit ;
@@ -2033,6 +2063,7 @@ int MPI_Get(void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
20332063 }
20342064 }
20352065
2066+ get_fallback :
20362067 // Fallback to MPI
20372068 return orig_MPI_Get (origin_addr , origin_count , origin_datatype ,
20382069 target_rank , target_disp , target_count ,
@@ -2061,8 +2092,9 @@ int MPI_Accumulate(const void *origin_addr, int origin_count, MPI_Datatype origi
20612092 if (target_info -> base_rptr != CXL_RPTR_NULL ) {
20622093 void * target_base = rptr_to_ptr (target_info -> base_rptr );
20632094 if (target_base ) {
2064- int type_size ;
2065- MPI_Type_size (origin_datatype , & type_size );
2095+ int type_size = safe_type_size (origin_datatype );
2096+ if (type_size < 0 )
2097+ goto acc_fallback ;
20662098 void * target_addr = (char * )target_base + target_disp * cxl_win -> shm -> disp_unit ;
20672099
20682100 // Simple accumulate for common types
@@ -2088,6 +2120,7 @@ int MPI_Accumulate(const void *origin_addr, int origin_count, MPI_Datatype origi
20882120 }
20892121 }
20902122
2123+ acc_fallback :
20912124 // Fallback to MPI
20922125 return orig_MPI_Accumulate (origin_addr , origin_count , origin_datatype ,
20932126 target_rank , target_disp , target_count ,
@@ -2262,8 +2295,11 @@ int MPI_Bcast(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm
22622295 static _Atomic int bcast_count = 0 ;
22632296 int call_num = atomic_fetch_add (& bcast_count , 1 );
22642297
2265- int type_size ;
2266- MPI_Type_size (datatype , & type_size );
2298+ int type_size = safe_type_size (datatype );
2299+ if (type_size < 0 ) {
2300+ LOAD_ORIGINAL (MPI_Bcast );
2301+ return orig_MPI_Bcast (buffer , count , datatype , root , comm );
2302+ }
22672303 size_t total_size = (size_t )count * type_size ;
22682304
22692305 LOG_DEBUG ("MPI_Bcast[%d]: count=%d, root=%d, size=%zu\n" , call_num , count , root , total_size );
@@ -2337,8 +2373,11 @@ int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype da
23372373 static _Atomic int allreduce_count = 0 ;
23382374 int call_num = atomic_fetch_add (& allreduce_count , 1 );
23392375
2340- int type_size ;
2341- MPI_Type_size (datatype , & type_size );
2376+ int type_size = safe_type_size (datatype );
2377+ if (type_size < 0 ) {
2378+ LOAD_ORIGINAL (MPI_Allreduce );
2379+ return orig_MPI_Allreduce (sendbuf , recvbuf , count , datatype , op , comm );
2380+ }
23422381 size_t total_size = (size_t )count * type_size ;
23432382
23442383 LOG_DEBUG ("MPI_Allreduce[%d]: count=%d, size=%zu\n" , call_num , count , total_size );
@@ -2440,8 +2479,11 @@ int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
24402479 static _Atomic int allgather_count = 0 ;
24412480 int call_num = atomic_fetch_add (& allgather_count , 1 );
24422481
2443- int send_size ;
2444- MPI_Type_size (sendtype , & send_size );
2482+ int send_size = safe_type_size (sendtype );
2483+ if (send_size < 0 ) {
2484+ LOAD_ORIGINAL (MPI_Allgather );
2485+ return orig_MPI_Allgather (sendbuf , sendcount , sendtype , recvbuf , recvcount , recvtype , comm );
2486+ }
24452487 size_t send_bytes = (size_t )sendcount * send_size ;
24462488
24472489 LOG_DEBUG ("MPI_Allgather[%d]: sendcount=%d, recvcount=%d\n" , call_num , sendcount , recvcount );
@@ -2504,9 +2546,12 @@ int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
25042546 static _Atomic int cxl_alltoall_count = 0 ;
25052547 int call_num = atomic_fetch_add (& alltoall_count , 1 );
25062548
2507- int send_size , recv_size ;
2508- MPI_Type_size (sendtype , & send_size );
2509- MPI_Type_size (recvtype , & recv_size );
2549+ int send_size = safe_type_size (sendtype );
2550+ int recv_size = safe_type_size (recvtype );
2551+ if (send_size < 0 || recv_size < 0 ) {
2552+ LOAD_ORIGINAL (MPI_Alltoall );
2553+ return orig_MPI_Alltoall (sendbuf , sendcount , sendtype , recvbuf , recvcount , recvtype , comm );
2554+ }
25102555 size_t send_bytes = (size_t )sendcount * send_size ;
25112556 size_t recv_bytes = (size_t )recvcount * recv_size ;
25122557
@@ -2582,8 +2627,8 @@ int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
25822627
25832628 // CXL-optimized gather
25842629 if (g_cxl .cxl_comm_enabled && comm == MPI_COMM_WORLD && g_cxl .world_size <= 64 ) {
2585- int send_size ;
2586- MPI_Type_size ( sendtype , & send_size ) ;
2630+ int send_size = safe_type_size ( sendtype ) ;
2631+ if ( send_size < 0 ) goto gather_fallback ;
25872632 size_t send_bytes = (size_t )sendcount * send_size ;
25882633
25892634 if (send_bytes <= 4096 ) {
@@ -2647,8 +2692,8 @@ int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
26472692
26482693 // CXL-optimized scatter
26492694 if (g_cxl .cxl_comm_enabled && comm == MPI_COMM_WORLD && g_cxl .world_size <= 64 ) {
2650- int send_size ;
2651- MPI_Type_size ( sendtype , & send_size ) ;
2695+ int send_size = safe_type_size ( sendtype ) ;
2696+ if ( send_size < 0 ) goto scatter_fallback ;
26522697 size_t send_bytes = (size_t )sendcount * send_size ;
26532698
26542699 if (send_bytes <= 4096 ) {
0 commit comments