Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dialect] [Linalgx] Add linalgx ops: 3 vnni matmuls and multi_batch_matmul #89

Merged
merged 19 commits into from
Jun 3, 2024

Conversation

LongshengDu
Copy link
Contributor

@LongshengDu LongshengDu commented May 20, 2024

Added mm2d_vnni, mm4d_vnni, batch_reduce_matmul_vnni with custom verifier, static indexing_maps and iterator_types, and using vnni dims to get constant symbol for indexing_maps.

Added multi_batch_matmul with LinalgContractionOpInterface, dynamic indexing_maps and iterator_types.

Tracking: #14

@LongshengDu LongshengDu changed the title [WIP] [Linalgx Dialect] Add linalgx ops: mmt2d_vnni, mmt4d_vnni, multi_batch_matmul [WIP] [Dialect] [Linalgx] Add linalgx ops: mmt2d_vnni, mmt4d_vnni, multi_batch_matmul May 20, 2024
@LongshengDu LongshengDu added the WIP work in progress label May 20, 2024
@LongshengDu LongshengDu changed the title [WIP] [Dialect] [Linalgx] Add linalgx ops: mmt2d_vnni, mmt4d_vnni, multi_batch_matmul [Dialect] [Linalgx] Add linalgx ops: mmt2d_vnni, mmt4d_vnni, multi_batch_matmul May 20, 2024
include/gc/Dialect/Linalgx/LinalgxStructuredOps.td Outdated Show resolved Hide resolved
lib/gc/Dialect/Linalgx/LinalgxOps.cpp Show resolved Hide resolved
bool matchK =
shapeA.getDimSize(1) ==
(shapeB.getDimSize(1) * shapeB.getDimSize(2) * shapeB.getDimSize(4));
bool matchVnni = (shapeB.getDimSize(4) == 1) || (shapeB.getDimSize(4) == 2) ||

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, why do we consider 1 as vnni format here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input: MKmk/MK
weight: NKkn4k/NKkn2k/NKkn/KN
output: MNmn/MN

For NKkn format, can we treat it as NKkn1k so we can reuse this op?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave a note here if you'd like to reuse it for F32 datatypes, as vnni always refers to low precision with blk_size 2 or 4.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a name for NKkn format (differentiating from mmt4d's Nknk)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a name for NKkn format (differentiating from mmt4d's Nknk)?

The name is NKkn. for op naming, I think it's mm4d_vnni(no transpose).

@@ -24,3 +24,84 @@ func.func @generalize_sigmoid(%arg0: tensor<4x256x64xbf16>, %arg1: tensor<4x256x
// CHECK-NEXT: linalg.yield %[[DIV]] : bf16

// -----

func.func @generalize_mmt2d_vnni(%arg0: tensor<256x64xf32>, %arg1: tensor<16x2x8x32x4xf32>,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also add failed case for checking?

lib/gc/Dialect/Linalgx/LinalgxOps.cpp Show resolved Hide resolved
}];
}

def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also define multi_batch_matmul_4d and multi_batch_matmul_4d_vnni?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this define correct for multi_batch_matmul_4d and multi_batch_matmul_4d_vnni? Do we also want transposed weight?

input: BMKmk
weight: BNKkn4k/BNKkn2k/BNKkn1k
output: BMNmn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. We don't use transposed weight.

@zhczhong
Copy link

Do we need to add batch_reduce_matmul_vnni so that the matmul_vnni could be lowered to brgemm_vnni named op?

input: BMK
weight: BKN4k/BKN2k
output: MN

@LongshengDu
Copy link
Contributor Author

Do we need to add batch_reduce_matmul_vnni so that the matmul_vnni could be lowered to brgemm_vnni named op?

input: BMK
weight: BKN4k/BKN2k
output: MN

Yeah, will add this in the future

@LongshengDu LongshengDu changed the title [Dialect] [Linalgx] Add linalgx ops: mmt2d_vnni, mmt4d_vnni, multi_batch_matmul [Dialect] [Linalgx] Add linalgx ops: mm2d_vnni, mm4d_vnni, batch_reduce_matmul_vnni, multi_batch_matmul May 22, 2024
@LongshengDu LongshengDu changed the title [Dialect] [Linalgx] Add linalgx ops: mm2d_vnni, mm4d_vnni, batch_reduce_matmul_vnni, multi_batch_matmul [Dialect] [Linalgx] Add linalgx ops: 3 vnni matmuls and multi_batch_matmul May 22, 2024
bool matchK =
shapeA.getDimSize(1) ==
(shapeB.getDimSize(1) * shapeB.getDimSize(2) * shapeB.getDimSize(4));
bool matchVnni = (shapeB.getDimSize(4) == 2) || (shapeB.getDimSize(4) == 4);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also check vnni dim value based on dtype? (e.g. we restrain vnni to be 2 under bf16; and 4 under u8/s8).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Base automatically changed from longsheng/add_linalgx to main May 29, 2024 04:55
@LongshengDu LongshengDu removed the WIP work in progress label Jun 3, 2024
@LongshengDu LongshengDu merged commit cecc53c into main Jun 3, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants