-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve scheduling of prefetched dot #3972
Conversation
@ThomasRaoux Is there any reason why we don't prefetch on cc > 80? I think it would still be beneficial for cc < 90 |
Also if the dot op generated on H100 is not AsyncDot, we are not able to read directly from the shared memory. Then prefetch still seems to help |
for sm90 it is because there is no more load from shared memory and it is harder to breakup the matmul. But we should be checking the mma version instead of the sm version |
To help with reviewing, here are some IR dumps from a basic matmul example.
And just to be extra clear, here are before and after TTGIR dumps from the end of the pipeline; the IR gets cleaned up a good bit after the prefetch pass so this makes it a bit easier to see the end effect. |
// If we find anything that depends on a dot op, stop inserting before the | ||
// dot, and add to the end of the block instead. | ||
for (auto operand : op.getOperands()) { | ||
if (auto def = operand.getDefiningOp()) { | ||
auto dot = dyn_cast<triton::DotOp>(def); | ||
if (dot && dots.contains(dot)) { | ||
builder.setInsertionPointToEnd(newForOp.getBody()); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ops could have nested regions with use of tt.dot?
something like:
%d = tt.dot
scf.if
%user = user_op %d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I guess we can stop when hit a region op. So at end of each dot op, we set the insertion point to the last decomposed small dot. When going through the later ops in the loop, we can advance the insertion point to the end of the new forOp when needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the current implementation may not work well if the dot is in a if-block or there are other control flows. Maybe just bail out when there is any control flow in the loop body?
Also IIUC, the code following the original dot is cut in half by the dot dependency op. The first half will be placed before the last dot. This seems a bit random and I'm not sure it's always beneficial. For example, there could be expansive operations that originally don't block the dot but now do, such as a async_wait for unrelated loads for other operations. Correct me if this is not valid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤦 good catch, I forgot entirely about control flow. I've updated the PR to stop sinking the dot if any control flow is encountered. (I guess I could recurse into the control flow region to see if the dot is used inside... but I'm not sure it's worth trying to optimize that case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated to this this PR, noticed a TODO in this file
// TODO: segfault (original for still has uses)
// when used in flash attention that has 2 dots in the loop
Is this something worth fixing?
// If we find anything that depends on a dot op, stop inserting before the | ||
// dot, and add to the end of the block instead. | ||
for (auto operand : op.getOperands()) { | ||
if (auto def = operand.getDefiningOp()) { | ||
auto dot = dyn_cast<triton::DotOp>(def); | ||
if (dot && dots.contains(dot)) { | ||
builder.setInsertionPointToEnd(newForOp.getBody()); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I guess we can stop when hit a region op. So at end of each dot op, we set the insertion point to the last decomposed small dot. When going through the later ops in the loop, we can advance the insertion point to the end of the new forOp when needed.
@@ -329,6 +339,11 @@ scf::ForOp Prefetcher::createNewForOp() { | |||
prevDot = newOp; | |||
kOff += kShape; | |||
kRem -= kShape; | |||
if (kRem == 0) { | |||
// We want to delay issuing the last dot as long as possible, ideally | |||
// until after the prefetch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add: Move the insertion point to before the last decomposed dot so later instructions will be added prior to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess if there are more than two dots, the intermediate dots may still not be overlapped well? Maybe we want to handle that separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@manman-ren thanks, I attempted to clarify the comments to this effect!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess if there are more than two dots, the intermediate dots may still not be overlapped well? Maybe we want to handle that separately.
If the original dot is decomposed into 3 small dots, we will have
prefetch 2nd dot operands
run 1st dot
prefetch 3rd dot operands
run 2nd dot
...
run 3rd dot
...
prefetch 1st dot operands
This is more like a scheduling problem now. May make sense to do it in the scheduler :]
So we can always add this pass in make_ttgir and inside the pass we skip mma v3? We can fix that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me now. Thanks for adding the comments!
While analyzing performance of tf32 gemm on A100, I found a surprising number of stalls on ldmatrix. Looking at the ttgir:
suggested that the prefetching wasn't really helping to hide latency
Instead, we want to delay the second
dot
until after the secondlocal_load
:With this scheduling tweak tf32 perf is more or less on par with cuBLAS:
The implementation is ... mildly hacky, so would definitely like some feedback on it. But also I figure that this pass is limited to A100 (we don't use prefetch on cc > 80), so it has a short shelf life and might not be worth making too general. (I do want it for some critical workloads that are going to be important for a while longer yet though)