Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Add support to convert INT8 MMAV2 accumulator layout to dot_operand layout #3595

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
5 changes: 3 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
auto dstLayout = dstTy.getEncoding();
auto mmaLayout = srcLayout.cast<NvidiaMmaEncodingAttr>();
auto dotOperandLayout = dstLayout.cast<DotOperandEncodingAttr>();
int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth();
unsigned elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth();
auto ans = mmaLayout.getVersionMajor() == 3 &&
dotOperandLayout.getOpIdx() == 0 &&
isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) &&
Expand All @@ -628,12 +628,13 @@ bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto mmaLayout = srcLayout.cast<NvidiaMmaEncodingAttr>();
unsigned elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth();
auto dotOperandLayout = dstLayout.cast<DotOperandEncodingAttr>();
return mmaLayout.getVersionMajor() == 2 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout &&
!srcTy.getElementType().isF32();
(elementTypeSize == 16 || elementTypeSize == 8);
}

namespace {
Expand Down
16 changes: 12 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2968,7 +2968,8 @@ def convert_fp8_to_fp32(x, device, dtype_str):
for col_a in [True, False]
for col_b in [True, False]] + [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32')] +
[(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32')
for float8_type in ["float8e5", "float8e4nv"]])
for float8_type in ["float8e5", "float8e4nv"]] +
[(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', 'int8', 'int8')])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, num_ctas, device):
if is_interpreter():
Expand Down Expand Up @@ -3073,7 +3074,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
z_tri = torch.as_strided(z_tri, (M, N), [1, M])

if out_dtype == 'int8':
out_dtype = tl.int8
out_dtype = tl.int32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that doesn't make sense. That will change behavior of existing tests. If we want to tests i32 dtype it can be set in the config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output type for INT8 MMA is always INT32 (https://github.com/openai/triton/blob/main/python/triton/language/semantic.py#L1358), and out_dtype is simply ignored, as opposed to FP MMA (https://github.com/openai/triton/blob/main/python/triton/language/semantic.py#L1367)

I changed it to tl.int32 here just for correctness. If you feel necessary, I can add a check that out_dtype must be tl.int32 for tl.int8 inputs, in this or a separate PR.

elif out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
Expand Down Expand Up @@ -3106,7 +3107,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
assert "bar.sync" not in red_code
# torch result
if in_dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32)
z_ref = np.matmul(x, y, dtype=np.int32)
elif 'float8' in in_dtype:
x = convert_fp8_to_fp32(x, device, in_dtype)
y = convert_fp8_to_fp32(y, device, in_dtype)
Expand All @@ -3125,15 +3126,22 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
denom = np.sum(num, axis=-1, keepdims=True)
z_ref = num / denom
if epilogue == 'chain-dot':
compute_dtype = np.float32
if 'float8' in in_dtype:
w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype))
z_ref = np.matmul(z_ref, w)
if 'int8' in in_dtype:
# Truncating int32 to int8
z_ref = z_ref.astype(np.int8)
compute_dtype = np.int32
z_ref = np.matmul(z_ref, w, dtype=compute_dtype)
# compare
if in_dtype == 'float32':
# XXX: Somehow there's a larger difference when we use float32
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
elif out_dtype == tl.float16 or in_dtype == 'bfloat16':
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
elif out_dtype == tl.int32:
np.testing.assert_equal(z_ref, to_numpy(z_tri))
else:
# added atol, to loose precision for float16xfloat16->float32 case
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
Expand Down
32 changes: 32 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1577,3 +1577,35 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c
tt.return
}
}

// -----

#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: cvt_mma_to_dot_int8
// CHECK: nvvm.shfl.sync
// CHECK: nvvm.shfl.sync
// CHECK: prmt.b32
// CHECK: prmt.b32
tt.func @cvt_mma_to_dot_int8(%a: tensor<128x64xi8, #mma>) {
%opA = triton_gpu.convert_layout %a : tensor<128x64xi8, #mma> -> tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
tt.return
}
}

// -----

#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.compute-capability" = 89 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: cvt_mma_to_dot_fp8
// CHECK: nvvm.shfl.sync
// CHECK: nvvm.shfl.sync
// CHECK: prmt.b32
// CHECK: prmt.b32
tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) {
%opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
tt.return
}
}
2 changes: 0 additions & 2 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,6 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: cvt_mma_to_dot_fp8
// CHECK: prmt.b32
// CHECK: prmt.b32
Comment on lines -189 to -190
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are those removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two corresponding to the prmt-s using selectorEx0 and selectorEx1. These two patterns are simply selecting one Value of the two, so a select is sufficient and no prmt is required here.

// CHECK: nvvm.shfl.sync
// CHECK: nvvm.shfl.sync
// CHECK: prmt.b32
Expand Down
Loading
Loading