Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve scheduling of prefetched dot #3972

Merged
merged 1 commit into from
Jun 2, 2024
Merged

Conversation

bertmaher
Copy link
Collaborator

While analyzing performance of tf32 gemm on A100, I found a surprising number of stalls on ldmatrix. Looking at the ttgir:

local_load
tt.dot
tt.dot
async_copy, etc...
local_load

suggested that the prefetching wasn't really helping to hide latency

Instead, we want to delay the second dot until after the second local_load:

local_load
tt.dot
async_copy, etc...
local_load
tt.dot

With this scheduling tweak tf32 perf is more or less on par with cuBLAS:

matmul-performance-fp32:
         M       N       K      cublas  triton prev    triton now
0    256.0   256.0   256.0    2.502568     3.057073      2.953735
1    384.0   384.0   384.0    7.643507     8.406043      8.759763
2    512.0   512.0   512.0   18.078896    14.413416     15.477137
3    640.0   640.0   640.0   30.284659    23.883382     25.801575
4    768.0   768.0   768.0   39.050414    34.317033     35.345259
5    896.0   896.0   896.0   56.550561    48.973525     49.787038
6   1024.0  1024.0  1024.0   59.440979    50.043897     56.631955
7   1152.0  1152.0  1152.0   75.148633    63.828650     72.995792
8   1280.0  1280.0  1280.0   95.533530    79.776017     88.383007
9   1408.0  1408.0  1408.0   72.298729    67.645144     72.750969
10  1536.0  1536.0  1536.0   87.787755    81.943708     87.280315
11  1664.0  1664.0  1664.0   96.892725    88.768552     96.600192
12  1792.0  1792.0  1792.0  112.464534   104.189331    112.148915
13  1920.0  1920.0  1920.0   88.553303    81.844219     89.061408
14  2048.0  2048.0  2048.0   99.715994    93.279631    100.745155
15  2176.0  2176.0  2176.0  112.031443   101.714853    113.193314
16  2304.0  2304.0  2304.0  124.659473   114.142437    127.253524
17  2432.0  2432.0  2432.0  114.685912   102.933691    112.230552
18  2560.0  2560.0  2560.0  126.976992   114.236412    124.386234
19  2688.0  2688.0  2688.0  108.215904    98.208559    108.758878
20  2816.0  2816.0  2816.0  118.406265   107.830851    119.501210
21  2944.0  2944.0  2944.0  129.108891   117.898422    130.610400
22  3072.0  3072.0  3072.0  125.246376   112.788009    122.885002
23  3200.0  3200.0  3200.0  133.159950   119.121699    129.833900
24  3328.0  3328.0  3328.0  124.957773   110.181092    121.741877
25  3456.0  3456.0  3456.0  134.411287   118.765808    130.945595
26  3584.0  3584.0  3584.0  129.718790   116.785078    127.570664
27  3712.0  3712.0  3712.0  135.099099   121.469581    133.193687
28  3840.0  3840.0  3840.0  123.986410   115.455567    127.355116
29  3968.0  3968.0  3968.0  128.467381   123.054516    135.546859
30  4096.0  4096.0  4096.0  136.270295   122.945194    134.175797

The implementation is ... mildly hacky, so would definitely like some feedback on it. But also I figure that this pass is limited to A100 (we don't use prefetch on cc > 80), so it has a short shelf life and might not be worth making too general. (I do want it for some critical workloads that are going to be important for a while longer yet though)

@bertmaher bertmaher requested review from htyu and manman-ren May 22, 2024 16:58
@Jokeren
Copy link
Contributor

Jokeren commented May 22, 2024

But also I figure that this pass is limited to A100 (we don't use prefetch on cc > 80)

@ThomasRaoux Is there any reason why we don't prefetch on cc > 80? I think it would still be beneficial for cc < 90

@Jokeren
Copy link
Contributor

Jokeren commented May 22, 2024

Also if the dot op generated on H100 is not AsyncDot, we are not able to read directly from the shared memory. Then prefetch still seems to help

@ThomasRaoux
Copy link
Collaborator

But also I figure that this pass is limited to A100 (we don't use prefetch on cc > 80)

@ThomasRaoux Is there any reason why we don't prefetch on cc > 80? I think it would still be beneficial for cc < 90

for sm90 it is because there is no more load from shared memory and it is harder to breakup the matmul. But we should be checking the mma version instead of the sm version

@bertmaher
Copy link
Collaborator Author

To help with reviewing, here are some IR dumps from a basic matmul example.

And just to be extra clear, here are before and after TTGIR dumps from the end of the pipeline; the IR gets cleaned up a good bit after the prefetch pass so this makes it a bit easier to see the end effect.

Comment on lines 294 to 307
// If we find anything that depends on a dot op, stop inserting before the
// dot, and add to the end of the block instead.
for (auto operand : op.getOperands()) {
if (auto def = operand.getDefiningOp()) {
auto dot = dyn_cast<triton::DotOp>(def);
if (dot && dots.contains(dot)) {
builder.setInsertionPointToEnd(newForOp.getBody());
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops could have nested regions with use of tt.dot?

something like:

%d = tt.dot
scf.if 
   %user = user_op %d

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I guess we can stop when hit a region op. So at end of each dot op, we set the insertion point to the last decomposed small dot. When going through the later ops in the loop, we can advance the insertion point to the end of the new forOp when needed.

Copy link
Collaborator

@htyu htyu May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the current implementation may not work well if the dot is in a if-block or there are other control flows. Maybe just bail out when there is any control flow in the loop body?

Also IIUC, the code following the original dot is cut in half by the dot dependency op. The first half will be placed before the last dot. This seems a bit random and I'm not sure it's always beneficial. For example, there could be expansive operations that originally don't block the dot but now do, such as a async_wait for unrelated loads for other operations. Correct me if this is not valid.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦 good catch, I forgot entirely about control flow. I've updated the PR to stop sinking the dot if any control flow is encountered. (I guess I could recurse into the control flow region to see if the dot is used inside... but I'm not sure it's worth trying to optimize that case)

Copy link
Collaborator

@manman-ren manman-ren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this this PR, noticed a TODO in this file
// TODO: segfault (original for still has uses)
// when used in flash attention that has 2 dots in the loop

Is this something worth fixing?

Comment on lines 294 to 307
// If we find anything that depends on a dot op, stop inserting before the
// dot, and add to the end of the block instead.
for (auto operand : op.getOperands()) {
if (auto def = operand.getDefiningOp()) {
auto dot = dyn_cast<triton::DotOp>(def);
if (dot && dots.contains(dot)) {
builder.setInsertionPointToEnd(newForOp.getBody());
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I guess we can stop when hit a region op. So at end of each dot op, we set the insertion point to the last decomposed small dot. When going through the later ops in the loop, we can advance the insertion point to the end of the new forOp when needed.

@@ -329,6 +339,11 @@ scf::ForOp Prefetcher::createNewForOp() {
prevDot = newOp;
kOff += kShape;
kRem -= kShape;
if (kRem == 0) {
// We want to delay issuing the last dot as long as possible, ideally
// until after the prefetch.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add: Move the insertion point to before the last decomposed dot so later instructions will be added prior to it.

Copy link
Collaborator

@htyu htyu May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if there are more than two dots, the intermediate dots may still not be overlapped well? Maybe we want to handle that separately.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@manman-ren thanks, I attempted to clarify the comments to this effect!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if there are more than two dots, the intermediate dots may still not be overlapped well? Maybe we want to handle that separately.

If the original dot is decomposed into 3 small dots, we will have
prefetch 2nd dot operands
run 1st dot
prefetch 3rd dot operands
run 2nd dot
...
run 3rd dot
...
prefetch 1st dot operands
This is more like a scheduling problem now. May make sense to do it in the scheduler :]

@manman-ren
Copy link
Collaborator

But also I figure that this pass is limited to A100 (we don't use prefetch on cc > 80)

@ThomasRaoux Is there any reason why we don't prefetch on cc > 80? I think it would still be beneficial for cc < 90

for sm90 it is because there is no more load from shared memory and it is harder to breakup the matmul. But we should be checking the mma version instead of the sm version

So we can always add this pass in make_ttgir and inside the pass we skip mma v3? We can fix that.

Copy link
Collaborator

@manman-ren manman-ren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me now. Thanks for adding the comments!

@bertmaher bertmaher marked this pull request as ready for review May 29, 2024 17:26
@bertmaher bertmaher requested a review from ptillet as a code owner May 29, 2024 17:26
@Jokeren Jokeren merged commit 384fd6a into triton-lang:main Jun 2, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants