Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Add support to convert INT8 MMAV2 accumulator layout to dot_operand layout #3595

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

tongyuantongyu
Copy link
Contributor

@tongyuantongyu tongyuantongyu commented Apr 7, 2024

Partial fix of #3580. Resolved the INT8 layoutC -> INT8 layoutA case.

  • Port MMAV3's reg shuffling to support MMAV2 layout.
  • Simplify the shuffling logic.

@ThomasRaoux
Copy link
Collaborator

I haven't reviewed in details as it is failing the tests. I'm not sure I understand why the code sequence for mma to dot_operand(fp8) has changed.

The FP16 case is a bit tough. convert_layout only knows that the tensor is in #mma layout, but has no idea which MMA exactly. #mma layouts of different MMAs are different. Guessing from element type is not reliable, as user may (and for INT8/FP8 MMA, have to) cast between types.

I don't understand, could you give an example of what MMA format is different based on the type?

Comment on lines -189 to -190
// CHECK: prmt.b32
// CHECK: prmt.b32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are those removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two corresponding to the prmt-s using selectorEx0 and selectorEx1. These two patterns are simply selecting one Value of the two, so a select is sufficient and no prmt is required here.

rewriter.replaceOp(op, result);
}

void convert8BitsMMAV2To16BitsDotOperand(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let' not add it if it is not used and tested

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@tongyuantongyu
Copy link
Contributor Author

I don't understand, could you give an example of what MMA format is different based on the type?

You're indeed right. I was confused as there turns out to be another issue here (Detail reported in #3580 (comment)). Loading INT8 input and do both MMA in FP16 also get wrong result. convert8BitsMMAV2To16BitsDotOperand fixed the wrong order, and made me think the issue is MMA having different layouts.

@tongyuantongyu
Copy link
Contributor Author

FAILED hopper/test_gemm.py::test_gemm[128-128-64-4-1-4096-1-1024-False-False-True-none-float32-False-3] - AssertionError: Tensor-likes are not close!

Mismatched elements: 1 / 4096 (0.0%)
Greatest absolute difference: 2.0 at index (289, 0) (up to 0.001 allowed)
Greatest relative difference: 2.0 at index (289, 0) (up to 0.01 allowed)

This failure seems flaky. I don't see a torch.manual_seed call in test_gemm.py, so maybe this is numerical instability with specific input value.

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a big performance regression in the fp8 flash attention tutorial with this patch (run python tutorials/06-fused-attention.py on h100 to see it). I'm not sure where this is coming from. Please fix it and I can take a deeper look at the changes.

@@ -3073,7 +3074,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
z_tri = torch.as_strided(z_tri, (M, N), [1, M])

if out_dtype == 'int8':
out_dtype = tl.int8
out_dtype = tl.int32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that doesn't make sense. That will change behavior of existing tests. If we want to tests i32 dtype it can be set in the config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output type for INT8 MMA is always INT32 (https://github.com/openai/triton/blob/main/python/triton/language/semantic.py#L1358), and out_dtype is simply ignored, as opposed to FP MMA (https://github.com/openai/triton/blob/main/python/triton/language/semantic.py#L1367)

I changed it to tl.int32 here just for correctness. If you feel necessary, I can add a check that out_dtype must be tl.int32 for tl.int8 inputs, in this or a separate PR.

@tongyuantongyu
Copy link
Contributor Author

I see a big performance regression in the fp8 flash attention tutorial with this patch

Sorry I don't have access to H100. I made an attempt to fix it, could you test if it fixes the regression? If it's still there, I'll try to revert all changes to MMAV3 part.

@ThomasRaoux
Copy link
Collaborator

I see a big performance regression in the fp8 flash attention tutorial with this patch

Sorry I don't have access to H100. I made an attempt to fix it, could you test if it fixes the regression? If it's still there, I'll try to revert all changes to MMAV3 part.

Thanks, that seems to fix it. I'll look more at the code sequence in a little bit to understand it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants