From 6f1445c8fdb8684fac05ba68432c9748d0501468 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Wed, 19 Jun 2024 19:41:56 -0700 Subject: [PATCH] enhance config --- CMakeLists.txt | 1 + include/gc/Analysis/MatmulConfigAnalysis.h | 123 ++++++ lib/gc/Analysis/CMakeLists.txt | 19 + lib/gc/Analysis/MatmulConfigAnalysis.cpp | 387 ++++++++++++++++++ lib/gc/CMakeLists.txt | 1 + .../Transforms/DeepTileContractionNamedOp.cpp | 162 ++------ src/gc-opt/CMakeLists.txt | 6 +- 7 files changed, 570 insertions(+), 129 deletions(-) create mode 100644 include/gc/Analysis/MatmulConfigAnalysis.h create mode 100644 lib/gc/Analysis/CMakeLists.txt create mode 100644 lib/gc/Analysis/MatmulConfigAnalysis.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 55661408..02619767 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,6 +96,7 @@ endif() set(GC_LIB_LINKED_LIBS GCPasses + GCAnalysis MLIROneDNNGraph ) add_library(graph_compiler SHARED ${GC_LIB_SOURCES}) diff --git a/include/gc/Analysis/MatmulConfigAnalysis.h b/include/gc/Analysis/MatmulConfigAnalysis.h new file mode 100644 index 00000000..cbc25960 --- /dev/null +++ b/include/gc/Analysis/MatmulConfigAnalysis.h @@ -0,0 +1,123 @@ +//===-- MatmulConfigAnalysis.h - DESC ---------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H +#define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H + +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" +#include +#include +#include + +namespace mlir { +namespace gc { + +using namespace mlir; + +struct SystemDesc { + // get runtime OMP_NUM_THREADS + uint32_t getNumThreads() { + char *numThreads = getenv("OMP_NUM_THREADS"); + if (numThreads) { + return std::stoi(numThreads); + } + return 1; + } + // get cache size by cacheLevel + size_t getCacheSize(uint8_t cacheLevel) { + if (cacheLevel == 1) { + char *cacheSize = getenv("L1_CACHE_SIZE"); + if (cacheSize) { + return std::stoi(cacheSize); + } + } else if (cacheLevel == 2) { + char *cacheSize = getenv("L2_CACHE_SIZE"); + if (cacheSize) { + return std::stoi(cacheSize); + } + } else if (cacheLevel == 3) { + char *cacheSize = getenv("L3_CACHE_SIZE"); + if (cacheSize) { + return std::stoi(cacheSize); + } + } + return 0; + } + + SmallVector getContractionOperationMaxVectorLength() { + return {512UL, 512UL}; + } +}; + +struct MatmulConfig { + uint32_t MBlock, NBlock, KBlock; + uint32_t MThreads, NThreads, KThreads; + uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const MatmulConfig &config); +}; + +enum DimType { Batch, M, N, K }; + +[[maybe_unused]] static SmallVector +extractDimTypeIdx(ArrayRef tyList, DimType ty) { + SmallVector idxList; + for (auto [idx, type] : llvm::enumerate(tyList)) { + if (type == ty) { + idxList.push_back(idx); + } + } + return idxList; +} + +static FailureOr>> +getOprandDimType(linalg::LinalgOp &linalgOp) { + if (isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::K, DimType::N}, + SmallVector{DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, + DimType::K}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, + DimType::K}, + SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (llvm::isa(linalgOp)) { + return SmallVector>{ + SmallVector{DimType::Batch, DimType::M, DimType::K}, + SmallVector{DimType::Batch, DimType::K, DimType::N}, + SmallVector{DimType::Batch, DimType::M, DimType::N}}; + } + return failure(); +} + +struct MatmulConfigAnalysis { +public: + explicit MatmulConfigAnalysis(Operation *root); + MatmulConfig getConfig() { return config; } + +private: + MatmulConfig config; +}; + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt new file mode 100644 index 00000000..865adcb5 --- /dev/null +++ b/lib/gc/Analysis/CMakeLists.txt @@ -0,0 +1,19 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS + MLIRIR + MLIRSupport) + +add_mlir_library(GCAnalysis + MatmulConfigAnalysis.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include + + DEPENDS + GraphCompilerPassIncGen + + LINK_LIBS PUBLIC + ${mlir_dialect_libs} + ${MLIR_LINK_COMPONENTS} + ) + +set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCAnalysis) \ No newline at end of file diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp new file mode 100644 index 00000000..de206756 --- /dev/null +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -0,0 +1,387 @@ +//===-- MatmulConfigAnalysis.cpp - DESC -------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "gc/Analysis/MatmulConfigAnalysis.h" + +namespace mlir { +namespace gc { + +#define DEBUG_TYPE "matmul-config-analysis" + +#define MAX_THREADS (1024U * 1024U) + +llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const MatmulConfig &config) { + + ss << "MBlock: " << config.MBlock << ", NBlock: " << config.NBlock + << ", KBlock: " << config.KBlock << ", MThreads: " << config.MThreads + << ", NThreads: " << config.NThreads << ", KThreads: " << config.KThreads + << ", innerMostMBlock: " << config.innerMostMBlock + << ", innerMostNBlock: " << config.innerMostNBlock + << ", innerMostKBlock: " << config.innerMostKBlock; + return ss; +} + +std::vector getCandidate(uint32_t num, uint32_t floor, + uint32_t ceil) { + std::vector candidates; + for (uint32_t i = 1; i <= num; i++) { + if (num % i == 0 && i <= ceil && i >= floor) { + candidates.push_back(i); + } + } + auto candidate = 1U; + while (candidate < num && candidate <= ceil && candidate >= floor) { + candidates.push_back(candidate); + candidate *= 2; + } + auto last = std::unique(candidates.begin(), candidates.end()); + candidates.erase(last, candidates.end()); + return candidates; +} + +bool isValidConfig(const MatmulConfig &config, SystemDesc &sysDesc, + ArrayRef shape) { + if (config.innerMostMBlock == 0 || config.innerMostNBlock == 0 || + config.innerMostKBlock == 0) { + return false; + } + if (config.MBlock % config.innerMostMBlock != 0 || + config.NBlock % config.innerMostNBlock != 0 || + config.KBlock % config.innerMostKBlock != 0) { + return false; + } + auto threads = sysDesc.getNumThreads(); + if (config.MThreads * config.NThreads * config.KThreads != threads) { + return false; + } + + if (shape[0] % config.innerMostMBlock != 0 || + shape[1] % config.innerMostNBlock != 0 || + shape[2] % config.innerMostKBlock != 0) { + return false; + } + + return true; +} + +double threadUtilizationCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, SystemDesc &sysDesc) { + auto threads = sysDesc.getNumThreads(); + auto actualThreads = + (float)(config.MThreads * config.NThreads * config.KThreads); + return threads >= actualThreads ? threads / actualThreads + : actualThreads / threads; +} + +double hardwareEfficiencyCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, SystemDesc &sysDesc) { + auto dtypeSize = DataLayout().getTypeSizeInBits( + ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); + auto vectorLength = sysDesc.getContractionOperationMaxVectorLength(); + auto mMaxVectorLength = vectorLength[0] / dtypeSize; + auto kMaxVectorLength = + (vectorLength.size() > 1 ? vectorLength[1] : vectorLength[0]) / dtypeSize; + auto cost = (mMaxVectorLength - config.innerMostMBlock % mMaxVectorLength) % + mMaxVectorLength * 1.0 / config.innerMostMBlock + + (kMaxVectorLength - config.innerMostKBlock % kMaxVectorLength) % + kMaxVectorLength * 1.0 / config.innerMostKBlock + + (mMaxVectorLength - config.innerMostNBlock % mMaxVectorLength) % + mMaxVectorLength * 1.0 / config.innerMostNBlock; + return cost; +} + +double workloadBalancedCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, SystemDesc &sysDesc) { + return 1; +} + +double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, + SystemDesc &sysDesc) { + auto M = shape[0], N = shape[1], K = shape[2]; + auto dtypeSize = DataLayout().getTypeSizeInBits( + ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); + auto penalty = 2.0 * (dtypeSize / 8); + auto memoryConsumptionPerThread = + M * K * 1.0 / config.MThreads / config.KThreads + + K * N * 1.0 / config.KThreads / config.NThreads + + M * N * ((config.KThreads - 1) * penalty + 1.0) / config.MThreads / + config.NThreads; + return memoryConsumptionPerThread; +} + +double computationIntensityOnL1Cache(linalg::LinalgOp &linalgOp, + ArrayRef shape, + const MatmulConfig &config, + SystemDesc &sysDesc) { + auto L1Cache = sysDesc.getCacheSize(2); + auto dtypeSize = DataLayout().getTypeSizeInBits( + ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); + auto outOfCachePenalty = 1024; + double FLOPS = + 2.0 * config.innerMostMBlock * config.innerMostNBlock * config.KBlock; + double memoryConsumption = config.innerMostMBlock * config.innerMostNBlock + + config.innerMostNBlock * config.KBlock + + config.innerMostMBlock * config.KBlock; + double computationIntensity = FLOPS / memoryConsumption; + if (memoryConsumption * (dtypeSize / 8) > L1Cache) { + computationIntensity /= outOfCachePenalty; + } + return 1 / computationIntensity; +} + +using CostModelFn = + std::function shape, + MatmulConfig cfg, SystemDesc &sysDesc)>; + +std::vector +filterConfigByCostModel(std::vector configs, + linalg::LinalgOp &linalgOp, ArrayRef shape, + SystemDesc &sysDesc, const CostModelFn &costModel, + float eliminationRatio = 0.5, float threshold = -1) { + std::vector result; + std::vector costs; + std::vector idx; + for (auto [i, config] : llvm::enumerate(configs)) { + costs.push_back(costModel(linalgOp, shape, config, sysDesc)); + idx.push_back(i); + } + std::stable_sort(idx.begin(), idx.end(), [&costs](size_t i1, size_t i2) { + return costs[i1] < costs[i2]; + }); + auto thresholdCost = costs[idx[(size_t)(eliminationRatio * configs.size())]]; + thresholdCost = + threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost; + for (size_t i = 0; i < configs.size(); i++) { + if (costs[idx[i]] <= thresholdCost) { + result.push_back(configs[idx[i]]); + } + } + llvm::outs() << "thresholdCost is: " << thresholdCost + << "\nbest with cost: " << costs[idx[0]] << "\n" + << configs[idx[0]] + << "\n worst with cost: " << costs[idx[configs.size() - 1]] + << "\n" + << configs[idx[configs.size() - 1]] << "\n"; + return !result.empty() ? result : configs; +} + +std::vector +prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, + ArrayRef shape, + ArrayRef givenInnermostBlock) { + std::vector configs; + auto threads = sysDesc.getNumThreads(); + auto MThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); + auto NThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); + auto KThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); + auto MBlockCandidates = + getCandidate((uint32_t)shape[0], 1U, (uint32_t)shape[0]); + auto NBlockCandidates = getCandidate((uint32_t)shape[1], 1U, shape[1]); + auto KBlockCandidates = getCandidate((uint32_t)shape[2], 1U, shape[2]); + auto innerMostMBlockCandidates = + getCandidate((uint32_t)shape[0], 1U, (uint32_t)shape[0]); + auto innerMostNBlockCandidates = + getCandidate((uint32_t)shape[1], 1U, (uint32_t)shape[1]); + auto innerMostKBlockCandidates = + getCandidate((uint32_t)shape[2], 1U, (uint32_t)shape[2]); + if (givenInnermostBlock.size() == 3) { + innerMostMBlockCandidates = + givenInnermostBlock[0] != 0 + ? std::vector{givenInnermostBlock[0]} + : innerMostMBlockCandidates; + innerMostNBlockCandidates = + givenInnermostBlock[1] != 0 + ? std::vector{givenInnermostBlock[1]} + : innerMostNBlockCandidates; + innerMostKBlockCandidates = + givenInnermostBlock[2] != 0 + ? std::vector{givenInnermostBlock[2]} + : innerMostKBlockCandidates; + } + llvm::outs() << "MThreadsCandidates size: " << MThreadsCandidates.size() + << "\n"; + llvm::outs() << "NThreadsCandidates size: " << NThreadsCandidates.size() + << "\n"; + llvm::outs() << "KThreadsCandidates size: " << KThreadsCandidates.size() + << "\n"; + llvm::outs() << "MBlockCandidates size: " << MBlockCandidates.size() << "\n"; + llvm::outs() << "NBlockCandidates size: " << NBlockCandidates.size() << "\n"; + llvm::outs() << "KBlockCandidates size: " << KBlockCandidates.size() << "\n"; + llvm::outs() << "innerMostMBlockCandidates size: " + << innerMostMBlockCandidates.size() << "\n"; + llvm::outs() << "innerMostNBlockCandidates size: " + << innerMostNBlockCandidates.size() << "\n"; + llvm::outs() << "innerMostKBlockCandidates size: " + << innerMostKBlockCandidates.size() << "\n"; + for (auto MThreads : MThreadsCandidates) { + for (auto NThreads : NThreadsCandidates) { + for (auto KThreads : KThreadsCandidates) { + for (auto MBlock : MBlockCandidates) { + for (auto NBlock : NBlockCandidates) { + for (auto KBlock : KBlockCandidates) { + for (auto innerMostMBlock : innerMostMBlockCandidates) { + for (auto innerMostNBlock : innerMostNBlockCandidates) { + for (auto innerMostKBlock : innerMostKBlockCandidates) { + MatmulConfig config{ + MBlock, NBlock, KBlock, + MThreads, NThreads, KThreads, + innerMostMBlock, innerMostNBlock, innerMostKBlock}; + + if (isValidConfig(config, sysDesc, shape)) { + configs.push_back(config); + } + } + } + } + } + } + } + } + } + } + return configs; +} + +/* +thread utilization +computation intensity +cache locality +memory requirements +computation unit efficiency +padding/pack cost +workload balance +communication +previous matmul +*/ +MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { + SystemDesc sysDesc; + if (auto linalgOp = dyn_cast(root)) { + // TODO: build a more complex heuristic to determine the best tiling + auto oprandDimType = *getOprandDimType(linalgOp); + // get the origin M,N,K size + auto MDimTypeIdx = extractDimTypeIdx(oprandDimType[0], DimType::M); + auto KDimTypeIdx = extractDimTypeIdx(oprandDimType[1], DimType::K); + auto NDimTypeIdx = extractDimTypeIdx(oprandDimType[1], DimType::N); + uint32_t M = 1U, N = 1U, K = 1U; + for (auto [s, dimType] : + llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(0)), + oprandDimType[0])) { + if (dimType == DimType::M) { + M *= s; + } + } + for (auto [s, dimType] : + llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(1)), + oprandDimType[1])) { + if (dimType == DimType::N) { + N *= s; + } else if (dimType == DimType::K) { + K *= s; + } + } + + // innermost Block, if the layout is blockied layout, the innermost block + // will derived from the layout directly + auto defaultBlock = 32; + config.innerMostMBlock = M % defaultBlock == 0 ? defaultBlock : M; + config.innerMostNBlock = N % defaultBlock == 0 ? defaultBlock : N; + config.innerMostKBlock = K % defaultBlock == 0 ? defaultBlock : K; + SmallVector givenInnermostBlock; + if (MDimTypeIdx.size() > 1) { + config.innerMostMBlock = 1; + for (auto i = 1UL; i < MDimTypeIdx.size(); i++) { + config.innerMostMBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(0))[MDimTypeIdx[i]]; + } + givenInnermostBlock.push_back(config.innerMostMBlock); + } else { + givenInnermostBlock.push_back(0); + } + if (NDimTypeIdx.size() > 1) { + config.innerMostNBlock = 1; + for (auto i = 1UL; i < NDimTypeIdx.size(); i++) { + config.innerMostNBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(1))[NDimTypeIdx[i]]; + } + givenInnermostBlock.push_back(config.innerMostNBlock); + } else { + givenInnermostBlock.push_back(0); + } + if (KDimTypeIdx.size() > 1) { + config.innerMostKBlock = 1; + for (auto i = 1UL; i < KDimTypeIdx.size(); i++) { + config.innerMostKBlock *= + linalgOp.getShape(linalgOp.getDpsInputOperand(1))[KDimTypeIdx[i]]; + } + givenInnermostBlock.push_back(config.innerMostKBlock); + } else { + givenInnermostBlock.push_back(0); + } + + // Number of block + auto MNumBlock = M / config.innerMostMBlock; + auto NNumBlock = N / config.innerMostNBlock; + auto KNumBlock = K / config.innerMostKBlock; + + // Threads + config.MThreads = 32; + config.NThreads = 1; + config.KThreads = 1; + + // Block + config.MBlock = (int)llvm::divideCeil(MNumBlock, config.MThreads) * + config.innerMostMBlock; + config.NBlock = (int)llvm::divideCeil(NNumBlock, config.NThreads) * + config.innerMostNBlock; + config.KBlock = (int)llvm::divideCeil(KNumBlock, config.KThreads) * + config.innerMostKBlock; + config.MBlock = 128; + config.NBlock = 128; + config.KBlock = 128; + config.MThreads = 2; + config.NThreads = 2; + config.KThreads = 1; + + llvm::outs() << "M: " << M << ", N: " << N << ", K: " << K << "\n"; + + SmallVector> costModelList = { + {threadUtilizationCost, "threadUtilizationCost"}, + {hardwareEfficiencyCost, "hardwareEfficiencyCost"}, + {workloadBalancedCost, "workloadBalancedCost"}, + {memoryConsumptionOnThreadCost, "memoryConsumptionOnThreadCost"}, + {computationIntensityOnL1Cache, "computationIntensityOnL1Cache"}}; + + auto configCandidates = + prepareConfigCandidates(root, sysDesc, {M, N, K}, givenInnermostBlock); + + for (auto [fn, name] : costModelList) { + llvm::outs() << name << "\n\n"; + configCandidates = filterConfigByCostModel(configCandidates, linalgOp, + {M, N, K}, sysDesc, fn, 0.5); + llvm::outs() << "ConfigCandidates size: " << configCandidates.size() + << "\n"; + } + + if (!configCandidates.empty()) { + config = configCandidates[0]; + } + + llvm::outs() << "Final config\nNumThreads: " << sysDesc.getNumThreads() + << ", MatmulConfig: " << config << "\n"; + } +} +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/CMakeLists.txt b/lib/gc/CMakeLists.txt index 03f7023b..02c20a01 100644 --- a/lib/gc/CMakeLists.txt +++ b/lib/gc/CMakeLists.txt @@ -4,6 +4,7 @@ endif() include(functions) +add_subdirectory(Analysis) add_subdirectory(CAPI) add_subdirectory(Dialect) add_subdirectory(Transforms) diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index eabc434c..5b7f3441 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "./Tiling.hpp" +#include "gc/Analysis/MatmulConfigAnalysis.h" #include "gc/Dialect/Arith/Utils/EasyBuild.h" #include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/IR/EasyBuild.h" @@ -45,101 +46,6 @@ namespace gc { namespace { -struct SystemDesc { - // get runtime OMP_NUM_THREADS - uint32_t getNumThreads(); - // get cache size by cacheLevel - size_t getCacheSize(uint8_t cacheLevel); -}; - -struct MatmulConfig { - int MBlock, NBlock, KBlock; - int MThreads, NThreads, KThreads; - int innerMostMBlock, innerMostNBlock, innerMostKBlock; -}; - -template inline T divAndCeil(T a, T b) { return (a - 1) / b + 1; } - -enum DimType { Batch, M, N, K }; - -static FailureOr>> -getOprandDimType(linalg::LinalgOp &linalgOp) { - if (isa(linalgOp)) { - return SmallVector>{ - SmallVector{DimType::M, DimType::K}, - SmallVector{DimType::K, DimType::N}, - SmallVector{DimType::M, DimType::N}}; - } else if (llvm::isa(linalgOp)) { - return SmallVector>{ - SmallVector{DimType::M, DimType::K}, - SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, - DimType::K}, - SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; - } else if (llvm::isa(linalgOp)) { - return SmallVector>{ - SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, - SmallVector{DimType::N, DimType::K, DimType::K, DimType::N, - DimType::K}, - SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; - } else if (llvm::isa(linalgOp)) { - return SmallVector>{ - SmallVector{DimType::Batch, DimType::M, DimType::K}, - SmallVector{DimType::Batch, DimType::K, DimType::N}, - SmallVector{DimType::Batch, DimType::M, DimType::N}}; - } - return failure(); -} - -[[maybe_unused]] static SmallVector -extractDimTypeIdx(ArrayRef tyList, DimType ty) { - SmallVector idxList; - for (auto [idx, type] : llvm::enumerate(tyList)) { - if (type == ty) { - idxList.push_back(idx); - } - } - return idxList; -} - -MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) { - // TODO: build a more complex heuristic to determine the best tiling - auto M = linalgOp.getShape(linalgOp.getDpsInputOperand(0))[0]; - auto N = linalgOp.getShape(linalgOp.getDpsInputOperand(1))[1]; - auto K = linalgOp.getShape(linalgOp.getDpsInputOperand(1))[0]; - MatmulConfig cfg; - - // innermost Block - auto defaultBlock = 32; - cfg.innerMostMBlock = M % defaultBlock == 0 ? defaultBlock : M; - cfg.innerMostNBlock = N % defaultBlock == 0 ? defaultBlock : N; - cfg.innerMostKBlock = K % defaultBlock == 0 ? defaultBlock : K; - - // Number of block - auto MNumBlock = M / cfg.innerMostMBlock; - auto NNumBlock = N / cfg.innerMostNBlock; - auto KNumBlock = K / cfg.innerMostKBlock; - - // Threads - cfg.MThreads = 32; - cfg.NThreads = 1; - cfg.KThreads = 1; - - // Block - cfg.MBlock = divAndCeil((int)MNumBlock, cfg.MThreads) * cfg.innerMostMBlock; - cfg.NBlock = divAndCeil((int)NNumBlock, cfg.NThreads) * cfg.innerMostNBlock; - cfg.KBlock = divAndCeil((int)KNumBlock, cfg.KThreads) * cfg.innerMostKBlock; - cfg.innerMostMBlock = 32; - cfg.innerMostNBlock = 32; - cfg.innerMostKBlock = 32; - cfg.MBlock = 64; - cfg.NBlock = 64; - cfg.KBlock = 64; - cfg.MThreads = 2; - cfg.NThreads = 2; - cfg.KThreads = 1; - return cfg; -} - static Value tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType, Value value, @@ -478,9 +384,9 @@ using FinalReduceCallBackFn = std::function( struct OuterLoopGenerationOption { enum LoopType { ForOp, ForallOp }; - SmallVector> nestedTileSizes; + SmallVector> nestedTileSizes; SmallVector loopType; - SmallVector> loopDim; + SmallVector> loopDim; SmallVector innermostFullResultCallBacks; SmallVector finalReduceCallBacks; bool isPartialResult = false; @@ -657,7 +563,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { FailureOr outerLoopGeneration(RewriterBase &rewriter, linalg::LinalgOp linalgOp, - MatmulConfig cfg, bool hasFillOp) const { + gc::MatmulConfig cfg, bool hasFillOp) const { SmallVector KDimPos, MDimPos, NDimPos; linalgOp.getReductionDims(KDimPos); getMatmulParallelDims(linalgOp, 0, MDimPos); @@ -665,23 +571,26 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { OuterLoopGenerationOption option; auto iteratorTypes = linalgOp.getIteratorTypesArray(); - auto KFirstDim = (int)getOprandDim(linalgOp, KDimPos[0], 1); - auto MFirstDim = (int)getOprandDim(linalgOp, MDimPos[0], 0); - auto NFirstDim = (int)getOprandDim(linalgOp, NDimPos[0], 1); + auto KFirstDim = getOprandDim(linalgOp, KDimPos[0], 1); + auto MFirstDim = getOprandDim(linalgOp, MDimPos[0], 0); + auto NFirstDim = getOprandDim(linalgOp, NDimPos[0], 1); auto KParallelBlockSize = KDimPos.size() > 1 - ? divAndCeil(KFirstDim, cfg.KThreads) - : divAndCeil(divAndCeil(KFirstDim, cfg.KBlock), cfg.KThreads) * + ? llvm::divideCeil(KFirstDim, cfg.KThreads) + : llvm::divideCeil(llvm::divideCeil(KFirstDim, cfg.KBlock), + cfg.KThreads) * cfg.KBlock; auto MParallelBlockSize = MDimPos.size() > 1 - ? divAndCeil(MFirstDim, cfg.MThreads) - : divAndCeil(divAndCeil(MFirstDim, cfg.MBlock), cfg.MThreads) * + ? llvm::divideCeil(MFirstDim, cfg.MThreads) + : llvm::divideCeil(llvm::divideCeil(MFirstDim, cfg.MBlock), + cfg.MThreads) * cfg.MBlock; auto NParallelBlockSize = NDimPos.size() > 1 - ? divAndCeil(NFirstDim, cfg.NThreads) - : divAndCeil(divAndCeil(NFirstDim, cfg.NBlock), cfg.NThreads) * + ? llvm::divideCeil(NFirstDim, cfg.NThreads) + : llvm::divideCeil(llvm::divideCeil(NFirstDim, cfg.NBlock), + cfg.NThreads) * cfg.NBlock; auto KOuterBlockSize = KDimPos.size() > 1 ? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1 @@ -693,46 +602,45 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 : cfg.NBlock; // Outer - option.nestedTileSizes.emplace_back(SmallVector{ + option.nestedTileSizes.emplace_back(SmallVector{ MParallelBlockSize, NParallelBlockSize, KParallelBlockSize}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp); option.loopDim.emplace_back( - SmallVector{(int)MDimPos[0], (int)NDimPos[0], (int)KDimPos[0]}); + SmallVector{MDimPos[0], NDimPos[0], KDimPos[0]}); // Middle for (auto [tile, dim] : - llvm::zip(SmallVector{MOuterBlockSize, NOuterBlockSize, - KOuterBlockSize}, - SmallVector{(int)MDimPos[0], (int)NDimPos[0], - (int)KDimPos[0]})) { - option.nestedTileSizes.emplace_back(SmallVector{tile}); + llvm::zip(SmallVector{MOuterBlockSize, NOuterBlockSize, + KOuterBlockSize}, + SmallVector{MDimPos[0], NDimPos[0], KDimPos[0]})) { + option.nestedTileSizes.emplace_back(SmallVector{tile}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); - option.loopDim.emplace_back(SmallVector{dim}); + option.loopDim.emplace_back(SmallVector{dim}); } // Inner if (KDimPos.size() == 1) { - option.nestedTileSizes.emplace_back(SmallVector{cfg.KBlock}); + option.nestedTileSizes.emplace_back(SmallVector{cfg.KBlock}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); - option.loopDim.emplace_back(SmallVector{(int)KDimPos.back()}); + option.loopDim.emplace_back(SmallVector{KDimPos.back()}); } if (MDimPos.size() == 1) { option.nestedTileSizes.emplace_back( - SmallVector{cfg.innerMostMBlock}); + SmallVector{cfg.innerMostMBlock}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); - option.loopDim.emplace_back(SmallVector{(int)MDimPos.back()}); + option.loopDim.emplace_back(SmallVector{MDimPos.back()}); } if (NDimPos.size() == 1) { option.nestedTileSizes.emplace_back( - SmallVector{cfg.innerMostNBlock}); + SmallVector{cfg.innerMostNBlock}); option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp); - option.loopDim.emplace_back(SmallVector{(int)NDimPos.back()}); + option.loopDim.emplace_back(SmallVector{NDimPos.back()}); } for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) { if (dim != MDimPos.back() && dim != NDimPos.back() && iteratorTypes[dim] != mlir::utils::IteratorType::reduction) { - option.nestedTileSizes.emplace_back(SmallVector{1}); + option.nestedTileSizes.emplace_back(SmallVector{1}); option.loopType.emplace_back( OuterLoopGenerationOption::LoopType::ForOp); - option.loopDim.emplace_back(SmallVector{(int)dim}); + option.loopDim.emplace_back(SmallVector{dim}); } } @@ -784,7 +692,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()}; auto operandDimTypes = getOprandDimType(originOp); - MatmulConfig cfg = getDefaultMatmulConfig(originOp); + auto cfg = MatmulConfigAnalysis(originOp.getOperation()).getConfig(); auto AShape = originOp.getShape(originOp.getDpsInputOperand(0)); auto BShape = originOp.getShape(originOp.getDpsInputOperand(1)); auto CShape = originOp.getShape(originOp.getDpsInitOperand(0)); @@ -946,7 +854,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { auto upBound = eb.wrap(*loop.getSingleUpperBound()); auto step = eb.wrap(*loop.getSingleStep()); - auto currentCond = (induceVar + step) > upBound; + auto currentCond = (induceVar + step) >= upBound; cond = cond & currentCond; } EB_scf_if(cond, {currentOp.getDpsInits().back().getType()}) { @@ -1027,12 +935,12 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { rewriter.setInsertionPoint(linalgOp); linalg::LinalgOp originOp = dyn_cast(*rewriter.clone(*(linalgOp.getOperation()))); - linalgOp = *linalg::generalizeNamedOp(rewriter, linalgOp); Operation *fillOp = findParentFillOp(linalgOp.getDpsInits()[0]); // Step 1. Split matmul(bf16xbf16->bf16) to matmul(bf16xbf16->f32) + // cast(f32->bf16) if K slicing is needed - MatmulConfig cfg = getDefaultMatmulConfig(linalgOp); + auto cfg = MatmulConfigAnalysis(originOp.getOperation()).getConfig(); + linalgOp = *linalg::generalizeNamedOp(rewriter, linalgOp); bool needLowPrecisionCast = needToLegalizeDtype(linalgOp); if (cfg.KThreads > 1) { auto result = matmulDtypeLegalize(rewriter, linalgOp.getOperation()); diff --git a/src/gc-opt/CMakeLists.txt b/src/gc-opt/CMakeLists.txt index 36ace684..1ce25d83 100644 --- a/src/gc-opt/CMakeLists.txt +++ b/src/gc-opt/CMakeLists.txt @@ -11,12 +11,14 @@ else() get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) endif() - +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) set(gc_opt_libs ${dialect_libs} ${conversion_libs} + ${extension_libs} ${MLIR_LINK_COMPONENTS} - GCPasses) + GCPasses + GCAnalysis) if(GC_MLIR_CXX_FLAGS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GC_MLIR_CXX_FLAGS}")