diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index f73f213f..eabc434c 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" @@ -30,7 +31,6 @@ #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include #include "gc/Transforms/Passes.h" @@ -230,39 +230,44 @@ tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType, return result; } -struct OuterLoopGenerationOption { - enum LoopType { ForOp, ForallOp }; - SmallVector> nestedTileSizes; - SmallVector loopType; - SmallVector> loopDim; - bool hasFillOp; -}; - -struct OuterLoopGenerationResult { - /// Tiled operations that are generated during tiling. The order does not - /// matter except the last op. The replacements are expected to be the results - /// of the last op. - SmallVector tiledOps; - /// The `scf.for` operations that iterate over the tiles. - SmallVector loops; - SmallVector reductionLoops; -}; +bool isDummyLoop(LoopLikeOpInterface loop) { + std::optional tripCount = mlir::constantTripCount( + *loop.getSingleLowerBound(), *loop.getSingleUpperBound(), + *loop.getSingleStep()); + if (tripCount) { + return *tripCount == 1; + } + return false; +} -static void buildLinalgRegion(Operation *op) { +static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) { SmallVector argTypes; SmallVector argLocs; for (const Value &opOperand : op->getOperands()) { argTypes.push_back(getElementTypeOrSelf(opOperand.getType())); argLocs.push_back(opOperand.getLoc()); } + auto initSize = op->getResults().size(); ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); b.setInsertionPointToStart(body); - auto *dialect = static_cast(op->getDialect()); - linalg::LinalgDialect::RegionBuilderFunType fun = - dialect->getRegionBuilder("linalg.matmul"); - fun(b, *body, op->getAttrs()); + if (createTemporaryOp) { + auto argNum = body->getNumArguments(); + SmallVector vals; + for (auto i = initSize; i > 0; i--) { + vals.push_back(body->getArgument(argNum - i)); + } + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToEnd(body); + Location loc = b.getUnknownLoc(); + b.create(loc, ValueRange(vals)); + } else { + auto *dialect = static_cast(op->getDialect()); + linalg::LinalgDialect::RegionBuilderFunType fun = + dialect->getRegionBuilder("linalg.matmul"); + fun(b, *body, op->getAttrs()); + } } struct DtypeLegalizeResult { @@ -270,25 +275,29 @@ struct DtypeLegalizeResult { Operation *castOp = nullptr; }; +bool needToLegalizeDtype(linalg::LinalgOp linalgOp) { + auto dataType = + dyn_cast(linalgOp.getDpsInputs()[0].getType()) + .getElementType(); + auto resultType = + dyn_cast(linalgOp.getDpsInits()[0].getType()) + .getElementType(); + return (dataType.isBF16() || dataType.isF16()) && dataType == resultType; +} + // Split a low precision matmul(bf16xbf16->bf16) to a combination // matmul(bf16xbf16->f32) + cast(f32->bf16) +// if needFurtherFuse=true, a middle temporary linalgOp(bf16xbf16->(f32,bf16)) +// will be created static FailureOr matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, - bool needCopyInit = true) { - + bool needCopyInit = true, bool needFurtherFuse = false) { auto linalgOp = dyn_cast(op); DtypeLegalizeResult result; if (!linalgOp) return failure(); - auto dataType = - dyn_cast(linalgOp.getDpsInputs()[0].getType()) - .getElementType(); - auto resultType = - dyn_cast(linalgOp.getDpsInits()[0].getType()) - .getElementType(); - - if ((dataType.isBF16() || dataType.isF16()) && dataType == resultType) { + if (needToLegalizeDtype(linalgOp)) { rewriter.setInsertionPoint(linalgOp); IRMapping mapping; auto initOp = linalgOp.getDpsInits()[0].getDefiningOp(); @@ -315,14 +324,30 @@ matmulDtypeLegalize(RewriterBase &rewriter, Operation *op, linalgOp.getLoc(), initOp->getResult(0), currentOp->getResult(0)); } SmallVector newOperands = linalgOp->getOperands(); + auto oldInit = newOperands.back(); newOperands.back() = currentOp->getResult(0); + + auto indexingMaps = linalgOp.getIndexingMapsArray(); + indexingMaps.push_back(indexingMaps.back()); + SmallVector attrs(linalgOp->getAttrs()); + SmallVector types = {currentOp->getResult(0).getType()}; + if (needFurtherFuse) { + auto segmentSize = rewriter.getNamedAttr( + "operandSegmentSizes", rewriter.getDenseI32ArrayAttr({2, 2})); + for (auto &attr : attrs) { + if (attr.getName() == "indexing_maps") + attr.setValue(rewriter.getAffineMapArrayAttr(indexingMaps)); + if (attr.getName() == "operandSegmentSizes") + attr.setValue(segmentSize.getValue()); + } + types.push_back(oldInit.getType()); + newOperands.push_back(oldInit); + } OperationState state(linalgOp->getLoc(), linalgOp->getName(), newOperands, - currentOp->getResult(0).getType(), - linalgOp->getAttrs()); + types, attrs); state.addRegion(); currentOp = rewriter.create(state); - buildLinalgRegion(currentOp); - + buildLinalgRegion(currentOp, needFurtherFuse); auto castOp = rewriter.create( linalgOp.getLoc(), currentOp->getResult(0), initOp->getResult(0)); result.linalgOp = currentOp; @@ -348,6 +373,129 @@ static Operation *findParentFillOp(Value val) { return nullptr; } +[[maybe_unused]] static LogicalResult +indexRolling(RewriterBase &b, Block *insertBlock, Location loc, Value v, + Value rollingIdx, Value maximumRange, Value step) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(insertBlock); + mlir::easybuild::EasyBuilder eb{b, loc}; + auto vWraped = eb.wrap(v); + auto rollingIdxWraped = eb.wrap(rollingIdx); + auto stepWraped = eb.wrap(step); + auto maximumRangeWraped = eb.wrap(step); + auto newV = (vWraped + rollingIdxWraped) * stepWraped % + (maximumRangeWraped / stepWraped * stepWraped); + v.replaceAllUsesWith(newV); + return failure(); +} + +static void getMatmulParallelDims(linalg::LinalgOp linalgOp, + unsigned operandIdx, + SmallVectorImpl &dims) { + AffineMap map = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(operandIdx)); + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + + ArrayRef results = map.getResults(); + for (auto dim : results) { + auto dimExpr = dyn_cast(dim); + if (dimExpr && iteratorTypes[dimExpr.getPosition()] == + mlir::utils::IteratorType::parallel) { + dims.push_back(dimExpr.getPosition()); + } + } +} + +static unsigned getOprandDim(linalg::LinalgOp &linalgOp, unsigned iteratorPos, + unsigned operandIdx) { + Value Operand; + unsigned dimPos; + [[maybe_unused]] auto result = + linalgOp.mapIterationSpaceDimToOperandDim(iteratorPos, Operand, dimPos); + return linalgOp.getShape(linalgOp.getDpsInputOperand(operandIdx))[dimPos]; +} + +static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter, + Operation *op, + bool isExtract, + SmallVector size, + int shrinDimNum = 0) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + if (auto extractSlice = dyn_cast(op)) { + SmallVector mixedOffsets = extractSlice.getMixedOffsets(); + SmallVector mixedSizes = extractSlice.getMixedSizes(); + SmallVector mixedStrides = extractSlice.getMixedStrides(); + for (auto i = 0UL; i < mixedSizes.size(); i++) { + mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); + } + if (shrinDimNum > 0) { + rewriter.replaceOpWithNewOp( + extractSlice, + mlir::RankedTensorType::get( + SmallVector(size.begin() + shrinDimNum, size.end()), + extractSlice.getResult().getType().getElementType()), + extractSlice.getSource(), mixedOffsets, mixedSizes, mixedStrides); + } else { + rewriter.replaceOpWithNewOp( + extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes, + mixedStrides); + } + } else { + return failure(); + } + return mlir::success(); +} + +static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter, + Operation *op, Value source, + SmallVector size) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + if (auto insertSlice = dyn_cast(op)) { + SmallVector mixedOffsets = insertSlice.getMixedOffsets(); + SmallVector mixedSizes = insertSlice.getMixedSizes(); + SmallVector mixedStrides = insertSlice.getMixedStrides(); + for (auto i = 0UL; i < mixedSizes.size(); i++) { + mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); + } + rewriter.replaceOpWithNewOp( + insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes, + mixedStrides); + } else { + return failure(); + } + return success(); +} + +using InnermostFullResultCallBackFn = std::function( + RewriterBase &rewriter, Location loc, linalg::LinalgOp linalgop)>; + +using FinalReduceCallBackFn = std::function( + RewriterBase &rewriter, Location loc, + linalg::ForallReductionTilingResult result)>; + +struct OuterLoopGenerationOption { + enum LoopType { ForOp, ForallOp }; + SmallVector> nestedTileSizes; + SmallVector loopType; + SmallVector> loopDim; + SmallVector innermostFullResultCallBacks; + SmallVector finalReduceCallBacks; + bool isPartialResult = false; +}; + +struct OuterLoopGenerationResult { + /// Tiled operations that are generated during tiling. The order does not + /// matter except the last op. The replacements are expected to be the results + /// of the last op. + SmallVector tiledOps; + /// The `scf.for` operations that iterate over the tiles. + SmallVector loops; + SmallVector reductionLoops; +}; + static FailureOr generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, const OuterLoopGenerationOption &option) { @@ -371,6 +519,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, linalgOp, "currentOp should not has pure buffer semantics"); linalg::LinalgOp currentOp = linalgOp; + bool hasFullResult = !option.isPartialResult; for (auto loopTypeIter : llvm::enumerate(loopType)) { auto [i, loopType] = loopTypeIter; auto currentDim = loopDim[i]; @@ -385,26 +534,29 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, tileOption.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(currentOp); - // TODO: refactor here to use a callback function if (iteratorTypes[d] == mlir::utils::IteratorType::reduction && - tile != 0) { - auto result = matmulDtypeLegalize(b, currentOp.getOperation(), - !option.hasFillOp); - if (result->castOp && result->linalgOp) { - b.replaceOp(currentOp, result->castOp); - currentOp = dyn_cast(result->linalgOp); + tile != 0 && hasFullResult) { + for (const auto &fn : option.innermostFullResultCallBacks) { + auto result = fn(b, currentOp.getLoc(), currentOp); + if (succeeded(result)) { + currentOp = *result; + } } + hasFullResult = false; } auto tilingResult = scf::tileUsingSCF( b, cast(currentOp.getOperation()), tileOption); if (failed(tilingResult)) return failure(); - b.replaceOp(currentOp, tilingResult->replacements); - currentOp = dyn_cast(tilingResult->tiledOps.back()); - if (iteratorTypes[d] == mlir::utils::IteratorType::reduction) { - result.reductionLoops.push_back(tilingResult->loops.back()); + + if (!isDummyLoop(tilingResult->loops.back())) { + b.replaceOp(currentOp, tilingResult->replacements); + currentOp = dyn_cast(tilingResult->tiledOps.back()); + if (iteratorTypes[d] == mlir::utils::IteratorType::reduction) { + result.reductionLoops.push_back(tilingResult->loops.back()); + } + result.loops.push_back(tilingResult->loops.back()); } - result.loops.push_back(tilingResult->loops.back()); } } else if (loopType == OuterLoopGenerationOption::LoopType::ForallOp) { SmallVector tileSizes( @@ -442,11 +594,12 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, if (failed(tilingResult)) return failure(); currentOp = dyn_cast(tilingResult->parallelTiledOp); - if (option.hasFillOp && tilingResult->mergeOp) { - auto fillOp = findParentFillOp(tilingResult->loops.getDpsInits()[0]); - if (fillOp) { - b.replaceOp(fillOp, dyn_cast(*fillOp) - .getDpsInits()[0]); + if (tilingResult->mergeOp) { + for (const auto &fn : option.finalReduceCallBacks) { + auto result = fn(b, currentOp.getLoc(), *tilingResult); + if (succeeded(result)) { + currentOp = *result; + } } } } else if (auto tilingInterface = @@ -464,102 +617,6 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, return result; } -[[maybe_unused]] static LogicalResult -indexRolling(RewriterBase &b, Block *insertBlock, Location loc, Value v, - Value rollingIdx, Value maximumRange, Value step) { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(insertBlock); - mlir::easybuild::EasyBuilder eb{b, loc}; - auto vWraped = eb.wrap(v); - auto rollingIdxWraped = eb.wrap(rollingIdx); - auto stepWraped = eb.wrap(step); - auto maximumRangeWraped = eb.wrap(step); - auto newV = (vWraped + rollingIdxWraped) * stepWraped % - (maximumRangeWraped / stepWraped * stepWraped); - v.replaceAllUsesWith(newV); - return failure(); -} - -static void getMatmulParallelDims(linalg::LinalgOp linalgOp, - unsigned operandIdx, - SmallVectorImpl &dims) { - AffineMap map = - linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(operandIdx)); - SmallVector iteratorTypes = - linalgOp.getIteratorTypesArray(); - - ArrayRef results = map.getResults(); - for (auto dim : results) { - auto dimExpr = dyn_cast(dim); - if (dimExpr && iteratorTypes[dimExpr.getPosition()] == - mlir::utils::IteratorType::parallel) { - dims.push_back(dimExpr.getPosition()); - } - } -} - -static unsigned getOprandDim(linalg::LinalgOp &linalgOp, unsigned iteratorPos, - unsigned operandIdx) { - Value Operand; - unsigned dimPos; - [[maybe_unused]] auto result = - linalgOp.mapIterationSpaceDimToOperandDim(iteratorPos, Operand, dimPos); - return linalgOp.getShape(linalgOp.getDpsInputOperand(operandIdx))[dimPos]; -} - -static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter, - Operation *op, - bool isExtract, - SmallVector size, - int shrinDimNum = 0) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(op); - if (auto extractSlice = dyn_cast(op)) { - SmallVector mixedOffsets = extractSlice.getMixedOffsets(); - SmallVector mixedSizes = extractSlice.getMixedSizes(); - SmallVector mixedStrides = extractSlice.getMixedStrides(); - for (auto i = 0UL; i < mixedSizes.size(); i++) { - mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); - } - if (shrinDimNum > 0) { - rewriter.replaceOpWithNewOp( - extractSlice, - mlir::RankedTensorType::get( - SmallVector(size.begin() + shrinDimNum, size.end()), - extractSlice.getResult().getType().getElementType()), - extractSlice.getSource(), mixedOffsets, mixedSizes, mixedStrides); - } else { - rewriter.replaceOpWithNewOp( - extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes, - mixedStrides); - } - } else { - return failure(); - } - return mlir::success(); -} - -static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter, - Operation *op, Value source, - SmallVector size) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(op); - if (auto insertSlice = dyn_cast(op)) { - SmallVector mixedOffsets = insertSlice.getMixedOffsets(); - SmallVector mixedSizes = insertSlice.getMixedSizes(); - SmallVector mixedStrides = insertSlice.getMixedStrides(); - for (auto i = 0UL; i < mixedSizes.size(); i++) { - mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), size[i]); - } - rewriter.replaceOpWithNewOp( - insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes, - mixedStrides); - } else { - return failure(); - } - return success(); -} - /* matmul(A, B) -> C ----------------> @@ -636,14 +693,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 : cfg.NBlock; // Outer - if (cfg.KThreads > 1) { - auto result = - matmulDtypeLegalize(rewriter, linalgOp.getOperation(), !hasFillOp); - if (result->castOp && result->linalgOp) { - rewriter.replaceOp(linalgOp, result->castOp); - linalgOp = dyn_cast(result->linalgOp); - } - } option.nestedTileSizes.emplace_back(SmallVector{ MParallelBlockSize, NParallelBlockSize, KParallelBlockSize}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp); @@ -686,12 +735,45 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { option.loopDim.emplace_back(SmallVector{(int)dim}); } } - option.hasFillOp = hasFillOp; + + auto lowPrecisionCast = + [&](RewriterBase &rewriter, Location loc, + linalg::LinalgOp linalgop) -> FailureOr { + auto legalizedResult = matmulDtypeLegalize( + rewriter, linalgop.getOperation(), !hasFillOp, true); + if (legalizedResult->castOp && legalizedResult->linalgOp) { + auto linalgOp = legalizedResult->linalgOp; + rewriter.replaceOp(linalgop, + linalgOp->getResult(linalgOp->getNumResults() - 1)); + return dyn_cast(linalgOp); + } + return failure(); + }; + option.innermostFullResultCallBacks.push_back(lowPrecisionCast); + + if (hasFillOp) { + auto removeReduncantFill = + [&](RewriterBase &rewriter, Location loc, + const linalg::ForallReductionTilingResult &result) + -> FailureOr { + auto initValue = result.initialValues; + if (initValue.size() == 1 && + isa(initValue[0].getDefiningOp())) { + rewriter.replaceOp(initValue[0].getDefiningOp(), + dyn_cast( + initValue[0].getDefiningOp()) + .getDpsInits()[0]); + } + return dyn_cast(result.parallelTiledOp); + }; + option.finalReduceCallBacks.push_back(removeReduncantFill); + } return generateOuterLoop(rewriter, linalgOp, option); } struct innerBodyGenerationOption { Operation *fillOp; + bool needLowPrecisionCast; SmallVector KLoopHandles; }; @@ -796,7 +878,11 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { BInnermostDims, NDimNum > 1)) || failed(setStaticSizeForExtractSliceOp( rewriter, currentOp.getDpsInputs()[0].getDefiningOp(), true, - AInnermostDims, MDimNum > 1))) { + AInnermostDims, MDimNum > 1)) || + (currentOp.getDpsInits().size() > 1 && + failed(setStaticSizeForExtractSliceOp( + rewriter, currentOp.getDpsInits()[1].getDefiningOp(), true, + CInnermostDims, MDimNum > 1 ? 2 : 0)))) { return failure(); } // View the tensor to brgemm required format @@ -850,13 +936,47 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { return failure(); } } - rewriter.replaceOp(currentOp, matmul.getOperation()->getResult(0)); - currentOp = matmul; + if (option.needLowPrecisionCast) { + rewriter.setInsertionPointAfter(currentOp); + auto cond = eb(true); + for (auto loop : option.KLoopHandles) { + auto induceVar = + eb.wrap(*loop.getSingleInductionVar()); + auto upBound = + eb.wrap(*loop.getSingleUpperBound()); + auto step = eb.wrap(*loop.getSingleStep()); + auto currentCond = (induceVar + step) > upBound; + cond = cond & currentCond; + } + EB_scf_if(cond, {currentOp.getDpsInits().back().getType()}) { + auto castOp = rewriter.create( + matmul.getLoc(), matmul->getResult(0), + currentOp.getDpsInits().back()); + eb.yield(castOp->getResult(0)); + } + EB_else { eb.yield(currentOp.getDpsInits().back()); } + auto ifOp = eb.getLastOperaion(); + // set static size for the insertSliceOp of copyOp + for (Operation *user : currentOp->getResult(1).getUsers()) { + if (failed(setStaticSizeForInsertSliceOp( + rewriter, user, ifOp->getResult(0), CInnermostDims))) { + return failure(); + } + } + rewriter.replaceOp(currentOp, {matmul->getResult(0), ifOp->getResult(0)}); + } else { + rewriter.replaceOp(currentOp, matmul->getResult(0)); + } + currentOp = matmul; // Fuse the fill op to the innermost body if (auto fillOp = llvm::dyn_cast_or_null(option.fillOp)) { auto fillValue = fillOp.getDpsInputs()[0]; - rewriter.replaceOp(fillOp, fillOp.getDpsInits()[0]); + if (cfg.KThreads <= 1) { + // if use k slicing, the fill op is still need to be kept for the reduce + // init + rewriter.replaceOp(fillOp, fillOp.getDpsInits()[0]); + } rewriter.setInsertionPointAfter(currentOp); auto cond = eb(true); @@ -910,29 +1030,31 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { linalgOp = *linalg::generalizeNamedOp(rewriter, linalgOp); Operation *fillOp = findParentFillOp(linalgOp.getDpsInits()[0]); - // Step 1. generate the outer loop + // Step 1. Split matmul(bf16xbf16->bf16) to matmul(bf16xbf16->f32) + + // cast(f32->bf16) if K slicing is needed MatmulConfig cfg = getDefaultMatmulConfig(linalgOp); + bool needLowPrecisionCast = needToLegalizeDtype(linalgOp); + if (cfg.KThreads > 1) { + auto result = matmulDtypeLegalize(rewriter, linalgOp.getOperation()); + if (result->castOp && result->linalgOp) { + rewriter.replaceOp(linalgOp, result->castOp); + linalgOp = dyn_cast(result->linalgOp); + } + needLowPrecisionCast = false; + } + + // Step 2. Outer loop generation auto outerLoopResult = outerLoopGeneration(rewriter, linalgOp, cfg, isa(fillOp)); if (failed(outerLoopResult)) { return failure(); } linalgOp = dyn_cast(outerLoopResult->tiledOps.back()); - // Step 2 index rolling - // if (failed(indexRolling(rewriter, linalgOp.getLoc(), - // outerLoopResult->reductionLoops[0].getInductionVar(), - // linalgOp.getLoopRanges()[0].size, cfg.MBlock)) - // || - // failed(indexRolling(rewriter, linalgOp.getLoc(), - // linalgOp.getDpsInputOperand(1), - // linalgOp.getLoopRanges()[1].size, cfg.KBlock))) - // { - // return failure(); - // } // Step 3 generate inner loop body, convert the linalg.generic to brgemm - auto option = - innerBodyGenerationOption{fillOp, outerLoopResult->reductionLoops}; + auto option = innerBodyGenerationOption{fillOp, needLowPrecisionCast, + outerLoopResult->reductionLoops}; + if (failed(innerBodyGeneration(rewriter, originOp, linalgOp, option))) { return failure(); }