Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform][Tiling] Add deep tile support for matmul #90

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

zhczhong
Copy link

@zhczhong zhczhong commented May 20, 2024

Tracking #53

TODO:

  • the nested outer loop generation
  • partial reduction support
    • enhance the PartialReductionOpInterface to allow user control where the new parallel dims are inserted
    • Erase the reducant linalg.FillOp in partial reduction
  • merge all parallel iterator into a single scf.forall before nested parallel is ready
  • fuse the linalg.fillOp into the innermost loop body
  • replace all genericOp with linalg named op
  • Support 4Dx4/5D->4D, 2Dx2D->2D, 2Dx4/5D->2D
  • Dtype Support(f32, bf16)
  • Fuse the f32->bf16 cast into the last loop about K axis
  • Support Batch matmul
  • Balance211 support
  • Tune a general matmul config based on cost model
  • Fuse the linalg.copy to the innermost loop

The innerloop generation part depends on the easy builder support(#62)

@zhczhong zhczhong added the WIP work in progress label May 20, 2024
@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 3 times, most recently from 7c8cfbb to 927322a Compare May 23, 2024 06:11
@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 6 times, most recently from ea02416 to f261c3c Compare June 3, 2024 03:47
@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 5 times, most recently from 5ed4fc1 to 22d86d4 Compare June 5, 2024 03:21
@zhczhong
Copy link
Author

zhczhong commented Jun 5, 2024

Support use linalgx.batch_reduce_vnni(bf16xbf16->f32) and fuse the cast(f32->bf16) to the last loop about K axis

func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
    %cst_0 = arith.constant 0.000000e+00 : bf16
    %0 = tensor.empty() : tensor<128x128x32x32xbf16>
    %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16>
    %2 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>)  -> tensor<128x128x32x32xbf16>
    return %2 : tensor<128x128x32x32xbf16>
}

will be transformed into

#map = affine_map<(d0) -> (d0 * 64)>
#map1 = affine_map<(d0)[s0, s1] -> (d0 * 64 + s0 + s1)>
module {
  func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
    %c1 = arith.constant 1 : index
    %c128 = arith.constant 128 : index
    %c2 = arith.constant 2 : index
    %c64 = arith.constant 64 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : bf16
    %0 = tensor.empty() : tensor<128x128x32x32xbf16>
    %1 = scf.forall (%arg2, %arg3) in (2, 2) shared_outs(%arg4 = %0) -> (tensor<128x128x32x32xbf16>) {
      %2 = affine.apply #map(%arg2)
      %3 = affine.apply #map(%arg3)
      %extracted_slice = tensor.extract_slice %arg4[%2, %3, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<64x64x32x32xbf16>
      %4 = scf.for %arg5 = %c0 to %c64 step %c2 iter_args(%arg6 = %extracted_slice) -> (tensor<64x64x32x32xbf16>) {
        %extracted_slice_0 = tensor.extract_slice %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> to tensor<2x64x32x32xbf16>
        %7 = scf.for %arg7 = %c0 to %c64 step %c2 iter_args(%arg8 = %extracted_slice_0) -> (tensor<2x64x32x32xbf16>) {
          %extracted_slice_1 = tensor.extract_slice %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> to tensor<2x2x32x32xbf16>
          %8 = tensor.empty() : tensor<2x2x32x32xf32>
          %9 = scf.for %arg9 = %c0 to %c128 step %c2 iter_args(%arg10 = %8) -> (tensor<2x2x32x32xf32>) {
            %11 = scf.for %arg11 = %c0 to %c2 step %c1 iter_args(%arg12 = %arg10) -> (tensor<2x2x32x32xf32>) {
              %extracted_slice_3 = tensor.extract_slice %arg12[%arg11, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xf32> to tensor<1x2x32x32xf32>
              %12 = scf.for %arg13 = %c0 to %c2 step %c1 iter_args(%arg14 = %extracted_slice_3) -> (tensor<1x2x32x32xf32>) {
                %13 = affine.apply #map1(%arg2)[%arg11, %arg5]
                %extracted_slice_5 = tensor.extract_slice %arg0[%13, %arg9, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<2x32x32xbf16>
                %14 = affine.apply #map1(%arg3)[%arg13, %arg7]
                %extracted_slice_6 = tensor.extract_slice %arg1[%14, %arg9, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<128x128x16x32x2xbf16> to tensor<2x16x32x2xbf16>
                %extracted_slice_7 = tensor.extract_slice %arg14[0, %arg13, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> to tensor<32x32xf32>
                %15 = arith.cmpi eq, %arg9, %c0 : index
                %16 = scf.if %15 -> (tensor<32x32xf32>) {
                  %17 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_7 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  %18 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_5, %extracted_slice_6 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%17 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  scf.yield %18 : tensor<32x32xf32>
                } else {
                  %17 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_5, %extracted_slice_6 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%extracted_slice_7 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  scf.yield %17 : tensor<32x32xf32>
                }
                %inserted_slice_8 = tensor.insert_slice %16 into %arg14[0, %arg13, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<1x2x32x32xf32>
                scf.yield %inserted_slice_8 : tensor<1x2x32x32xf32>
              }
              %inserted_slice_4 = tensor.insert_slice %12 into %arg12[%arg11, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> into tensor<2x2x32x32xf32>
              scf.yield %inserted_slice_4 : tensor<2x2x32x32xf32>
            }
            scf.yield %11 : tensor<2x2x32x32xf32>
          }
          %10 = linalg.copy ins(%9 : tensor<2x2x32x32xf32>) outs(%extracted_slice_1 : tensor<2x2x32x32xbf16>) -> tensor<2x2x32x32xbf16>
          %inserted_slice_2 = tensor.insert_slice %10 into %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xbf16> into tensor<2x64x32x32xbf16>
          scf.yield %inserted_slice_2 : tensor<2x64x32x32xbf16>
        }
        %inserted_slice = tensor.insert_slice %7 into %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> into tensor<64x64x32x32xbf16>
        scf.yield %inserted_slice : tensor<64x64x32x32xbf16>
      }
      %5 = affine.apply #map(%arg2)
      %6 = affine.apply #map(%arg3)
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %4 into %arg4[%5, %6, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> into tensor<128x128x32x32xbf16>
      }
    }
    return %1 : tensor<128x128x32x32xbf16>
  }
}

return idxList;
}

MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) {
Copy link

@yifeizh2 yifeizh2 Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expose this method in the future, as it would be used in layout inference logic of global layout analysis pass.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will expose this method later


namespace {

struct SystemDesc {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this struct be replaced with target description in future?

Copy link
Author

@zhczhong zhczhong Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is just a mock class here and will be replaced with target description when it is ready

@zhczhong
Copy link
Author

Update: Fuse the cast(f32->bf16) to the innermost loop

func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
    %cst_0 = arith.constant 0.000000e+00 : bf16
    %0 = tensor.empty() : tensor<128x128x32x32xbf16>
    %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16>
    %2 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>)  -> tensor<128x128x32x32xbf16>
    return %2 : tensor<128x128x32x32xbf16>
}

will be transformed to

#map = affine_map<(d0) -> (d0 * 64)>
#map1 = affine_map<(d0)[s0, s1] -> (d0 * 64 + s0 + s1)>
module {
  func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
    %c1 = arith.constant 1 : index
    %c128 = arith.constant 128 : index
    %c2 = arith.constant 2 : index
    %c64 = arith.constant 64 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : bf16
    %0 = tensor.empty() : tensor<128x128x32x32xbf16>
    %1 = scf.forall (%arg2, %arg3) in (2, 2) shared_outs(%arg4 = %0) -> (tensor<128x128x32x32xbf16>) {
      %2 = affine.apply #map(%arg2)
      %3 = affine.apply #map(%arg3)
      %extracted_slice = tensor.extract_slice %arg4[%2, %3, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<64x64x32x32xbf16>
      %4 = scf.for %arg5 = %c0 to %c64 step %c2 iter_args(%arg6 = %extracted_slice) -> (tensor<64x64x32x32xbf16>) {
        %extracted_slice_0 = tensor.extract_slice %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> to tensor<2x64x32x32xbf16>
        %7 = scf.for %arg7 = %c0 to %c64 step %c2 iter_args(%arg8 = %extracted_slice_0) -> (tensor<2x64x32x32xbf16>) {
          %extracted_slice_1 = tensor.extract_slice %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> to tensor<2x2x32x32xbf16>
          %8 = tensor.empty() : tensor<2x2x32x32xf32>
          %9:2 = scf.for %arg9 = %c0 to %c128 step %c2 iter_args(%arg10 = %8, %arg11 = %extracted_slice_1) -> (tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>) {
            %10:2 = scf.for %arg12 = %c0 to %c2 step %c1 iter_args(%arg13 = %arg10, %arg14 = %arg11) -> (tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>) {
              %extracted_slice_3 = tensor.extract_slice %arg13[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xf32> to tensor<1x2x32x32xf32>
              %extracted_slice_4 = tensor.extract_slice %arg14[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xbf16> to tensor<1x2x32x32xbf16>
              %11:2 = scf.for %arg15 = %c0 to %c2 step %c1 iter_args(%arg16 = %extracted_slice_3, %arg17 = %extracted_slice_4) -> (tensor<1x2x32x32xf32>, tensor<1x2x32x32xbf16>) {
                %12 = affine.apply #map1(%arg2)[%arg12, %arg5]
                %extracted_slice_7 = tensor.extract_slice %arg0[%12, %arg9, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<2x32x32xbf16>
                %13 = affine.apply #map1(%arg3)[%arg15, %arg7]
                %extracted_slice_8 = tensor.extract_slice %arg1[%13, %arg9, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<128x128x16x32x2xbf16> to tensor<2x16x32x2xbf16>
                %extracted_slice_9 = tensor.extract_slice %arg16[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> to tensor<32x32xf32>
                %extracted_slice_10 = tensor.extract_slice %arg17[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xbf16> to tensor<32x32xbf16>
                %14 = arith.cmpi eq, %arg9, %c0 : index
                %15 = scf.if %14 -> (tensor<32x32xf32>) {
                  %18 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_9 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  %19 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_7, %extracted_slice_8 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%18 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  scf.yield %19 : tensor<32x32xf32>
                } else {
                  %18 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_7, %extracted_slice_8 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%extracted_slice_9 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  scf.yield %18 : tensor<32x32xf32>
                }
                %16 = arith.cmpi eq, %arg9, %c0 : index
                %17 = scf.if %16 -> (tensor<32x32xbf16>) {
                  %18 = linalg.copy ins(%15 : tensor<32x32xf32>) outs(%extracted_slice_10 : tensor<32x32xbf16>) -> tensor<32x32xbf16>
                  scf.yield %18 : tensor<32x32xbf16>
                } else {
                  scf.yield %extracted_slice_10 : tensor<32x32xbf16>
                }
                %inserted_slice_11 = tensor.insert_slice %15 into %arg16[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<1x2x32x32xf32>
                %inserted_slice_12 = tensor.insert_slice %17 into %arg17[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xbf16> into tensor<1x2x32x32xbf16>
                scf.yield %inserted_slice_11, %inserted_slice_12 : tensor<1x2x32x32xf32>, tensor<1x2x32x32xbf16>
              }
              %inserted_slice_5 = tensor.insert_slice %11#0 into %arg13[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> into tensor<2x2x32x32xf32>
              %inserted_slice_6 = tensor.insert_slice %11#1 into %arg14[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xbf16> into tensor<2x2x32x32xbf16>
              scf.yield %inserted_slice_5, %inserted_slice_6 : tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>
            }
            scf.yield %10#0, %10#1 : tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>
          }
          %inserted_slice_2 = tensor.insert_slice %9#1 into %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xbf16> into tensor<2x64x32x32xbf16>
          scf.yield %inserted_slice_2 : tensor<2x64x32x32xbf16>
        }
        %inserted_slice = tensor.insert_slice %7 into %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> into tensor<64x64x32x32xbf16>
        scf.yield %inserted_slice : tensor<64x64x32x32xbf16>
      }
      %5 = affine.apply #map(%arg2)
      %6 = affine.apply #map(%arg3)
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %4 into %arg4[%5, %6, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> into tensor<128x128x32x32xbf16>
      }
    }
    return %1 : tensor<128x128x32x32xbf16>
  }
}


} // namespace impl

impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing an inline here

}

#define DEF_EASYBUILD_CMP_OPERATOR(OP, OPCLASS, TYPE, PRED) \
EBUnsigned operator OP(const TYPE &a, const TYPE &b) { \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing an inline here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants