diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index fe99e71950cc..795396109992 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -612,7 +612,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, auto dstLayout = dstTy.getEncoding(); auto mmaLayout = srcLayout.cast(); auto dotOperandLayout = dstLayout.cast(); - int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth(); + unsigned elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth(); auto ans = mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) && @@ -628,12 +628,13 @@ bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { auto srcLayout = srcTy.getEncoding(); auto dstLayout = dstTy.getEncoding(); auto mmaLayout = srcLayout.cast(); + unsigned elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth(); auto dotOperandLayout = dstLayout.cast(); return mmaLayout.getVersionMajor() == 2 && mmaLayout.getWarpsPerCTA()[1] == 1 && dotOperandLayout.getOpIdx() == 0 && dotOperandLayout.getParent() == mmaLayout && - !srcTy.getElementType().isF32(); + (elementTypeSize == 16 || elementTypeSize == 8); } namespace { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4f6f158e2c7b..26fe2d126ed9 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2968,7 +2968,8 @@ def convert_fp8_to_fp32(x, device, dtype_str): for col_a in [True, False] for col_b in [True, False]] + [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32')] + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32') - for float8_type in ["float8e5", "float8e4nv"]]) + for float8_type in ["float8e5", "float8e4nv"]] + + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', 'int8', 'int8')]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, num_ctas, device): if is_interpreter(): @@ -3073,7 +3074,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid z_tri = torch.as_strided(z_tri, (M, N), [1, M]) if out_dtype == 'int8': - out_dtype = tl.int8 + out_dtype = tl.int32 elif out_dtype == 'float16' and epilogue != 'softmax': # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will # fail with the following error: 'llvm.fmul' op requires the same type @@ -3106,7 +3107,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid assert "bar.sync" not in red_code # torch result if in_dtype == 'int8': - z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) + z_ref = np.matmul(x, y, dtype=np.int32) elif 'float8' in in_dtype: x = convert_fp8_to_fp32(x, device, in_dtype) y = convert_fp8_to_fp32(y, device, in_dtype) @@ -3125,15 +3126,22 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid denom = np.sum(num, axis=-1, keepdims=True) z_ref = num / denom if epilogue == 'chain-dot': + compute_dtype = np.float32 if 'float8' in in_dtype: w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype)) - z_ref = np.matmul(z_ref, w) + if 'int8' in in_dtype: + # Truncating int32 to int8 + z_ref = z_ref.astype(np.int8) + compute_dtype = np.int32 + z_ref = np.matmul(z_ref, w, dtype=compute_dtype) # compare if in_dtype == 'float32': # XXX: Somehow there's a larger difference when we use float32 np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) elif out_dtype == tl.float16 or in_dtype == 'bfloat16': np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + elif out_dtype == tl.int32: + np.testing.assert_equal(z_ref, to_numpy(z_tri)) else: # added atol, to loose precision for float16xfloat16->float32 case np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index c32640a67b59..4883bea697db 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1577,3 +1577,35 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c tt.return } } + +// ----- + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// CHECK-LABEL: cvt_mma_to_dot_int8 +// CHECK: nvvm.shfl.sync +// CHECK: nvvm.shfl.sync +// CHECK: prmt.b32 +// CHECK: prmt.b32 + tt.func @cvt_mma_to_dot_int8(%a: tensor<128x64xi8, #mma>) { + %opA = triton_gpu.convert_layout %a : tensor<128x64xi8, #mma> -> tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.compute-capability" = 89 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// CHECK-LABEL: cvt_mma_to_dot_fp8 +// CHECK: nvvm.shfl.sync +// CHECK: nvvm.shfl.sync +// CHECK: prmt.b32 +// CHECK: prmt.b32 + tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) { + %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + tt.return + } +} diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 3e8ecbb4a8d0..5613e8ce34ca 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -186,8 +186,6 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: cvt_mma_to_dot_fp8 -// CHECK: prmt.b32 -// CHECK: prmt.b32 // CHECK: nvvm.shfl.sync // CHECK: nvvm.shfl.sync // CHECK: prmt.b32 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index bc3e7eed9c30..0e355c37954d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -586,9 +586,56 @@ struct ConvertLayoutOpConversion return success(); } + Value pack4xB8ToI32(Location loc, const SmallVector &vals, + unsigned start, + ConversionPatternRewriter &rewriter) const { + Value pack = undef(vec_ty(i8_ty, 4)); + pack = insert_element(vec_ty(i8_ty, 4), pack, + bitcast(vals[start + 0], i8_ty), i32_val(0)); + pack = insert_element(vec_ty(i8_ty, 4), pack, + bitcast(vals[start + 1], i8_ty), i32_val(1)); + pack = insert_element(vec_ty(i8_ty, 4), pack, + bitcast(vals[start + 2], i8_ty), i32_val(2)); + pack = insert_element(vec_ty(i8_ty, 4), pack, + bitcast(vals[start + 3], i8_ty), i32_val(3)); + return bitcast(pack, i32_ty); + } + // Convert from accumulator MMA layout to 8bit dot operand layout. // The conversion logic is taken from: // https://github.com/ColfaxResearch/cutlass-kernels/blob/a9de6446c1c0415c926025cea284210c799b11f8/src/fmha-pipeline/reg2reg.h#L45 + void shuffle8BitsMMAToDotOperand(Location loc, Value &upper, Value &lower, + ConversionPatternRewriter &rewriter) const { + Value threadIdx = getThreadId(rewriter, loc); + // (threadIdx + 1) & 2 == 0, select thread 0 3 + Value cnd = + icmp_eq(and_(add(threadIdx, i32_val(1)), i32_val(0b10)), i32_val(0)); + + // high bits ignored by shfl, 0 2 0 2 + Value shflIdx = shl(threadIdx, i32_val(1)); + Value shflIdxAlt = add(shflIdx, i32_val(1)); // 1 3 1 3 + + Value upperIdx = select(cnd, shflIdx, shflIdxAlt); // 0 3 1 2 + Value lowerIdx = select(cnd, shflIdxAlt, shflIdx); // 1 2 0 3 + + Value upper0 = select(cnd, upper, lower); + Value lower0 = select(cnd, lower, upper); + Value mask = i32_val(0xFFFFFFFF); + // Set clamp tp shuffle only within 4 lanes. + Value clamp = i32_val(0x1C1F); + upper0 = + rewriter.create(loc, i32_ty, mask, upper0, upperIdx, + clamp, NVVM::ShflKind::idx, UnitAttr()); + lower0 = + rewriter.create(loc, i32_ty, mask, lower0, lowerIdx, + clamp, NVVM::ShflKind::idx, UnitAttr()); + + Value selectorEx4 = select(cnd, i32_val(0x5410), i32_val(0x1054)); + Value selectorEx5 = select(cnd, i32_val(0x7632), i32_val(0x3276)); + upper = LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx4); + lower = LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx5); + } + void convertMMAV3To8BitsDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, @@ -598,76 +645,120 @@ struct ConvertLayoutOpConversion auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); SmallVector retVals; for (int i = 0; i < vals.size(); i += 8) { - Value upper = undef(vec_ty(i8_ty, 4)); + Value upper = pack4xB8ToI32(loc, vals, i, rewriter); + Value lower = pack4xB8ToI32(loc, vals, i + 4, rewriter); + + shuffle8BitsMMAToDotOperand(loc, upper, lower, rewriter); + + Value vecVal = bitcast(upper, vec_ty(i8_ty, 4)); for (int j = 0; j < 4; j++) { - upper = - insert_element(vec_ty(i8_ty, 4), upper, vals[i + j], i32_val(j)); + retVals.push_back(extract_element(i8_ty, vecVal, i32_val(j))); } - upper = bitcast(upper, i32_ty); - Value lower = undef(vec_ty(i8_ty, 4)); + vecVal = bitcast(lower, vec_ty(i8_ty, 4)); for (int j = 0; j < 4; j++) { - lower = insert_element(vec_ty(i8_ty, 4), lower, vals[i + 4 + j], - i32_val(j)); - } - lower = bitcast(lower, i32_ty); - - Value threadIdMod4 = urem(getThreadId(rewriter, loc), i32_val(4)); - Value cnd = or_(icmp_eq(threadIdMod4, i32_val(0)), - icmp_eq(threadIdMod4, i32_val(3))); - Value selectorEx0 = select(cnd, i32_val(0x3210), i32_val(0x7654)); - Value selectorEx1 = select(cnd, i32_val(0x7654), i32_val(0x3210)); - Value selectorEx4 = select(cnd, i32_val(0x5410), i32_val(0x1054)); - Value selectorEx5 = select(cnd, i32_val(0x7632), i32_val(0x3276)); - - Value isOne = icmp_eq(threadIdMod4, i32_val(1)); - Value isTwo = icmp_eq(threadIdMod4, i32_val(2)); - Value isThree = icmp_eq(threadIdMod4, i32_val(3)); - Value upperIdx = i32_val(0); - upperIdx = select(isOne, i32_val(3), upperIdx); - upperIdx = select(isTwo, i32_val(1), upperIdx); - upperIdx = select(isThree, i32_val(2), upperIdx); - - Value lowerIdx = i32_val(1); - lowerIdx = select(isOne, i32_val(2), lowerIdx); - lowerIdx = select(isTwo, i32_val(0), lowerIdx); - lowerIdx = select(isThree, i32_val(3), lowerIdx); - - Value upper0 = - LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx0); - Value lower0 = - LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx1); - Value mask = i32_val(0xFFFFFFFF); - // Set clamp tp shuffle only within 4 lanes. - Value clamp = i32_val(0x1C1F); - upper0 = - rewriter.create(loc, i32_ty, mask, upper0, upperIdx, - clamp, NVVM::ShflKind::idx, UnitAttr()); - lower0 = - rewriter.create(loc, i32_ty, mask, lower0, lowerIdx, - clamp, NVVM::ShflKind::idx, UnitAttr()); - Value upper1 = - LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx4); - Value vecVal = bitcast(upper1, vec_ty(i8_ty, 4)); - for (int i = 0; i < 4; i++) { - retVals.push_back(extract_element(i8_ty, vecVal, i32_val(i))); + retVals.push_back(extract_element(i8_ty, vecVal, i32_val(j))); } - Value lower1 = - LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx5); - vecVal = bitcast(lower1, vec_ty(i8_ty, 4)); - for (int i = 0; i < 4; i++) { - retVals.push_back(extract_element(i8_ty, vecVal, i32_val(i))); + } + Value result = + packLLElements(loc, getTypeConverter(), retVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + } + + void + convertMMAV2To8BitsDotOperand(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto dstTy = op.getType(); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector retVals; + assert(vals.size() % 16 == 0 && "Unsupported MMA output size"); + unsigned srcBits = vals[0].getType().getIntOrFloatBitWidth(); + assert(srcBits == 8 && "Unsupported src element size"); + for (int i = 0; i < vals.size(); i += 8) { + Value upper = pack4xB8ToI32(loc, vals, i, rewriter); + Value lower = pack4xB8ToI32(loc, vals, i + 4, rewriter); + + shuffle8BitsMMAToDotOperand(loc, upper, lower, rewriter); + + if (i % 16 != 0) { + std::swap(retVals.back(), upper); } + + retVals.push_back(upper); + retVals.push_back(lower); } Value result = packLLElements(loc, getTypeConverter(), retVals, rewriter, dstTy); rewriter.replaceOp(op, result); } + void + convertMMAV2To16BitsDotOperand(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + // get source values + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + unsigned elems = getTotalElemsPerThread(srcTy); + Type elemTy = this->getTypeConverter()->convertType(srcTy.getElementType()); + // for the destination type, we need to pack values together + // so they can be consumed by tensor core operations + SmallVector vecVals; + SmallVector types; + // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer + // instructions to pack & unpack sub-word integers. A workaround is to + // store the results of ldmatrix in i32 + auto elemSize = elemTy.getIntOrFloatBitWidth(); + if (auto intTy = elemTy.dyn_cast() && elemSize <= 16) { + auto fold = 32 / elemSize; + for (unsigned i = 0; i < elems; i += fold) { + Value val = i32_val(0); + for (unsigned j = 0; j < fold; j++) { + auto ext = + shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); + val = or_(i32_ty, val, ext); + } + vecVals.push_back(val); + } + elems = elems / (32 / elemSize); + types = SmallVector(elems, i32_ty); + } else { + unsigned vecSize = std::max(32 / elemSize, 1); + Type vecTy = vec_ty(elemTy, vecSize); + types = SmallVector(elems / vecSize, vecTy); + for (unsigned i = 0; i < elems; i += vecSize) { + Value packed = rewriter.create(loc, vecTy); + for (unsigned j = 0; j < vecSize; j++) + packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); + vecVals.push_back(packed); + } + } + + // This needs to be ordered the same way that + // ldmatrix.x4 would order it + // TODO: this needs to be refactor so we don't + // implicitly depends on how emitOffsetsForMMAV2 + // is implemented + SmallVector reorderedVals; + for (unsigned i = 0; i < vecVals.size(); i += 4) { + reorderedVals.push_back(bitcast(vecVals[i], i32_ty)); + reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty)); + reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty)); + reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty)); + } + + Value view = + packLLElements(loc, getTypeConverter(), reorderedVals, rewriter, dstTy); + rewriter.replaceOp(op, view); + } + // mma -> dot_operand LogicalResult lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) { @@ -682,60 +773,12 @@ struct ConvertLayoutOpConversion } if (isMmaToDotShortcut(srcTy, dstTy)) { - // get source values - auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - unsigned elems = getTotalElemsPerThread(srcTy); - Type elemTy = - this->getTypeConverter()->convertType(srcTy.getElementType()); - // for the destination type, we need to pack values together - // so they can be consumed by tensor core operations - SmallVector vecVals; - SmallVector types; - // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer - // instructions to pack & unpack sub-word integers. A workaround is to - // store the results of ldmatrix in i32 - auto elemSize = elemTy.getIntOrFloatBitWidth(); - if (auto intTy = elemTy.dyn_cast() && elemSize <= 16) { - auto fold = 32 / elemSize; - for (unsigned i = 0; i < elems; i += fold) { - Value val = i32_val(0); - for (unsigned j = 0; j < fold; j++) { - auto ext = - shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); - val = or_(i32_ty, val, ext); - } - vecVals.push_back(val); - } - elems = elems / (32 / elemSize); - types = SmallVector(elems, i32_ty); - } else { - unsigned vecSize = std::max(32 / elemSize, 1); - Type vecTy = vec_ty(elemTy, vecSize); - types = SmallVector(elems / vecSize, vecTy); - for (unsigned i = 0; i < elems; i += vecSize) { - Value packed = rewriter.create(loc, vecTy); - for (unsigned j = 0; j < vecSize; j++) - packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); - vecVals.push_back(packed); - } - } - - // This needs to be ordered the same way that - // ldmatrix.x4 would order it - // TODO: this needs to be refactor so we don't - // implicitly depends on how emitOffsetsForMMAV2 - // is implemented - SmallVector reorderedVals; - for (unsigned i = 0; i < vecVals.size(); i += 4) { - reorderedVals.push_back(bitcast(vecVals[i], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty)); + if (srcTy.getElementType().getIntOrFloatBitWidth() == 8) { + convertMMAV2To8BitsDotOperand(op, adaptor, rewriter); + return success(); } - Value view = packLLElements(loc, getTypeConverter(), reorderedVals, - rewriter, dstTy); - rewriter.replaceOp(op, view); + convertMMAV2To16BitsDotOperand(op, adaptor, rewriter); return success(); } return failure();