Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resplit with Custom MPI Datatypes and AlltoAllW #1493

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

JuanPedroGHM
Copy link
Member

@JuanPedroGHM JuanPedroGHM commented May 23, 2024

Due Diligence

  • General:
  • Implementation:
    • unit tests: all split configurations tested
    • unit tests: multiple dtypes tested
    • documentation updated where needed

Description

Rewritten most of resplit to use Alltoallw and custom data types.

Issue/s resolved: #

Changes proposed:

  • MPICommunicator
    • Alltoallw operation, that mimics the MPI Alltoallw interface.
    • mpi_type_of class method for easy of use
    • _create_recursive_vector to handle subarray datatype creation for non-contiguous send buffers.
  • Manipulations
    • _axis2axis method to handle all non-trivial replits
  • Tiling
    • get_subarray_params method to calculate MPI subarray type params

Type of change

  • Breaking change (fix or feature that would cause existing functionality to not work as expected)

Memory requirements

Coming soon...

Performance

Coming soon...

Does this change modify the behaviour of other functions? If so, which?

yes

Probably most high level ones.

@JuanPedroGHM JuanPedroGHM self-assigned this May 23, 2024
@ClaudiaComito ClaudiaComito changed the title Reshape with Custom MPI Datatypes and AlltoAllW resplit with Custom MPI Datatypes and AlltoAllW May 23, 2024
@JuanPedroGHM JuanPedroGHM marked this pull request as ready for review June 5, 2024 19:55
@JuanPedroGHM JuanPedroGHM added MPI Anything related to MPI communication communication labels Jun 5, 2024
Copy link
Contributor

github-actions bot commented Jun 5, 2024

Thank you for the PR!

1 similar comment
Copy link
Contributor

github-actions bot commented Jun 5, 2024

Thank you for the PR!

Copy link
Contributor

github-actions bot commented Jun 6, 2024

Thank you for the PR!

@JuanPedroGHM
Copy link
Member Author

JuanPedroGHM commented Jun 6, 2024

Some results from Horeka using the following code. Using the old resplit, it managed to overflow the GPU memory (H100 with 94 GBs of memory) when doubling the size of the array. When trying on 4 nodes (16 GPUs), it hit the walltime of 1 hour.

from mpi4py import MPI
import heat as ht
import argparse
import perun
import torch

from heat.core.communication import CUDA_AWARE_MPI
print(f"CUDA_AWARE_MPI: {CUDA_AWARE_MPI}")

@perun.monitor()
def cpu_contiguous(a):
    a = a.resplit(4)
    a.resplit_(3)


@perun.monitor()
def cpu_noncontiguous(a):
    a = a.resplit(0)
    a.resplit_(2)


@perun.monitor()
def gpu_contiguous(a):
    a = a.resplit(4)
    a.resplit_(3)

@perun.monitor()
def gpu_noncontiguous(a):
    a = a.resplit(0)
    a.resplit_(2)

if __name__ == "__main__":
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    
    # Contiguous data creation
    shape = [100, 50, 50, 20, 250]
    n_elements = ht.array(shape).prod().item()
    Mem = n_elements * 8 / 1e9
    base_array = torch.arange(0, n_elements, dtype=torch.float64).reshape(shape) * (rank+1)
    
    print(f"Rank {rank} - Local Shape: {shape} - Memory: {Mem * size} GB - Per rank: {Mem} GB")
    
    # CPU contiguous data
    print("CPU contiguous data")
    a = ht.array(base_array, dtype=ht.float64, is_split=0, copy=True)
    print(f"Rank {rank} - Shape: {a.shape} - Split: {a.split} - Lshape: {a.lshape} - Device: {a.device}")
    cpu_contiguous(a)
    del a
    
    # CPU non-contiguous data
    print("CPU non-contiguous data")
    a = ht.array(base_array, dtype=ht.float64, is_split=1, copy=True).transpose((1,0,4,3,2))
    print(f"Rank {rank} - Shape: {a.shape} - Split: {a.split} - Lshape: {a.lshape} - Device: {a.device}")
    cpu_noncontiguous(a)
    del a

    # GPU contiguous data
    print("GPU contiguous data")
    a = ht.array(base_array, dtype=ht.float64, device="cuda", is_split=0, copy=True)
    print(f"Rank {rank} - Shape: {a.shape} - Split: {a.split} - Lshape: {a.lshape} - Device: {a.device}")
    gpu_contiguous(a)
    del a
    
    # GPU non-contiguous data
    print("GPU non-contiguous data")
    a = ht.array(base_array, dtype=ht.float64, device="cuda", is_split=1, copy=True).transpose((1,0,4,3,2))
    print(f"Rank {rank} - Shape: {a.shape} - Split: {a.split} - Lshape: {a.lshape} - Device: {a.device}")
    gpu_noncontiguous(a)
    del a

    torch.cuda.empty_cache()

Runtime

image

Memory

image

GPU Memory

image

@ClaudiaComito
Copy link
Contributor

And the runtime axis is logarithmic? 😁 Fantastic @JuanPedroGHM !

Copy link

codecov bot commented Jun 7, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 91.78%. Comparing base (6f5fa1f) to head (48939e1).

Current head 48939e1 differs from pull request most recent head c083678

Please upload reports for the commit c083678 to get more accurate results.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1493      +/-   ##
==========================================
- Coverage   91.80%   91.78%   -0.02%     
==========================================
  Files          80       80              
  Lines       11772    11810      +38     
==========================================
+ Hits        10807    10840      +33     
- Misses        965      970       +5     
Flag Coverage Δ
unit 91.78% <100.00%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@ClaudiaComito ClaudiaComito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuanPedroGHM I will review next week, but in the meantime - was resplit ever tested on column-major arrays? To me it looks like we forgot that one. If so, would you add a test for DNDarrays with order="F"? There are some examples in test_dndarray.test_stride_and_strides.

Thanks again for the fantastic job!

Copy link
Contributor

github-actions bot commented Jun 7, 2024

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

1 similar comment
Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

@ClaudiaComito ClaudiaComito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuanPedroGHM thank you so much for this. I have a few comments, mostly from the point of view of maintainability. Great job!

heat/core/communication.py Outdated Show resolved Hide resolved
heat/core/communication.py Outdated Show resolved Hide resolved
heat/core/communication.py Outdated Show resolved Hide resolved
Comment on lines 1472 to 1475
sendbuf: Union[DNDarray, torch.Tensor, Any]
Buffer address of the send message
recvbuf: Union[DNDarray, torch.Tensor, Any]
Buffer address where to store the result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we give more details on how sendbuf and recvbuf should be constructed? Related: #1072 (which we haven't addressed so far)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is definitely a big problem in the whole problem in the whole communication.py file. I guess it was made like this originally to support _alltoall_like and similar methods that need a flexible function signature.

For me, what we need is to define a more consistent communication interface, for example, one where all the communication buffers are a tuple of a torch array, and a collection of views from that buffer that we use to define the data types. What do you think?

For now, I'll expand on the actual contents of the buffer on the doc string.


Notes
-----
This function creates a recursive vector datatype by defining vectors out of the previous datatype with specified strides and sizes. The extent of the new datatype is set to the extent of the basic datatype to allow interweaving of data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stupid question for sure, but what is the "extent of a datatype"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a weird name for the length, or how many bits, the data type uses. Every MPI datatype has a symbolic extent, that we use to read non-contiguous data in the right order, and a real extent. More here: https://enccs.github.io/intermediate-mpi/derived-datatypes-pt2/

heat/core/tests/test_dndarray.py Outdated Show resolved Hide resolved
heat/core/tiling.py Outdated Show resolved Hide resolved
heat/core/communication.py Outdated Show resolved Hide resolved
Comment on lines +1481 to 1483
recv_buffer = torch.empty(
tuple(new_lshape), dtype=self.dtype.torch_type(), device=self.device.torch_device
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still in-place? 😬 Not sure this function ever was in place, really. Expand docs?

What's the difference in memory usage between array.resplit_(newaxis) and array = ht.resplit(array, newaxis)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it is the type of in place resplit as before, which is not really in place for on the larray. The DNDArray is reused and all the other properties are rewritten, but a new larray is created.

We could keep the same larray if the incoming and out coming data has the same length, but it will have to be written in a non-contiguous way, which is not impossible, but not ideal in my opinion.

@@ -3537,34 +3537,15 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
gathered, is_split=axis, device=arr.device, comm=arr.comm, dtype=arr.dtype
)
return new_arr
arr_tiles = tiling.SplitTiles(arr)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update the docstring here, and mention the Dalcin paper?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the FFT paper? I would reference it inside Alltoallw, as much of the relevant code landed there, or in tiling.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmark PR communication MPI Anything related to MPI communication
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants