Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TMA support for Pointwise and Normalization Kernels #2190

Open
rdspring1 opened this issue May 2, 2024 · 0 comments
Open

TMA support for Pointwise and Normalization Kernels #2190

rdspring1 opened this issue May 2, 2024 · 0 comments
Labels

Comments

@rdspring1
Copy link
Collaborator

rdspring1 commented May 2, 2024

Here are current TODOs for implementing a pointwise and normalization scheduler.
NOTE: The action items is not exhaustive nor ordered by priority.

1. Initialize and Invalidate mbarrier at the start and end of the kernel respectively.

  • Currently, we initialize and invalidate mbarrier for each TMA operations, which adds unnecessary overhead.

2. Implement mbarrier with parity bit.

  • A single thread can arrive at a mbarrier and set the expected transaction count while the other threads wait at mbarrier for memory transaction to be completed.
  • Currently, we use the mbarrier token style. All threads that arrive at mbarrier get a token. Every thread must wait at mbarrier with token. It enforces all threads to operate together.
  • Only a single thread launches TMA operation and sets expected transaction count.
  • The remaining threads must arrive at mbarrier but sets expected transaction count to 0.

Motivation

  • Required for warp specialization
  • Simplifies mbarrier arrive and wait pattern.

mbarrier token

__mbarrier_token_t token;

if (elect_sync()) {
  // Initiate TMA bulk tensor copy.
  cp_async_bulk_global_to_shared_tensor_2d(&smem_barrier, ...);
  token = barrier_arrive1_tx(&smem_barrier, expected_transaction_count);
} else {
  // Other threads arrive with arrival count of 1 and expected transaction count of 0.
  token = barrier_arrive1_tx(&smem_barrier, 0);
}

while(! barrier_try_wait_token(&smem_barrier, token)) { };

// compute

mbarrier parity

int parity = 0;
if (elect_sync()) {
  // Initiate TMA bulk tensor copy.
  cp_async_bulk_global_to_shared_tensor_2d(&smem_barrier, ...);
  barrier_arrive1_tx(&smem_barrier, expected_transaction_count);
} 

while(! barrier_try_wait_parity(&smem_barrier, parity)) { };

// compute

// update parity bit
parity ^= 1;

3. Pipelining - (Multiple mbarriers per TensorView)

  • Launch multiple TMA operations simultaneously but process each stage as they become available.

Motivation

  • Overlap data movement with computation

Pseudo-code

 for each stage of producer TV:
   launch TMA operation for stage
 end for

 for each stage of consumer:
   wait for corresponding TMA stage to become available
 end for

4. Combining mbarriers together (Multiple TensorViews for a mbarrier)

  • Currently, we create a mbarrier for each TensorView, but the TensorViews can use the same mbarrier if they synchronize at the same point.
  • Use syncthread analysis to identify placement of mbarrier_wait
  • Merge mbarriers at the same sync position together
  • Create a single mbarrier but combine the expected transaction count

Motivation

  • We can launch independent TMA load operations and wait for all results at same time.
  • Minimize register pressure caused by mbarrier overhead.

Pseudo-code

  if (elect_sync()) {
      cp_async_bulk_global_to_shared_tensor_2d(&mbarrier, ...);
      cp_async_bulk_global_to_shared_tensor_2d(&mbarrier, ...);
      barrier_arrive1_tx(&mbarrier, num_operands * tma_tile_size);
  }

  while(!barrier_try_wait_parity(&mbarrier, parity));

  // compute

5. Implement ublkcp TMA operator (1D version of TMA)

  • Does not require a TMA descriptor.
  • No 256 element limit per dimension, so you can request whatever size is necessary.
  • Basically extended vectorization.
  • Lacks support for bank conflicts, tile striding, etc.
@rdspring1 rdspring1 added the TMA label May 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant