Skip to content

Commit

Permalink
generalize config
Browse files Browse the repository at this point in the history
  • Loading branch information
zhczhong committed Jun 25, 2024
1 parent ab1239f commit d2e3b51
Show file tree
Hide file tree
Showing 7 changed files with 571 additions and 129 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ endif()

set(GC_LIB_LINKED_LIBS
GCPasses
GCAnalysis
MLIROneDNNGraph
)
add_library(graph_compiler SHARED ${GC_LIB_SOURCES})
Expand Down
124 changes: 124 additions & 0 deletions include/gc/Analysis/MatmulConfigAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
//===- MatmulConfigAnalysis.h - Graph Compiler analysis pass ----------*- C++
//-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
#define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H

#include "gc/Dialect/Linalgx/LinalgxOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include <llvm/Support/Debug.h>
#include <memory>
#include <numeric>

namespace mlir {
namespace gc {

using namespace mlir;

struct SystemDesc {
// get runtime OMP_NUM_THREADS
uint32_t getNumThreads() {
char *numThreads = getenv("OMP_NUM_THREADS");
if (numThreads) {
return std::stoi(numThreads);
}
return 1;
}
// get cache size by cacheLevel
size_t getCacheSize(uint8_t cacheLevel) {
if (cacheLevel == 1) {
char *cacheSize = getenv("L1_CACHE_SIZE");
if (cacheSize) {
return std::stoi(cacheSize);
}
} else if (cacheLevel == 2) {
char *cacheSize = getenv("L2_CACHE_SIZE");
if (cacheSize) {
return std::stoi(cacheSize);
}
} else if (cacheLevel == 3) {
char *cacheSize = getenv("L3_CACHE_SIZE");
if (cacheSize) {
return std::stoi(cacheSize);
}
}
return 0;
}

SmallVector<size_t> getContractionOperationMaxVectorLength() {
return {512UL, 512UL};
}
};

struct MatmulConfig {
uint32_t MBlock, NBlock, KBlock;
uint32_t MThreads, NThreads, KThreads;
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
const MatmulConfig &config);
};

enum DimType { Batch, M, N, K };

[[maybe_unused]] static SmallVector<unsigned>
extractDimTypeIdx(ArrayRef<DimType> tyList, DimType ty) {
SmallVector<unsigned> idxList;
for (auto [idx, type] : llvm::enumerate(tyList)) {
if (type == ty) {
idxList.push_back(idx);
}
}
return idxList;
}

static FailureOr<SmallVector<SmallVector<DimType>>>
getOprandDimType(linalg::LinalgOp &linalgOp) {
if (isa<linalg::MatmulOp>(linalgOp)) {
return SmallVector<SmallVector<DimType>>{
SmallVector<DimType>{DimType::M, DimType::K},
SmallVector<DimType>{DimType::K, DimType::N},
SmallVector<DimType>{DimType::M, DimType::N}};
} else if (llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
return SmallVector<SmallVector<DimType>>{
SmallVector<DimType>{DimType::M, DimType::K},
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
DimType::K},
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
} else if (llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) {
return SmallVector<SmallVector<DimType>>{
SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
DimType::K},
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
} else if (llvm::isa<linalg::BatchMatmulOp>(linalgOp)) {
return SmallVector<SmallVector<DimType>>{
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
}
return failure();
}

struct MatmulConfigAnalysis {
public:
explicit MatmulConfigAnalysis(Operation *root);
MatmulConfig getConfig() { return config; }

private:
MatmulConfig config;
};

} // namespace gc
} // namespace mlir

#endif
19 changes: 19 additions & 0 deletions lib/gc/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
MLIRIR
MLIRSupport)

add_mlir_library(GCAnalysis
MatmulConfigAnalysis.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include

DEPENDS
GraphCompilerPassIncGen

LINK_LIBS PUBLIC
${mlir_dialect_libs}
${MLIR_LINK_COMPONENTS}
)

set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCAnalysis)
Loading

0 comments on commit d2e3b51

Please sign in to comment.