Skip to content

Commit

Permalink
support 2Dx4D/5D case
Browse files Browse the repository at this point in the history
  • Loading branch information
zhczhong committed Jun 13, 2024
1 parent 9af3f96 commit 14f4918
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 111 deletions.
201 changes: 127 additions & 74 deletions lib/gc/Transforms/DeepTileContractionNamedOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,14 @@ MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) {
cfg.KBlock = 64;
cfg.MThreads = 2;
cfg.NThreads = 2;
cfg.KThreads = 2;
cfg.KThreads = 1;
return cfg;
}

static Value tensorViewRankedTensor(RewriterBase &rewriter,
RankedTensorType outTensorType,
Value value) {
static Value
tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType,
Value value,
ArrayRef<int64_t> permutation = SmallVector<int64_t>{}) {
// TODO: add support for plain layout transpose
Value result, currentValue = value;
auto loc = currentValue.getLoc();
Expand Down Expand Up @@ -175,33 +176,57 @@ static Value tensorViewRankedTensor(RewriterBase &rewriter,

if (outShape.size() < inShape.size()) {
SmallVector<ReassociationIndices> reassocIndices;
ReassociationIndices firstEntry;
for (auto i = 0UL; i < inShape.size() - outShape.size() + 1; i++) {
firstEntry.push_back(i);
}
reassocIndices.push_back(firstEntry);
for (auto i = inShape.size() - outShape.size() + 1UL; i < inShape.size();
i++) {
reassocIndices.push_back({(int)i});
uint64_t outIdx = 0UL, inIdx = 0UL;
while (inIdx < inShape.size() && outIdx < outShape.size()) {
ReassociationIndices firstEntry;
auto remaining = outShape[outIdx++];
if (remaining == 1) {
firstEntry.push_back(inIdx++);
reassocIndices.push_back(firstEntry);
continue;
}
while (remaining > 1) {
remaining /= inShape[inIdx];
firstEntry.push_back(inIdx++);
}
reassocIndices.push_back(firstEntry);
}
result = rewriter.create<tensor::CollapseShapeOp>(
loc, outTensorType, currentValue, reassocIndices);
} else if (outShape.size() > inShape.size()) {
SmallVector<ReassociationIndices> reassocIndices;
ReassociationIndices firstEntry;
for (auto i = 0UL; i < outShape.size() - inShape.size() + 1; i++) {
firstEntry.push_back((int)i);
}
reassocIndices.push_back(firstEntry);
for (auto i = outShape.size() - inShape.size() + 1UL; i < outShape.size();
i++) {
reassocIndices.push_back({(int)i});
uint64_t outIdx = 0UL, inIdx = 0UL;
while (outIdx < outShape.size() && inIdx < inShape.size()) {
ReassociationIndices firstEntry;
auto remaining = inShape[inIdx++];
if (remaining == 1) {
firstEntry.push_back(outIdx++);
reassocIndices.push_back(firstEntry);
continue;
}
while (remaining > 1) {
remaining /= outShape[outIdx];
firstEntry.push_back(outIdx++);
}
reassocIndices.push_back(firstEntry);
}
result = rewriter.create<tensor::ExpandShapeOp>(
loc, outTensorType, currentValue, reassocIndices);
} else {
result = rewriter.create<tensor::CastOp>(loc, outTensorType, currentValue);
}

if (!permutation.empty()) {
SmallVector<int64_t> transposeShape;
for (auto idx : permutation) {
transposeShape.push_back(outShape[idx]);
}
auto initOp = rewriter.create<tensor::EmptyOp>(loc, transposeShape,
tensorElementType);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, result, initOp->getResult(0), permutation);
result = transposeOp->getResult(0);
}
return result;
}

Expand Down Expand Up @@ -345,6 +370,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
return b.notifyMatchFailure(
linalgOp, "currentOp should not has pure buffer semantics");
linalg::LinalgOp currentOp = linalgOp;

for (auto loopTypeIter : llvm::enumerate(loopType)) {
auto [i, loopType] = loopTypeIter;
auto currentDim = loopDim[i];
Expand Down Expand Up @@ -486,6 +512,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
bool isExtract,
SmallVector<int64_t> size,
int shrinDimNum = 0) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
SmallVector<OpFoldResult> mixedOffsets = extractSlice.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSlice.getMixedSizes();
Expand Down Expand Up @@ -514,6 +542,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter,
Operation *op, Value source,
SmallVector<int64_t> size) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
SmallVector<OpFoldResult> mixedOffsets = insertSlice.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = insertSlice.getMixedSizes();
Expand Down Expand Up @@ -575,35 +605,34 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
linalgOp.getReductionDims(KDimPos);
getMatmulParallelDims(linalgOp, 0, MDimPos);
getMatmulParallelDims(linalgOp, 1, NDimPos);
bool useBlockedLayout = KDimPos.size() > 1;

OuterLoopGenerationOption option;
auto iteratorTypes = linalgOp.getIteratorTypesArray();
auto KFirstDim = (int)getOprandDim(linalgOp, KDimPos[0], 1);
auto MFirstDim = (int)getOprandDim(linalgOp, MDimPos[0], 0);
auto NFirstDim = (int)getOprandDim(linalgOp, NDimPos[0], 1);
auto KParallelBlockSize =
useBlockedLayout
KDimPos.size() > 1
? divAndCeil(KFirstDim, cfg.KThreads)
: divAndCeil(divAndCeil(KFirstDim, cfg.KBlock), cfg.KThreads) *
cfg.KBlock;
auto MParallelBlockSize =
useBlockedLayout
MDimPos.size() > 1
? divAndCeil(MFirstDim, cfg.MThreads)
: divAndCeil(divAndCeil(MFirstDim, cfg.MBlock), cfg.MThreads) *
cfg.MBlock;
auto NParallelBlockSize =
useBlockedLayout
NDimPos.size() > 1
? divAndCeil(NFirstDim, cfg.NThreads)
: divAndCeil(divAndCeil(NFirstDim, cfg.NBlock), cfg.NThreads) *
cfg.NBlock;
auto KOuterBlockSize = useBlockedLayout
auto KOuterBlockSize = KDimPos.size() > 1
? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1
: cfg.KBlock;
auto MOuterBlockSize = useBlockedLayout
auto MOuterBlockSize = MDimPos.size() > 1
? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1
: cfg.MBlock;
auto NOuterBlockSize = useBlockedLayout
auto NOuterBlockSize = NDimPos.size() > 1
? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1
: cfg.NBlock;
// Outer
Expand Down Expand Up @@ -631,11 +660,23 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
option.loopDim.emplace_back(SmallVector<int>{dim});
}
// Inner
if (!useBlockedLayout) {
if (KDimPos.size() == 1) {
option.nestedTileSizes.emplace_back(SmallVector<int>{cfg.KBlock});
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
option.loopDim.emplace_back(SmallVector<int>{(int)KDimPos.back()});
}
if (MDimPos.size() == 1) {
option.nestedTileSizes.emplace_back(
SmallVector<int>{cfg.innerMostMBlock});
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
option.loopDim.emplace_back(SmallVector<int>{(int)MDimPos.back()});
}
if (NDimPos.size() == 1) {
option.nestedTileSizes.emplace_back(
SmallVector<int>{cfg.innerMostNBlock});
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
option.loopDim.emplace_back(SmallVector<int>{(int)NDimPos.back()});
}
for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) {
if (dim != MDimPos.back() && dim != NDimPos.back() &&
iteratorTypes[dim] != mlir::utils::IteratorType::reduction) {
Expand All @@ -658,17 +699,24 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
linalg::LinalgOp originOp,
linalg::LinalgOp currentOp,
innerBodyGenerationOption &option) const {

mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()};
auto operandDimTypes = getOprandDimType(originOp);
MatmulConfig cfg = getDefaultMatmulConfig(originOp);
auto AShape = originOp.getShape(originOp.getDpsInputOperand(0));
auto BShape = originOp.getShape(originOp.getDpsInputOperand(1));
auto CShape = originOp.getShape(originOp.getDpsInitOperand(0));
bool useBlockedLayout = BShape.size() > 2;

auto MDimNum = std::count_if((*operandDimTypes)[0].begin(),
(*operandDimTypes)[0].end(),
[](DimType d) { return d == DimType::M; });
auto NDimNum = std::count_if((*operandDimTypes)[1].begin(),
(*operandDimTypes)[1].end(),
[](DimType d) { return d == DimType::N; });
// TODO: support plain in/block out format
SmallVector<int64_t> AInnermostDims, BInnermostDims, CInnermostDims;
if (useBlockedLayout) {
bool firstM = true, firstK = true, firstN = true;
bool firstM = true, firstK = true, firstN = true;
if (MDimNum > 1) {
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[0])) {
if (iter == DimType::M && firstM) {
AInnermostDims.push_back(1);
Expand All @@ -682,21 +730,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
AInnermostDims.push_back(AShape[idx]);
}
}
firstN = true;
firstK = true;
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) {
if (iter == DimType::N && firstN) {
BInnermostDims.push_back(1);
firstN = false;
} else if (iter == DimType::Batch) {
BInnermostDims.push_back(1);
} else if (iter == DimType::K && firstK) {
BInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock);
firstK = false;
} else {
BInnermostDims.push_back(BShape[idx]);
}
}
firstM = true;
firstN = true;
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[2])) {
Expand All @@ -716,74 +749,94 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
AInnermostDims = SmallVector<int64_t>{cfg.innerMostMBlock,
cfg.KBlock / cfg.innerMostKBlock *
cfg.innerMostKBlock};
CInnermostDims =
SmallVector<int64_t>{cfg.innerMostMBlock, cfg.innerMostNBlock};
}
if (NDimNum > 1) {
firstN = true;
firstK = true;
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) {
if (iter == DimType::N && firstN) {
BInnermostDims.push_back(1);
firstN = false;
} else if (iter == DimType::Batch) {
BInnermostDims.push_back(1);
} else if (iter == DimType::K && firstK) {
BInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock);
firstK = false;
} else {
BInnermostDims.push_back(BShape[idx]);
}
}
} else {
BInnermostDims = SmallVector<int64_t>{cfg.KBlock / cfg.innerMostKBlock *
cfg.innerMostKBlock,
cfg.innerMostNBlock};
CInnermostDims =
SmallVector<int64_t>{cfg.innerMostMBlock, cfg.innerMostNBlock};
}

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(currentOp);
auto dataType =
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[0].getType());
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[0].getType())
.getElementType();
auto weightType =
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[1].getType());
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[1].getType())
.getElementType();
auto resultType =
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInits()[0].getType());
// use shrink layout when it is able to be converted to brgemm
bool useShrinkedLayout = (BInnermostDims.size() == 4);
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInits()[0].getType())
.getElementType();

// update the extractSlice to static size, replace it with
// useBlockedLayout when
if (failed(setStaticSizeForExtractSliceOp(
rewriter, currentOp.getDpsInits()[0].getDefiningOp(), true,
CInnermostDims, useShrinkedLayout ? 2 : 0)) ||
CInnermostDims, MDimNum > 1 ? 2 : 0)) ||
failed(setStaticSizeForExtractSliceOp(
rewriter, currentOp.getDpsInputs()[1].getDefiningOp(), true,
BInnermostDims, useShrinkedLayout)) ||
BInnermostDims, NDimNum > 1)) ||
failed(setStaticSizeForExtractSliceOp(
rewriter, currentOp.getDpsInputs()[0].getDefiningOp(), true,
AInnermostDims, useShrinkedLayout))) {
AInnermostDims, MDimNum > 1))) {
return failure();
}

// View the tensor to brgemm required format
Value dataOprand = tensorViewRankedTensor(
rewriter,
mlir::RankedTensorType::get(
useBlockedLayout
? SmallVector<int64_t>(AInnermostDims.begin() + 1,
AInnermostDims.end())
: SmallVector<int64_t>{1, AInnermostDims[0], AInnermostDims[1]},
dataType.getElementType()),
currentOp.getDpsInputs()[0]);
MDimNum > 1 ? SmallVector<int64_t>(AInnermostDims.begin() + 1,
AInnermostDims.end())
: SmallVector<int64_t>{cfg.innerMostMBlock,
cfg.KBlock / cfg.innerMostKBlock,
cfg.innerMostKBlock},
dataType),
currentOp.getDpsInputs()[0],
MDimNum == 1 ? SmallVector<int64_t>{1, 0, 2} : SmallVector<int64_t>{});
Value weightOprand = tensorViewRankedTensor(
rewriter,
mlir::RankedTensorType::get(
useBlockedLayout
? SmallVector<int64_t>(BInnermostDims.begin() + 1,
BInnermostDims.end())
: SmallVector<int64_t>{1, BInnermostDims[0], BInnermostDims[1]},
weightType.getElementType()),
NDimNum > 1 ? SmallVector<int64_t>(BInnermostDims.begin() + 1,
BInnermostDims.end())
: SmallVector<int64_t>{cfg.KBlock / cfg.innerMostKBlock,
cfg.innerMostKBlock,
cfg.innerMostNBlock},
weightType),
currentOp.getDpsInputs()[1]);
Value resultOprand = tensorViewRankedTensor(
rewriter,
mlir::RankedTensorType::get(
SmallVector<int64_t>(CInnermostDims.begin() +
(useBlockedLayout ? 2 : 0),
SmallVector<int64_t>(CInnermostDims.begin() + (MDimNum > 1 ? 2 : 0),
CInnermostDims.end()),
resultType.getElementType()),
resultType),
currentOp.getDpsInits()[0]);

// Create the brgemm op and replace the origin linalg op
linalg::LinalgOp matmul;
if (BInnermostDims.size() == 4 || BInnermostDims.size() == 2) {
if (dyn_cast<mlir::RankedTensorType>(weightOprand.getType())
.getShape()
.size() == 3) {
matmul = rewriter.create<linalg::BatchReduceMatmulOp>(
resultOprand.getLoc(), resultOprand.getType(),
ValueRange{dataOprand, weightOprand}, resultOprand);
} else {
IRMapping mapping;
matmul = rewriter.create<linalgx::BatchReduceMatmulVnniOp>(
resultOprand.getLoc(), resultOprand.getType(),
ValueRange{dataOprand, weightOprand}, resultOprand);
Expand Down
Loading

0 comments on commit 14f4918

Please sign in to comment.