-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
resplit
with Custom MPI Datatypes and AlltoAllW
#1493
base: main
Are you sure you want to change the base?
Conversation
resplit
with Custom MPI Datatypes and AlltoAllW
…ne test case, thanks to recursive vector datatypes
Thank you for the PR! |
1 similar comment
Thank you for the PR! |
Thank you for the PR! |
Some results from Horeka using the following code. Using the old resplit, it managed to overflow the GPU memory (H100 with 94 GBs of memory) when doubling the size of the array. When trying on 4 nodes (16 GPUs), it hit the walltime of 1 hour. from mpi4py import MPI
import heat as ht
import argparse
import perun
import torch
from heat.core.communication import CUDA_AWARE_MPI
print(f"CUDA_AWARE_MPI: {CUDA_AWARE_MPI}")
@perun.monitor()
def cpu_contiguous(a):
a = a.resplit(4)
a.resplit_(3)
@perun.monitor()
def cpu_noncontiguous(a):
a = a.resplit(0)
a.resplit_(2)
@perun.monitor()
def gpu_contiguous(a):
a = a.resplit(4)
a.resplit_(3)
@perun.monitor()
def gpu_noncontiguous(a):
a = a.resplit(0)
a.resplit_(2)
if __name__ == "__main__":
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
# Contiguous data creation
shape = [100, 50, 50, 20, 250]
n_elements = ht.array(shape).prod().item()
Mem = n_elements * 8 / 1e9
base_array = torch.arange(0, n_elements, dtype=torch.float64).reshape(shape) * (rank+1)
print(f"Rank {rank} - Local Shape: {shape} - Memory: {Mem * size} GB - Per rank: {Mem} GB")
# CPU contiguous data
print("CPU contiguous data")
a = ht.array(base_array, dtype=ht.float64, is_split=0, copy=True)
print(f"Rank {rank} - Shape: {a.shape} - Split: {a.split} - Lshape: {a.lshape} - Device: {a.device}")
cpu_contiguous(a)
del a
# CPU non-contiguous data
print("CPU non-contiguous data")
a = ht.array(base_array, dtype=ht.float64, is_split=1, copy=True).transpose((1,0,4,3,2))
print(f"Rank {rank} - Shape: {a.shape} - Split: {a.split} - Lshape: {a.lshape} - Device: {a.device}")
cpu_noncontiguous(a)
del a
# GPU contiguous data
print("GPU contiguous data")
a = ht.array(base_array, dtype=ht.float64, device="cuda", is_split=0, copy=True)
print(f"Rank {rank} - Shape: {a.shape} - Split: {a.split} - Lshape: {a.lshape} - Device: {a.device}")
gpu_contiguous(a)
del a
# GPU non-contiguous data
print("GPU non-contiguous data")
a = ht.array(base_array, dtype=ht.float64, device="cuda", is_split=1, copy=True).transpose((1,0,4,3,2))
print(f"Rank {rank} - Shape: {a.shape} - Split: {a.split} - Lshape: {a.lshape} - Device: {a.device}")
gpu_noncontiguous(a)
del a
torch.cuda.empty_cache() RuntimeMemoryGPU Memory |
And the runtime axis is logarithmic? 😁 Fantastic @JuanPedroGHM ! |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1493 +/- ##
==========================================
- Coverage 91.80% 91.78% -0.02%
==========================================
Files 80 80
Lines 11772 11810 +38
==========================================
+ Hits 10807 10840 +33
- Misses 965 970 +5
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuanPedroGHM I will review next week, but in the meantime - was resplit
ever tested on column-major arrays? To me it looks like we forgot that one. If so, would you add a test for DNDarrays with order="F"? There are some examples in test_dndarray.test_stride_and_strides
.
Thanks again for the fantastic job!
Thank you for the PR! |
Thank you for the PR! |
1 similar comment
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuanPedroGHM thank you so much for this. I have a few comments, mostly from the point of view of maintainability. Great job!
heat/core/communication.py
Outdated
sendbuf: Union[DNDarray, torch.Tensor, Any] | ||
Buffer address of the send message | ||
recvbuf: Union[DNDarray, torch.Tensor, Any] | ||
Buffer address where to store the result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we give more details on how sendbuf
and recvbuf
should be constructed? Related: #1072 (which we haven't addressed so far)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is definitely a big problem in the whole problem in the whole communication.py file. I guess it was made like this originally to support _alltoall_like
and similar methods that need a flexible function signature.
For me, what we need is to define a more consistent communication interface, for example, one where all the communication buffers are a tuple of a torch array, and a collection of views from that buffer that we use to define the data types. What do you think?
For now, I'll expand on the actual contents of the buffer on the doc string.
heat/core/communication.py
Outdated
|
||
Notes | ||
----- | ||
This function creates a recursive vector datatype by defining vectors out of the previous datatype with specified strides and sizes. The extent of the new datatype is set to the extent of the basic datatype to allow interweaving of data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stupid question for sure, but what is the "extent of a datatype"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a weird name for the length, or how many bits, the data type uses. Every MPI datatype has a symbolic extent, that we use to read non-contiguous data in the right order, and a real extent. More here: https://enccs.github.io/intermediate-mpi/derived-datatypes-pt2/
recv_buffer = torch.empty( | ||
tuple(new_lshape), dtype=self.dtype.torch_type(), device=self.device.torch_device | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still in-place? 😬 Not sure this function ever was in place, really. Expand docs?
What's the difference in memory usage between array.resplit_(newaxis)
and array = ht.resplit(array, newaxis)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it is the type of in place resplit as before, which is not really in place for on the larray
. The DNDArray is reused and all the other properties are rewritten, but a new larray
is created.
We could keep the same larray if the incoming and out coming data has the same length, but it will have to be written in a non-contiguous way, which is not impossible, but not ideal in my opinion.
@@ -3537,34 +3537,15 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray: | |||
gathered, is_split=axis, device=arr.device, comm=arr.comm, dtype=arr.dtype | |||
) | |||
return new_arr | |||
arr_tiles = tiling.SplitTiles(arr) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we update the docstring here, and mention the Dalcin paper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean the FFT paper? I would reference it inside Alltoallw
, as much of the relevant code landed there, or in tiling.py
Due Diligence
Description
Rewritten most of
resplit
to useAlltoallw
and custom data types.Issue/s resolved: #
Changes proposed:
Alltoallw
operation, that mimics the MPIAlltoallw
interface.mpi_type_of
class method for easy of use_create_recursive_vector
to handle subarray datatype creation for non-contiguous send buffers._axis2axis
method to handle all non-trivial replitsget_subarray_params
method to calculate MPI subarray type paramsType of change
Memory requirements
Coming soon...
Performance
Coming soon...
Does this change modify the behaviour of other functions? If so, which?
yes
Probably most high level ones.