-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform][Tiling] Add deep tile support for matmul #90
base: main
Are you sure you want to change the base?
Conversation
7c8cfbb
to
927322a
Compare
ea02416
to
f261c3c
Compare
5ed4fc1
to
22d86d4
Compare
Support use linalgx.batch_reduce_vnni(bf16xbf16->f32) and fuse the cast(f32->bf16) to the last loop about K axis func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
%cst_0 = arith.constant 0.000000e+00 : bf16
%0 = tensor.empty() : tensor<128x128x32x32xbf16>
%1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16>
%2 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16>
return %2 : tensor<128x128x32x32xbf16>
} will be transformed into #map = affine_map<(d0) -> (d0 * 64)>
#map1 = affine_map<(d0)[s0, s1] -> (d0 * 64 + s0 + s1)>
module {
func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
%c1 = arith.constant 1 : index
%c128 = arith.constant 128 : index
%c2 = arith.constant 2 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%0 = tensor.empty() : tensor<128x128x32x32xbf16>
%1 = scf.forall (%arg2, %arg3) in (2, 2) shared_outs(%arg4 = %0) -> (tensor<128x128x32x32xbf16>) {
%2 = affine.apply #map(%arg2)
%3 = affine.apply #map(%arg3)
%extracted_slice = tensor.extract_slice %arg4[%2, %3, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<64x64x32x32xbf16>
%4 = scf.for %arg5 = %c0 to %c64 step %c2 iter_args(%arg6 = %extracted_slice) -> (tensor<64x64x32x32xbf16>) {
%extracted_slice_0 = tensor.extract_slice %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> to tensor<2x64x32x32xbf16>
%7 = scf.for %arg7 = %c0 to %c64 step %c2 iter_args(%arg8 = %extracted_slice_0) -> (tensor<2x64x32x32xbf16>) {
%extracted_slice_1 = tensor.extract_slice %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> to tensor<2x2x32x32xbf16>
%8 = tensor.empty() : tensor<2x2x32x32xf32>
%9 = scf.for %arg9 = %c0 to %c128 step %c2 iter_args(%arg10 = %8) -> (tensor<2x2x32x32xf32>) {
%11 = scf.for %arg11 = %c0 to %c2 step %c1 iter_args(%arg12 = %arg10) -> (tensor<2x2x32x32xf32>) {
%extracted_slice_3 = tensor.extract_slice %arg12[%arg11, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xf32> to tensor<1x2x32x32xf32>
%12 = scf.for %arg13 = %c0 to %c2 step %c1 iter_args(%arg14 = %extracted_slice_3) -> (tensor<1x2x32x32xf32>) {
%13 = affine.apply #map1(%arg2)[%arg11, %arg5]
%extracted_slice_5 = tensor.extract_slice %arg0[%13, %arg9, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<2x32x32xbf16>
%14 = affine.apply #map1(%arg3)[%arg13, %arg7]
%extracted_slice_6 = tensor.extract_slice %arg1[%14, %arg9, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<128x128x16x32x2xbf16> to tensor<2x16x32x2xbf16>
%extracted_slice_7 = tensor.extract_slice %arg14[0, %arg13, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> to tensor<32x32xf32>
%15 = arith.cmpi eq, %arg9, %c0 : index
%16 = scf.if %15 -> (tensor<32x32xf32>) {
%17 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_7 : tensor<32x32xf32>) -> tensor<32x32xf32>
%18 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_5, %extracted_slice_6 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%17 : tensor<32x32xf32>) -> tensor<32x32xf32>
scf.yield %18 : tensor<32x32xf32>
} else {
%17 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_5, %extracted_slice_6 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%extracted_slice_7 : tensor<32x32xf32>) -> tensor<32x32xf32>
scf.yield %17 : tensor<32x32xf32>
}
%inserted_slice_8 = tensor.insert_slice %16 into %arg14[0, %arg13, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<1x2x32x32xf32>
scf.yield %inserted_slice_8 : tensor<1x2x32x32xf32>
}
%inserted_slice_4 = tensor.insert_slice %12 into %arg12[%arg11, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> into tensor<2x2x32x32xf32>
scf.yield %inserted_slice_4 : tensor<2x2x32x32xf32>
}
scf.yield %11 : tensor<2x2x32x32xf32>
}
%10 = linalg.copy ins(%9 : tensor<2x2x32x32xf32>) outs(%extracted_slice_1 : tensor<2x2x32x32xbf16>) -> tensor<2x2x32x32xbf16>
%inserted_slice_2 = tensor.insert_slice %10 into %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xbf16> into tensor<2x64x32x32xbf16>
scf.yield %inserted_slice_2 : tensor<2x64x32x32xbf16>
}
%inserted_slice = tensor.insert_slice %7 into %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> into tensor<64x64x32x32xbf16>
scf.yield %inserted_slice : tensor<64x64x32x32xbf16>
}
%5 = affine.apply #map(%arg2)
%6 = affine.apply #map(%arg3)
scf.forall.in_parallel {
tensor.parallel_insert_slice %4 into %arg4[%5, %6, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> into tensor<128x128x32x32xbf16>
}
}
return %1 : tensor<128x128x32x32xbf16>
}
} |
return idxList; | ||
} | ||
|
||
MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you expose this method in the future, as it would be used in layout inference logic of global layout analysis pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I will expose this method later
|
||
namespace { | ||
|
||
struct SystemDesc { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this struct be replaced with target description in future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is just a mock class here and will be replaced with target description when it is ready
22d86d4
to
8577250
Compare
Update: Fuse the cast(f32->bf16) to the innermost loop func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
%cst_0 = arith.constant 0.000000e+00 : bf16
%0 = tensor.empty() : tensor<128x128x32x32xbf16>
%1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16>
%2 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16>
return %2 : tensor<128x128x32x32xbf16>
} will be transformed to #map = affine_map<(d0) -> (d0 * 64)>
#map1 = affine_map<(d0)[s0, s1] -> (d0 * 64 + s0 + s1)>
module {
func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
%c1 = arith.constant 1 : index
%c128 = arith.constant 128 : index
%c2 = arith.constant 2 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%0 = tensor.empty() : tensor<128x128x32x32xbf16>
%1 = scf.forall (%arg2, %arg3) in (2, 2) shared_outs(%arg4 = %0) -> (tensor<128x128x32x32xbf16>) {
%2 = affine.apply #map(%arg2)
%3 = affine.apply #map(%arg3)
%extracted_slice = tensor.extract_slice %arg4[%2, %3, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<64x64x32x32xbf16>
%4 = scf.for %arg5 = %c0 to %c64 step %c2 iter_args(%arg6 = %extracted_slice) -> (tensor<64x64x32x32xbf16>) {
%extracted_slice_0 = tensor.extract_slice %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> to tensor<2x64x32x32xbf16>
%7 = scf.for %arg7 = %c0 to %c64 step %c2 iter_args(%arg8 = %extracted_slice_0) -> (tensor<2x64x32x32xbf16>) {
%extracted_slice_1 = tensor.extract_slice %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> to tensor<2x2x32x32xbf16>
%8 = tensor.empty() : tensor<2x2x32x32xf32>
%9:2 = scf.for %arg9 = %c0 to %c128 step %c2 iter_args(%arg10 = %8, %arg11 = %extracted_slice_1) -> (tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>) {
%10:2 = scf.for %arg12 = %c0 to %c2 step %c1 iter_args(%arg13 = %arg10, %arg14 = %arg11) -> (tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>) {
%extracted_slice_3 = tensor.extract_slice %arg13[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xf32> to tensor<1x2x32x32xf32>
%extracted_slice_4 = tensor.extract_slice %arg14[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xbf16> to tensor<1x2x32x32xbf16>
%11:2 = scf.for %arg15 = %c0 to %c2 step %c1 iter_args(%arg16 = %extracted_slice_3, %arg17 = %extracted_slice_4) -> (tensor<1x2x32x32xf32>, tensor<1x2x32x32xbf16>) {
%12 = affine.apply #map1(%arg2)[%arg12, %arg5]
%extracted_slice_7 = tensor.extract_slice %arg0[%12, %arg9, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<2x32x32xbf16>
%13 = affine.apply #map1(%arg3)[%arg15, %arg7]
%extracted_slice_8 = tensor.extract_slice %arg1[%13, %arg9, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<128x128x16x32x2xbf16> to tensor<2x16x32x2xbf16>
%extracted_slice_9 = tensor.extract_slice %arg16[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> to tensor<32x32xf32>
%extracted_slice_10 = tensor.extract_slice %arg17[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xbf16> to tensor<32x32xbf16>
%14 = arith.cmpi eq, %arg9, %c0 : index
%15 = scf.if %14 -> (tensor<32x32xf32>) {
%18 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_9 : tensor<32x32xf32>) -> tensor<32x32xf32>
%19 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_7, %extracted_slice_8 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%18 : tensor<32x32xf32>) -> tensor<32x32xf32>
scf.yield %19 : tensor<32x32xf32>
} else {
%18 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_7, %extracted_slice_8 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%extracted_slice_9 : tensor<32x32xf32>) -> tensor<32x32xf32>
scf.yield %18 : tensor<32x32xf32>
}
%16 = arith.cmpi eq, %arg9, %c0 : index
%17 = scf.if %16 -> (tensor<32x32xbf16>) {
%18 = linalg.copy ins(%15 : tensor<32x32xf32>) outs(%extracted_slice_10 : tensor<32x32xbf16>) -> tensor<32x32xbf16>
scf.yield %18 : tensor<32x32xbf16>
} else {
scf.yield %extracted_slice_10 : tensor<32x32xbf16>
}
%inserted_slice_11 = tensor.insert_slice %15 into %arg16[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<1x2x32x32xf32>
%inserted_slice_12 = tensor.insert_slice %17 into %arg17[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xbf16> into tensor<1x2x32x32xbf16>
scf.yield %inserted_slice_11, %inserted_slice_12 : tensor<1x2x32x32xf32>, tensor<1x2x32x32xbf16>
}
%inserted_slice_5 = tensor.insert_slice %11#0 into %arg13[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> into tensor<2x2x32x32xf32>
%inserted_slice_6 = tensor.insert_slice %11#1 into %arg14[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xbf16> into tensor<2x2x32x32xbf16>
scf.yield %inserted_slice_5, %inserted_slice_6 : tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>
}
scf.yield %10#0, %10#1 : tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>
}
%inserted_slice_2 = tensor.insert_slice %9#1 into %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xbf16> into tensor<2x64x32x32xbf16>
scf.yield %inserted_slice_2 : tensor<2x64x32x32xbf16>
}
%inserted_slice = tensor.insert_slice %7 into %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> into tensor<64x64x32x32xbf16>
scf.yield %inserted_slice : tensor<64x64x32x32xbf16>
}
%5 = affine.apply #map(%arg2)
%6 = affine.apply #map(%arg3)
scf.forall.in_parallel {
tensor.parallel_insert_slice %4 into %arg4[%5, %6, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> into tensor<128x128x32x32xbf16>
}
}
return %1 : tensor<128x128x32x32xbf16>
}
} |
8577250
to
206fead
Compare
|
||
} // namespace impl | ||
|
||
impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing an inline
here
} | ||
|
||
#define DEF_EASYBUILD_CMP_OPERATOR(OP, OPCLASS, TYPE, PRED) \ | ||
EBUnsigned operator OP(const TYPE &a, const TYPE &b) { \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing an inline
here
206fead
to
65dfab8
Compare
Tracking #53
TODO:
The innerloop generation part depends on the easy builder support(#62)