diff --git a/heat/core/communication.py b/heat/core/communication.py index 7cbef8ff95..57246e110c 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -11,7 +11,9 @@ import warnings from mpi4py import MPI -from typing import Any, Callable, Optional, List, Tuple, Union +from abc import ABC, abstractmethod +from typing import Any +from collections.abc import Callable from .stride_tricks import sanitize_axis @@ -32,17 +34,17 @@ class MPIRequest: The buffer to the receive data tensor: torch.Tensor Internal Data - permutation: Tuple[int,...] + permutation: tuple[int,...] Permutation of the tensor axes """ def __init__( self, - handle, - sendbuf: Union[DNDarray, torch.Tensor, Any] = None, - recvbuf: Union[DNDarray, torch.Tensor, Any] = None, + handle: MPI.Request, + sendbuf: Any | None = None, + recvbuf: Any | None = None, tensor: torch.Tensor = None, - permutation: Tuple[int, ...] = None, + permutation: tuple[int, ...] = None, ): self.handle = handle self.tensor = tensor @@ -50,10 +52,12 @@ def __init__( self.sendbuf = sendbuf self.permutation = permutation - def Wait(self, status: MPI.Status = None): + def Wait(self, status: MPI.Status | None = None): """ Waits for an MPI request to complete """ + if self.handle is None: + return self.handle.Wait(status) if self.tensor is not None and isinstance(self.tensor, torch.Tensor): if self.permutation is not None: @@ -73,36 +77,50 @@ def __getattr__(self, name: str) -> Callable: return getattr(self.handle, name) -class Communication: +class Communication(ABC): """ Base class for Communications (inteded for other backends) """ @staticmethod - def is_distributed() -> NotImplementedError: + @abstractmethod + def is_distributed(): """ Whether or not the Communication is distributed """ - raise NotImplementedError() - - def __init__(self) -> NotImplementedError: - raise NotImplementedError() + pass - def chunk(self, shape, split) -> NotImplementedError: + @abstractmethod + def chunk( + self, + shape: tuple[int, ...], + split: int, + rank: int = None, + w_size: int = None, + sparse: bool = False, + ) -> tuple[int, tuple[int, ...], tuple[slice, ...]]: """ Calculates the chunk of data that will be assigned to this compute node given a global data shape and a split - axis. Returns ``(offset, local_shape, slices)``: the offset in the split dimension, the resulting local shape if the + axis. + Returns ``(offset, local_shape, slices)``: the offset in the split dimension, the resulting local shape if the global input shape is chunked on the split axis and the chunk slices with respect to the given shape Parameters ---------- - shape : Tuple[int,...] + shape : tuple[int,...] The global shape of the data to be split split : int The axis along which to chunk the data - + rank : int, optional + Process for which the chunking is calculated for, defaults to ``self.rank``. + Intended for creating chunk maps without communication + w_size : int, optional + The MPI world size, defaults to ``self.size``. + Intended for creating chunk maps without communication + sparse : bool, optional + Specifies whether the array is a sparse matrix """ - raise NotImplementedError() + pass class MPICommunication(Communication): @@ -150,11 +168,11 @@ class MPICommunication(Communication): if hasattr(torch, type_str): __mpi_dtype2ctype[getattr(torch, type_str)] = getattr(ctypes, f"c_{type_str}") - def __init__(self, handle=MPI.COMM_WORLD): + def __init__(self, handle: MPI.Intracomm = MPI.COMM_WORLD): self.handle = handle try: - self.rank: Optional[int] = handle.Get_rank() - self.size: Optional[int] = handle.Get_size() + self.rank = handle.Get_rank() + self.size = handle.Get_size() except MPI.Exception: # ranks not within the group will fail with an MPI.Exception, this is expected self.rank = None @@ -178,12 +196,12 @@ def is_distributed(self) -> bool: def chunk( self, - shape: Tuple[int], + shape: tuple[int, ...], split: int, - rank: int = None, - w_size: int = None, + rank: int | None = None, + w_size: int | None = None, sparse: bool = False, - ) -> Tuple[int, Tuple[int], Tuple[slice]]: + ) -> tuple[int, tuple[int, ...], tuple[slice, ...]]: """ Calculates the chunk of data that will be assigned to this compute node given a global data shape and a split axis. @@ -192,7 +210,7 @@ def chunk( Parameters ---------- - shape : Tuple[int,...] + shape : tuple[int,...] The global shape of the data to be split split : int The axis along which to chunk the data @@ -236,15 +254,15 @@ def chunk( ) def counts_displs_shape( - self, shape: Tuple[int], axis: int - ) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]: + self, shape: tuple[int, ...], axis: int + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Calculates the item counts, displacements and output shape for a variable sized all-to-all MPI-call (e.g. ``MPI_Alltoallv``). The passed shape is regularly chunk along the given axis and for all nodes. Parameters ---------- - shape : Tuple[int,...] + shape : tuple[int,...] The object for which to calculate the chunking. axis : int The axis along which the chunking is performed. @@ -277,7 +295,7 @@ def mpi_type_of(cls, dtype: torch.dtype) -> MPI.Datatype: return cls.__mpi_type_mappings[dtype] @classmethod - def _handle_large_count(cls, mpi_type: MPI.Datatype, elements: int) -> Tuple[MPI.Datatype, int]: + def _handle_large_count(cls, mpi_type: MPI.Datatype, elements: int) -> tuple[MPI.Datatype, int]: """ Handles large counts for MPI data types by creating vector types to circumvent the MAX_INT limit on certain MPI implementations. @@ -290,7 +308,7 @@ def _handle_large_count(cls, mpi_type: MPI.Datatype, elements: int) -> Tuple[MPI Returns ------- - Tuple[MPI.Datatype, int] + tuple[MPI.Datatype, int] A tuple containing the constructed MPI data type and the count (always 1 in this case) Raises @@ -326,11 +344,11 @@ def _handle_large_count(cls, mpi_type: MPI.Datatype, elements: int) -> Tuple[MPI @classmethod def mpi_type_and_elements_of( cls, - obj: Union[DNDarray, torch.Tensor], - counts: Optional[Tuple[int]], - displs: Tuple[int], - is_contiguous: Optional[bool], - ) -> Tuple[MPI.Datatype, Tuple[int, ...]]: + obj: DNDarray | torch.Tensor, + counts: tuple[int, ...] | None = None, + displs: tuple[int, ...] | None = None, + is_contiguous: bool | None = None, + ) -> tuple[MPI.Datatype, tuple[int, ...]]: """ Determines the MPI data type and number of respective elements for the given tensor (:class:`~heat.core.dndarray.DNDarray` or ``torch.Tensor). In case the tensor is contiguous in memory, a native MPI data type can be used. @@ -340,9 +358,9 @@ def mpi_type_and_elements_of( ---------- obj : DNDarray or torch.Tensor The object for which to construct the MPI data type and number of elements - counts : Tuple[ints,...], optional + counts : tuple[ints,...], optional Optional counts arguments for variable MPI-calls (e.g. Alltoallv) - displs : Tuple[ints,...], optional + displs : tuple[ints,...], optional Optional displacements arguments for variable MPI-calls (e.g. Alltoallv) is_contiguous: bool Information on global contiguity of the memory-distributed object. If `None`, it will be set to local contiguity via ``torch.Tensor.is_contiguous()``. @@ -407,10 +425,10 @@ def as_mpi_memory(cls, obj: torch.Tensor) -> MPI.memory: def as_buffer( cls, obj: torch.Tensor, - counts: Optional[Tuple[int]] = None, - displs: Optional[Tuple[int]] = None, - is_contiguous: Optional[bool] = None, - ) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]: + counts: tuple[int, ...] | None = None, + displs: tuple[int, ...] | None = None, + is_contiguous: bool | None = None, + ) -> list[MPI.memory, tuple[int, int], MPI.Datatype]: """ Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type. @@ -418,9 +436,9 @@ def as_buffer( ---------- obj : torch.Tensor The object to be converted into a buffer representation. - counts : Tuple[int,...], optional + counts : tuple[int,...], optional Optional counts arguments for variable MPI-calls (e.g. Alltoallv) - displs : Tuple[int,...], optional + displs : tuple[int,...], optional Optional displacements arguments for variable MPI-calls (e.g. Alltoallv) is_contiguous: bool, optional Optional information on global contiguity of the memory-distributed object. @@ -456,7 +474,7 @@ def _moveToCompDevice(self, x: torch.Tensor, func: Callable | None) -> torch.Ten def alltoall_sendbuffer( self, obj: torch.Tensor - ) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]: + ) -> list[MPI.memory, tuple[list[int], list[int]], list[MPI.Datatype]]: """ Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type. XXX: might not work for all MPI stacks. Might require multiple type commits or so @@ -521,7 +539,7 @@ def alltoall_sendbuffer( def alltoall_recvbuffer( self, obj: torch.Tensor - ) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]: + ) -> list[MPI.memory, tuple[list[int], list[int]], list[MPI.Datatype]]: """ Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type. XXX: might not work for all MPI stacks. Might require multiple type commits or so @@ -578,7 +596,7 @@ def Split(self, color: int = 0, key: int = 0) -> MPICommunication: def Irecv( self, - buf: Union[DNDarray, torch.Tensor, Any], + buf: Any, source: int = MPI.ANY_SOURCE, tag: int = MPI.ANY_TAG, ) -> MPIRequest: @@ -587,7 +605,7 @@ def Irecv( Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address where to place the received message source: int, optional Rank of source process, that send the message @@ -606,17 +624,17 @@ def Irecv( def Recv( self, - buf: Union[DNDarray, torch.Tensor, Any], + buf: Any, source: int = MPI.ANY_SOURCE, tag: int = MPI.ANY_TAG, - status: MPI.Status = None, - ): + status: MPI.Status | None = None, + ) -> None: """ Blocking receive Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address where to place the received message source: int, optional Rank of the source process, that send the message @@ -640,8 +658,8 @@ def Recv( Recv.__doc__ = MPI.Comm.Recv.__doc__ def __send_like( - self, func: Callable, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int - ) -> Tuple[Optional[Union[DNDarray, torch.Tensor]]]: + self, func: Callable, buf: Any, dest: int, tag: int + ) -> tuple[MPI.Request | None, torch.Tensor | None]: """ Generic function for sending a message to process with rank "dest" @@ -649,7 +667,7 @@ def __send_like( ---------- func: Callable The respective MPI sending function - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be send dest: int, optional Rank of the destination process, that receives the message @@ -666,13 +684,13 @@ def __send_like( return func(self.as_buffer(sbuf), dest, tag), sbuf - def Bsend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0): + def Bsend(self, buf: Any, dest: int, tag: int = 0) -> None: """ Blocking buffered send Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be send dest: int, optional Index of the destination process, that receives the message @@ -683,15 +701,13 @@ def Bsend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0 Bsend.__doc__ = MPI.Comm.Bsend.__doc__ - def Ibsend( - self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0 - ) -> MPIRequest: + def Ibsend(self, buf: Any, dest: int, tag: int = 0) -> MPIRequest: """ Nonblocking buffered send Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be send dest: int, optional Rank of the destination process, that receives the message @@ -702,15 +718,13 @@ def Ibsend( Ibsend.__doc__ = MPI.Comm.Ibsend.__doc__ - def Irsend( - self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0 - ) -> MPIRequest: + def Irsend(self, buf: Any, dest: int, tag: int = 0) -> MPIRequest: """ Nonblocking ready send Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be send dest: int, optional Rank of the destination process, that receives the message @@ -721,13 +735,13 @@ def Irsend( Irsend.__doc__ = MPI.Comm.Irsend.__doc__ - def Isend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0) -> MPIRequest: + def Isend(self, buf: Any, dest: int, tag: int = 0) -> MPIRequest: """ Nonblocking send Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be send dest: int, optional Rank of the destination process, that receives the message @@ -738,15 +752,13 @@ def Isend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0 Isend.__doc__ = MPI.Comm.Isend.__doc__ - def Issend( - self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0 - ) -> MPIRequest: + def Issend(self, buf: Any, dest: int, tag: int = 0) -> MPIRequest: """ Nonblocking synchronous send Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be send dest: int, optional Rank of the destination process, that receives the message @@ -757,13 +769,13 @@ def Issend( Issend.__doc__ = MPI.Comm.Issend.__doc__ - def Rsend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0): + def Rsend(self, buf: Any, dest: int, tag: int = 0) -> None: """ Blocking ready send Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be send dest: int, optional Rank of the destination process, that receives the message @@ -774,13 +786,13 @@ def Rsend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0 Rsend.__doc__ = MPI.Comm.Rsend.__doc__ - def Ssend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0): + def Ssend(self, buf: Any, dest: int, tag: int = 0) -> None: """ Blocking synchronous send Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be send dest: int, optional Rank of the destination process, that receives the message @@ -791,13 +803,13 @@ def Ssend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0 Ssend.__doc__ = MPI.Comm.Ssend.__doc__ - def Send(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0): + def Send(self, buf: Any, dest: int, tag: int = 0) -> None: """ Blocking send Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be send dest: int, optional Rank of the destination process, that receives the message @@ -809,8 +821,8 @@ def Send(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0) Send.__doc__ = MPI.Comm.Send.__doc__ def __broadcast_like( - self, func: Callable, buf: Union[DNDarray, torch.Tensor, Any], root: int - ) -> Tuple[Optional[DNDarray, torch.Tensor]]: + self, func: Callable, buf: Any, root: int + ) -> tuple[MPI.Request | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: """ Generic function for broadcasting a message from the process with rank "root" to all other processes of the communicator @@ -819,11 +831,13 @@ def __broadcast_like( ---------- func: Callable The respective MPI broadcast function - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be broadcasted root: int Rank of the root process, that broadcasts the message """ + if not self.is_distributed(): + return None, None, None, None # unpack the buffer if it is a HeAT tensor if isinstance(buf, DNDarray): buf = buf.larray @@ -835,13 +849,13 @@ def __broadcast_like( return func(self.as_buffer(srbuf), root), srbuf, srbuf, buf - def Bcast(self, buf: Union[DNDarray, torch.Tensor, Any], root: int = 0) -> None: + def Bcast(self, buf: Any, root: int = 0) -> None: """ Blocking Broadcast Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be broadcasted root: int Rank of the root process, that broadcasts the message @@ -853,13 +867,13 @@ def Bcast(self, buf: Union[DNDarray, torch.Tensor, Any], root: int = 0) -> None: Bcast.__doc__ = MPI.Comm.Bcast.__doc__ - def Ibcast(self, buf: Union[DNDarray, torch.Tensor, Any], root: int = 0) -> MPIRequest: + def Ibcast(self, buf: Any, root: int = 0) -> MPIRequest: """ Nonblocking Broadcast Parameters ---------- - buf: Union[DNDarray, torch.Tensor, Any] + buf: Any Buffer address of the message to be broadcasted root: int Rank of the root process, that broadcasts the message @@ -915,8 +929,8 @@ def _minmax_op( self, dtype: torch.dtype, total_count: int, - shape: Tuple[int], - stride: Tuple[int], + shape: tuple[int, ...], + stride: tuple[int, ...], offset: int = 0, ) -> Callable[[MPI.memory, MPI.memory, MPI.Datatype], None]: """ @@ -928,10 +942,10 @@ def _minmax_op( torch.dtype of underlying elements total_count: int Number of elements per mins OR per max (so recv buffer has 2*total_count elements) - shape: Tuple[int] + shape: tuple[int] Shape of the packed buffer that the MPI callback will operate on. This describes the logical shape of the concatenated buffer [mins; maxs] - stride: Tuple[int] + stride: tuple[int] Stride (in elements) of the packed buffer's storage, matching the layout offset: int, optional Storage offset (if needed), default 0 @@ -968,12 +982,12 @@ def op(sendbuf: MPI.memory, recvbuf: MPI.memory, datatype): def __reduce_like( self, func: Callable, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, op: MPI.Op, *args: Any, **kwargs: Any, - ) -> Tuple[Optional[DNDarray, torch.Tensor]]: + ) -> tuple[MPI.Request | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: """ Generic function for reduction operations. @@ -981,9 +995,10 @@ def __reduce_like( ---------- func: Callable The respective MPI reduction operation - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any + Buffer address of the send message. If MPI.IN_PLACE is set, + recvbuf is also used as send buffer. + recvbuf: Any Buffer address where to store the result of the reduction op: MPI.Op Operation to apply during the reduction. @@ -1003,6 +1018,16 @@ def __reduce_like( if isinstance(recvbuf, DNDarray): recvbuf = recvbuf.larray + if not self.is_distributed(): + if sendbuf is not MPI.IN_PLACE: + # set the contiguousness like in multiprocess + sendbuf.set_(sendbuf.contiguous()) + if func in (self.handle.Exscan, self.handle.Iexscan): + recvbuf.set_(recvbuf.contiguous()) + else: + recvbuf.set_(sendbuf.clone()) + return None, None, None, None + # harmonize the input and output buffers # MPI requires send and receive buffers to be of same type and length. If the torch tensors are either not both # contiguous or differently strided, they have to be made matching (if possible) first. @@ -1048,7 +1073,7 @@ def __reduce_like( # Datatype and count shall be derived from the recv buffer, and applied to both, as they should match after the last code block buf = recvbuf rbuf = self._moveToCompDevice(buf, func) - recvbuf: Tuple[MPI.memory, int, MPI.Datatype] = self.as_buffer(rbuf, is_contiguous=True) + recvbuf: tuple[MPI.memory, int, MPI.Datatype] = self.as_buffer(rbuf, is_contiguous=True) if not recvbuf[2].is_predefined: # If using a derived datatype, we need to define the reduce operation to be able to handle the it. derived_op = self.__derived_op(rbuf, recvbuf[2], op) @@ -1067,18 +1092,18 @@ def __reduce_like( def Allreduce( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, op: MPI.Op = MPI.SUM, - ): + ) -> None: """ Combines values from all processes and distributes the result back to all processes Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result of the reduction op: MPI.Op The operation to perform upon reduction @@ -1092,18 +1117,18 @@ def Allreduce( def Exscan( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, op: MPI.Op = MPI.SUM, - ): + ) -> None: """ Computes the exclusive scan (partial reductions) of data on a collection of processes Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result of the reduction op: MPI.Op The operation to perform upon reduction @@ -1117,8 +1142,8 @@ def Exscan( def Iallreduce( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, op: MPI.Op = MPI.SUM, ) -> MPIRequest: """ @@ -1126,9 +1151,9 @@ def Iallreduce( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result of the reduction op: MPI.Op The operation to perform upon reduction @@ -1139,8 +1164,8 @@ def Iallreduce( def Iexscan( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, op: MPI.Op = MPI.SUM, ) -> MPIRequest: """ @@ -1148,9 +1173,9 @@ def Iexscan( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result of the reduction op: MPI.Op The operation to perform upon reduction @@ -1161,8 +1186,8 @@ def Iexscan( def Iscan( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, op: MPI.Op = MPI.SUM, ) -> MPIRequest: """ @@ -1170,9 +1195,9 @@ def Iscan( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result of the reduction op: MPI.Op The operation to perform upon reduction @@ -1183,8 +1208,8 @@ def Iscan( def Ireduce( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, op: MPI.Op = MPI.SUM, root: int = 0, ) -> MPIRequest: @@ -1193,9 +1218,9 @@ def Ireduce( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result of the reduction op: MPI.Op The operation to perform upon reduction @@ -1208,19 +1233,19 @@ def Ireduce( def Reduce( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, op: MPI.Op = MPI.SUM, root: int = 0, - ): + ) -> None: """ Reduce values from all processes to a single value on process "root" Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result of the reduction op: MPI.Op The operation to perform upon reduction @@ -1236,18 +1261,18 @@ def Reduce( def Scan( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, op: MPI.Op = MPI.SUM, - ): + ) -> None: """ Computes the scan (partial reductions) of data on a collection of processes in a nonblocking way Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result of the reduction op: MPI.Op The operation to perform upon reduction @@ -1262,11 +1287,17 @@ def Scan( def __allgather_like( self, func: Callable, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, axis: int, **kwargs, - ): + ) -> tuple[ + MPI.Request | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + list | None, + ]: """ Generic function for allgather operations. @@ -1274,9 +1305,9 @@ def __allgather_like( ---------- func: Callable Type of MPI Allgather function (i.e. allgather, allgatherv, iallgather) - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result axis: int Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks @@ -1306,6 +1337,18 @@ def __allgather_like( f"recvbuf of type {type(recvbuf)} does not support concatenation axis != 0" ) + if not self.is_distributed(): + if isinstance(recvbuf, torch.Tensor): + if isinstance(sendbuf, np.ndarray): + sendbuf = torch.from_numpy(sendbuf) + recvbuf.copy_(sendbuf) + elif isinstance(recvbuf, np.ndarray): + if isinstance(sendbuf, torch.Tensor): + sendbuf = sendbuf.cpu().numpy() + np.copyto(sendbuf, recvbuf) + + return None, None, None, None, None + # keep a reference to the original buffer object original_recvbuf = recvbuf sbuf_is_contiguous, rbuf_is_contiguous = None, None @@ -1348,18 +1391,18 @@ def __allgather_like( def Allgather( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, recv_axis: int = 0, - ): + ) -> None: """ Gathers data from all tasks and distribute the combined data to all tasks Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result recv_axis: int Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks @@ -1377,18 +1420,18 @@ def Allgather( def Allgatherv( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, recv_axis: int = 0, - ): + ) -> None: """ v-call of Allgather: Each process may contribute a different amount of data. Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result recv_axis: int Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks @@ -1406,8 +1449,8 @@ def Allgatherv( def Iallgather( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, recv_axis: int = 0, ) -> MPIRequest: """ @@ -1415,9 +1458,9 @@ def Iallgather( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result recv_axis: int Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks @@ -1430,8 +1473,8 @@ def Iallgather( def Iallgatherv( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, recv_axis: int = 0, ): """ @@ -1439,9 +1482,9 @@ def Iallgatherv( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result recv_axis: int Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks @@ -1455,12 +1498,18 @@ def Iallgatherv( def __alltoall_like( self, func: Callable, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, send_axis: int, recv_axis: int, **kwargs, - ): + ) -> tuple[ + MPI.Request | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + list[int] | None, + ]: """ Generic function for alltoall operations. @@ -1468,9 +1517,9 @@ def __alltoall_like( ---------- func: Callable Specific alltoall function - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result send_axis: int Future split axis, along which data blocks will be created that will be send to individual ranks @@ -1509,6 +1558,24 @@ def __alltoall_like( if not isinstance(recvbuf, torch.Tensor) and send_axis != 0: raise TypeError(f"recvbuf of type {type(recvbuf)} does not support send_axis != 0") + if not self.is_distributed(): + if recv_axis > 1 and send_axis > 1 and recv_axis == send_axis: + raise NotImplementedError( + "AllToAll for same axes not supported. Please choose send_axis and recv_axis to be different." + ) + if sendbuf.shape != recvbuf.shape: + sendbuf = sendbuf.swapaxes(send_axis, recv_axis) + if isinstance(recvbuf, torch.Tensor): + if isinstance(sendbuf, np.ndarray): + sendbuf = torch.from_numpy(sendbuf) + recvbuf.copy_(sendbuf) + elif isinstance(recvbuf, np.ndarray): + if isinstance(sendbuf, torch.Tensor): + sendbuf = sendbuf.cpu().numpy() + np.copyto(sendbuf, recvbuf) + + return None, None, None, None, None + # keep a reference to the original buffer object original_recvbuf = recvbuf @@ -1578,20 +1645,20 @@ def __alltoall_like( def Alltoall( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, send_axis: int = 0, recv_axis: int = None, - ): + ) -> None: """ All processes send data to all processes: The jth block sent from process i is received by process j and is placed in the ith block of recvbuf. Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result send_axis: int Future split axis, along which data blocks will be created that will be send to individual ranks @@ -1614,20 +1681,20 @@ def Alltoall( def Alltoallv( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, send_axis: int = 0, recv_axis: int = None, - ): + ) -> None: """ v-call of Alltoall: All processes send different amount of data to, and receive different amount of data from, all processes Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result send_axis: int Future split axis, along which data blocks will be created that will be send to individual ranks @@ -1650,19 +1717,18 @@ def Alltoallv( def Alltoallw( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - ): + sendbuf: Any, + recvbuf: Any, + ) -> None: """ Generalized All-to-All communication allowing different counts, displacements and datatypes for each partner. See MPI standard for more information. Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message. The buffer is expected to be a tuple of the form (buffer, (counts, displacements), subarray_params_list), where subarray_params_list is a list of tuples of the form (lshape, subsizes, substarts). - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result. The buffer is expected to be a tuple of the form (buffer, (counts, displacements), subarray_params_list), where subarray_params_list is a list of tuples of the form (lshape, subsizes, substarts). - """ # Unpack sendbuffer information sendbuf_tensor, (send_counts, send_displs), subarray_params_list = sendbuf @@ -1765,9 +1831,9 @@ def Alltoallw( def _create_recursive_vectortype( self, datatype: MPI.Datatype, - tensor_stride: Tuple[int], - subarray_sizes: List[int], - start: List[int], + tensor_stride: tuple[int, ...], + subarray_sizes: list[int], + start: list[int], ) -> MPI.Datatype: """ Create a recursive vector to handle non-contiguous tensor data. The created datatype will be a recursively defined vector datatype that will enable the collection of non-contiguous tensor data in the specified subarray sizes. @@ -1776,11 +1842,11 @@ def _create_recursive_vectortype( ---------- datatype : MPI.Datatype The base datatype to create the recursive vector datatype from. - tensor_stride : Tuple[int] + tensor_stride : tuple[int] A list of tensor strides for each dimension. subarray_sizes : List[int] A list of subarray sizes for each dimension. - start: List[int] + start: list[int] Index of the first element of the subarray in the original array. Notes @@ -1848,8 +1914,8 @@ def _create_recursive_vectortype( def Ialltoall( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, send_axis: int = 0, recv_axis: int = None, ) -> MPIRequest: @@ -1858,9 +1924,9 @@ def Ialltoall( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result send_axis: int Future split axis, along which data blocks will be created that will be send to individual ranks @@ -1878,8 +1944,8 @@ def Ialltoall( def Ialltoallv( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, send_axis: int = 0, recv_axis: int = None, ) -> MPIRequest: @@ -1889,9 +1955,9 @@ def Ialltoallv( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result send_axis: int Future split axis, along which data blocks will be created that will be send to individual ranks @@ -1910,14 +1976,20 @@ def Ialltoallv( def __gather_like( self, func: Callable, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, send_axis: int, recv_axis: int, send_factor: int = 1, recv_factor: int = 1, **kwargs, - ): + ) -> tuple[ + MPI.Request | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + list[int] | None, + ]: """ Generic function for gather operations. @@ -1925,9 +1997,9 @@ def __gather_like( ---------- func: Callable Type of MPI Scatter/Gather function - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result send_axis: int The axis along which ``sendbuf`` is packed @@ -1965,6 +2037,19 @@ def __gather_like( if not isinstance(recvbuf, torch.Tensor) and recv_axis != 0: raise TypeError(f"recvbuf of type {type(recvbuf)} does not support recv_axis != 0") + if not self.is_distributed(): + if sendbuf is not MPI.IN_PLACE: + if isinstance(recvbuf, torch.Tensor): + if isinstance(sendbuf, np.ndarray): + sendbuf = torch.from_numpy(sendbuf) + recvbuf.copy_(sendbuf) + elif isinstance(recvbuf, np.ndarray): + if isinstance(sendbuf, torch.Tensor): + sendbuf = sendbuf.numpy() + np.copyto(sendbuf, recvbuf) + + return None, None, None, None, None + # keep a reference to the original buffer object original_recvbuf = recvbuf @@ -2011,20 +2096,20 @@ def __gather_like( def Gather( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, root: int = 0, axis: int = 0, recv_axis: int = None, - ): + ) -> None: """ Gathers together values from a group of processes Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result root: int Rank of receiving process @@ -2046,20 +2131,20 @@ def Gather( def Gatherv( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, root: int = 0, axis: int = 0, recv_axis: int = None, - ): + ) -> None: """ v-call for Gather: All processes send different amount of data Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result root: int Rank of receiving process @@ -2081,8 +2166,8 @@ def Gatherv( def Igather( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, root: int = 0, axis: int = 0, recv_axis: int = None, @@ -2092,9 +2177,9 @@ def Igather( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result root: int Rank of receiving process @@ -2119,8 +2204,8 @@ def Igather( def Igatherv( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, root: int = 0, axis: int = 0, recv_axis: int = None, @@ -2130,9 +2215,9 @@ def Igatherv( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result root: int Rank of receiving process @@ -2158,14 +2243,20 @@ def Igatherv( def __scatter_like( self, func: Callable, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, send_axis: int, recv_axis: int, send_factor: int = 1, recv_factor: int = 1, **kwargs, - ): + ) -> tuple[ + MPI.Request | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + list[int] | None, + ]: """ Generic function for scatter operations. @@ -2173,9 +2264,9 @@ def __scatter_like( ---------- func: Callable Type of MPI Scatter/Gather function - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result send_axis: int The axis along which ``sendbuf`` is packed @@ -2213,6 +2304,19 @@ def __scatter_like( if not isinstance(recvbuf, torch.Tensor) and recv_axis != 0: raise TypeError(f"recvbuf of type {type(recvbuf)} does not support recv_axis != 0") + if not self.is_distributed(): + if sendbuf is not MPI.IN_PLACE: + if isinstance(recvbuf, torch.Tensor): + if isinstance(sendbuf, np.ndarray): + sendbuf = torch.from_numpy(sendbuf) + recvbuf.copy_(sendbuf) + elif isinstance(recvbuf, np.ndarray): + if isinstance(sendbuf, torch.Tensor): + sendbuf = sendbuf.numpy() + np.copyto(sendbuf, recvbuf) + + return None, None, None, None, None + # keep a reference to the original buffer object original_recvbuf = recvbuf @@ -2262,8 +2366,8 @@ def __scatter_like( def Iscatter( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, root: int = 0, axis: int = 0, recv_axis: int = None, @@ -2273,9 +2377,9 @@ def Iscatter( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result root: int Rank of sending process @@ -2300,8 +2404,8 @@ def Iscatter( def Iscatterv( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, root: int = 0, axis: int = 0, recv_axis: int = None, @@ -2311,9 +2415,9 @@ def Iscatterv( Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result root: int Rank of sending process @@ -2338,20 +2442,20 @@ def Iscatterv( def Scatter( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, + recvbuf: Any, root: int = 0, axis: int = 0, recv_axis: int = None, - ): + ) -> None: """ Sends data parts from one process to all other processes in a communicator Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result root: int Rank of sending process @@ -2373,20 +2477,20 @@ def Scatter( def Scatterv( self, - sendbuf: Union[DNDarray, torch.Tensor, Any], + sendbuf: Any, recvbuf: int, root: int = 0, axis: int = 0, recv_axis: int = None, - ): + ) -> None: """ v-call for Scatter: Sends different amounts of data to different processes Parameters ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] + sendbuf: Any Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] + recvbuf: Any Buffer address where to store the result root: int Rank of sending process @@ -2442,7 +2546,7 @@ def get_comm() -> Communication: return __default_comm -def sanitize_comm(comm: Optional[Communication]) -> Communication: +def sanitize_comm(comm: Communication | None) -> Communication: """ Sanitizes a device or device identifier, i.e. checks whether it is already an instance of :class:`heat.core.devices.Device` or a string with known device identifier and maps it to a proper ``Device``. @@ -2465,7 +2569,7 @@ def sanitize_comm(comm: Optional[Communication]) -> Communication: raise TypeError(f"Unknown communication, must be instance of {Communication}") -def use_comm(comm: Communication = None): +def use_comm(comm: Communication | None = None): """ Sets the globally used default communicator. diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 16ce355700..722e503999 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -613,7 +613,7 @@ def __cast(self, cast_function) -> Union[float, int]: """ if np.prod(self.shape) == 1: - if self.split is None: + if not self.is_distributed(): return cast_function(self.__array) is_empty = np.prod(self.__array.shape) == 0 @@ -714,7 +714,7 @@ def create_lshape_map(self, force_check: bool = False) -> torch.Tensor: lshape_map = torch.zeros( (self.comm.size, self.ndim), dtype=torch.int64, device=self.device.torch_device ) - if not self.is_distributed: + if not self.is_distributed(): lshape_map[:] = torch.tensor(self.gshape, device=self.device.torch_device) return lshape_map if self.is_balanced(force_check=True): @@ -852,7 +852,7 @@ def fill_diagonal(self, value: float) -> DNDarray: if len(self.shape) != 2: raise ValueError("Only 2D tensors supported at the moment") - if self.split is not None and self.comm.is_distributed: + if self.is_distributed(): counts, displ, _ = self.comm.counts_displs_shape(self.shape, self.split) k = min(self.shape[0], self.shape[1]) for p in range(self.comm.size): diff --git a/heat/core/linalg/eigh.py b/heat/core/linalg/eigh.py index 955ab48865..c7e9f8b944 100644 --- a/heat/core/linalg/eigh.py +++ b/heat/core/linalg/eigh.py @@ -140,7 +140,7 @@ def _eigh( """ n = A.shape[0] global_comm = A.comm - nprocs = global_comm.Get_size() + nprocs = global_comm.size rank = global_comm.rank # direct solution in torch if the problem is small enough diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 437abe6871..8d58286c34 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -485,7 +485,7 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: dev = A.device tdev = dev.torch_device - nprocs = comm.Get_size() + nprocs = comm.size if A.split is None: # A not split if b.split is None: diff --git a/heat/core/linalg/svdtools.py b/heat/core/linalg/svdtools.py index 9dbb97d302..1ef35931d0 100644 --- a/heat/core/linalg/svdtools.py +++ b/heat/core/linalg/svdtools.py @@ -297,7 +297,7 @@ def hsvd( transposeflag = True A = A.T - no_procs = A.comm.Get_size() + no_procs = A.comm.size Anorm = vector_norm(A) diff --git a/heat/core/random.py b/heat/core/random.py index 02e8bfad7d..6c75e93e5d 100644 --- a/heat/core/random.py +++ b/heat/core/random.py @@ -118,8 +118,8 @@ def __counter_sequence( # Share this initial local state to update it correctly later tmp_counter = __counter - rank = comm.Get_rank() - size = comm.Get_size() + rank = comm.rank + size = comm.size max_count = 0xFFFFFFFF if dtype == torch.int32 else 0xFFFFFFFFFFFFFFFF # extract the counter state of the random number generator diff --git a/heat/core/statistics.py b/heat/core/statistics.py index c9f49b9936..d0c9d4d66d 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -6,7 +6,7 @@ import torch from typing import Any, Callable, Union, Tuple, List, Optional -from .communication import MPI +from .communication import MPI, MPI_WORLD from . import arithmetics from . import exponential from . import factories @@ -1050,7 +1050,7 @@ def median( axis: Optional[int] = None, keepdims: bool = False, sketched: bool = False, - sketch_size: Optional[float] = 1.0 / MPI.COMM_WORLD.size, + sketch_size: float = 1.0 / MPI_WORLD.size, ) -> DNDarray: """ Compute the median of the data along the specified axis. @@ -1084,8 +1084,8 @@ def median( DNDarray.median: Callable[[DNDarray, int, bool, bool, float], DNDarray] = ( - lambda x, axis=None, keepdims=False, sketched=False, sketch_size=1.0 / MPI.COMM_WORLD.size: ( - median(x, axis, keepdims, sketched=sketched, sketch_size=sketch_size) + lambda x, axis=None, keepdims=False, sketched=False, sketch_size=1.0 / MPI_WORLD.size: median( + x, axis, keepdims, sketched=sketched, sketch_size=sketch_size ) ) DNDarray.median.__doc__ = median.__doc__ @@ -1465,7 +1465,7 @@ def percentile( interpolation: str = "linear", keepdims: bool = False, sketched: bool = False, - sketch_size: Optional[float] = 1.0 / MPI.COMM_WORLD.size, + sketch_size: float = 1.0 / MPI_WORLD.size, ) -> DNDarray: r""" Compute the q-th percentile of the data along the specified axis/axes. @@ -1668,7 +1668,7 @@ def _create_sketch( if ( not isinstance(sketch_size, float) or sketch_size <= 0 - or (MPI.COMM_WORLD.size > 1 and sketch_size == 1) + or (MPI_WORLD.size > 1 and sketch_size == 1) or sketch_size > 1 ): raise ValueError(