-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dialect] [Linalgx] Add linalgx ops: 3 vnni matmuls and multi_batch_matmul #89
Conversation
bool matchK = | ||
shapeA.getDimSize(1) == | ||
(shapeB.getDimSize(1) * shapeB.getDimSize(2) * shapeB.getDimSize(4)); | ||
bool matchVnni = (shapeB.getDimSize(4) == 1) || (shapeB.getDimSize(4) == 2) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, why do we consider 1
as vnni format here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input: MKmk/MK
weight: NKkn4k/NKkn2k/NKkn/KN
output: MNmn/MN
For NKkn format, can we treat it as NKkn1k so we can reuse this op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please leave a note here if you'd like to reuse it for F32
datatypes, as vnni
always refers to low precision with blk_size 2 or 4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a name for NKkn
format (differentiating from mmt4d's Nknk
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a name for
NKkn
format (differentiating from mmt4d'sNknk
)?
The name is NKkn. for op naming, I think it's mm4d_vnni(no transpose).
@@ -24,3 +24,84 @@ func.func @generalize_sigmoid(%arg0: tensor<4x256x64xbf16>, %arg1: tensor<4x256x | |||
// CHECK-NEXT: linalg.yield %[[DIV]] : bf16 | |||
|
|||
// ----- | |||
|
|||
func.func @generalize_mmt2d_vnni(%arg0: tensor<256x64xf32>, %arg1: tensor<16x2x8x32x4xf32>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also add failed case for checking?
}]; | ||
} | ||
|
||
def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also define multi_batch_matmul_4d and multi_batch_matmul_4d_vnni?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this define correct for multi_batch_matmul_4d and multi_batch_matmul_4d_vnni? Do we also want transposed weight?
input: BMKmk
weight: BNKkn4k/BNKkn2k/BNKkn1k
output: BMNmn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. We don't use transposed weight.
Do we need to add
|
Yeah, will add this in the future |
bool matchK = | ||
shapeA.getDimSize(1) == | ||
(shapeB.getDimSize(1) * shapeB.getDimSize(2) * shapeB.getDimSize(4)); | ||
bool matchVnni = (shapeB.getDimSize(4) == 2) || (shapeB.getDimSize(4) == 4); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also check vnni dim value based on dtype? (e.g. we restrain vnni to be 2 under bf16; and 4 under u8/s8).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
Added
mm2d_vnni
,mm4d_vnni
,batch_reduce_matmul_vnni
with custom verifier, static indexing_maps and iterator_types, and using vnni dims to get constant symbol for indexing_maps.Added
multi_batch_matmul
with LinalgContractionOpInterface, dynamic indexing_maps and iterator_types.Tracking: #14