Skip to content

Commit

Permalink
[torch-frontend] support convert math ops to custom call (#392)
Browse files Browse the repository at this point in the history
so that it could avoid to decompose mat ops.
  • Loading branch information
qingyunqu committed Jul 5, 2024
1 parent e1546f4 commit 12f2fd6
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,23 @@
#define TORCH_FRONTEND_CONVERSION_CONVERTTORCHTOCUSTOMCALL_H

#include "mlir/Pass/Pass.h"
#include "llvm/ADT/StringSet.h"
#include <memory>
#include <string>

namespace mlir {
class ConversionTarget;
class RewritePatternSet;
class TypeConverter;
namespace func {
class FuncOp;
} // namespace func

void populateMathToCustomCallPattern(
ConversionTarget &target, TypeConverter &typeConverter,
RewritePatternSet &patterns,
const llvm::StringSet<> &validCustomCallOpsSet);

std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToCustomCall(ArrayRef<std::string> validCustomCallOps);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,91 +56,6 @@ llvm::SmallVector<NamedAttribute> getDefaultAttrs(PatternRewriter &rewriter) {
return attrs;
}

template <typename OP>
stablehlo::ConstantOp createInitialValueForReduceOp(PatternRewriter &rewriter,
Location loc,
Type elementTy);

template <>
stablehlo::ConstantOp
createInitialValueForReduceOp<stablehlo::MaxOp>(PatternRewriter &rewriter,
Location loc, Type elementTy) {
auto constType = RankedTensorType::get({}, elementTy);
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/true)});
return rewriter.create<stablehlo::ConstantOp>(loc, constType, constAttr);
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType,
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<stablehlo::ConstantOp>(loc, constType, constAttr);
}
assert(false && "unimplemented lowering in createInitialValueForReduceOp");
return nullptr;
}

template <>
stablehlo::ConstantOp
createInitialValueForReduceOp<stablehlo::AddOp>(PatternRewriter &rewriter,
Location loc, Type elementTy) {
auto constType = RankedTensorType::get({}, elementTy);
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
{APFloat::getZero(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(loc, constType, constAttr);
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<stablehlo::ConstantOp>(loc, constType, constAttr);
}
assert(false && "unimplemented lowering in createInitialValueForReduceOp");
return nullptr;
}

template <typename OP>
stablehlo::ReduceOp createSingleOpReduce(PatternRewriter &rewriter,
Location loc, Value input,
llvm::SmallVector<int64_t> dims) {
llvm::sort(dims.begin(), dims.end());
auto inputType = cast<RankedTensorType>(input.getType());
stablehlo::ConstantOp initValue = createInitialValueForReduceOp<OP>(
rewriter, loc, inputType.getElementType());

std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> outputShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dimsSet.find(i) == dimsSet.end()) {
outputShape.push_back(inputType.getDimSize(i));
}
}
stablehlo::ReduceOp reduceOp = rewriter.create<stablehlo::ReduceOp>(
loc, RankedTensorType::get(outputShape, inputType.getElementType()),
input, initValue.getOutput(), rewriter.getDenseI64ArrayAttr(dims));

Block &block = reduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputType.getElementType());
block.addArgument(blockArgumentTy, loc);
block.addArgument(blockArgumentTy, loc);
auto firstArgument = *block.args_begin();
auto secondArgument = *block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result = rewriter.create<OP>(loc, blockArgumentTy, firstArgument,
secondArgument);
rewriter.create<stablehlo::ReturnOp>(loc, result);
}

return reduceOp;
}

Value promoteType(Location loc, Value input, TensorType desiredType,
PatternRewriter &rewriter) {
TensorType inType = dyn_cast<TensorType>(input.getType());
Expand Down Expand Up @@ -1137,15 +1052,13 @@ class ConvertAtenNonzeroOp : public OpConversionPattern<AtenNonzeroOp> {
matchAndRewrite(AtenNonzeroOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<Value> bufferArgs({input});
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
if (!resultType) {
return op.emitError("could not convert output types");
}

std::vector<NamedAttribute> byteir_attrs;

auto attrs = getDefaultAttrs(rewriter);
attrs.emplace_back(rewriter.getStringAttr("call_target_name"),
rewriter.getStringAttr(getNonZeroName()));
Expand All @@ -1161,6 +1074,46 @@ class ConvertAtenNonzeroOp : public OpConversionPattern<AtenNonzeroOp> {
};
} // namespace

// math ops
namespace {
template <typename AtenOpT>
class ConvertMathOp : public OpConversionPattern<AtenOpT> {
public:
ConvertMathOp(const TypeConverter &typeConverter, MLIRContext *context,
llvm::StringRef targetName)
: OpConversionPattern<AtenOpT>(typeConverter, context),
callTargetName(targetName) {}
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
Type resultType =
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getResult().getType());
if (!resultType) {
return op.emitError("could not convert output types");
}

std::vector<NamedAttribute> byteir_attrs;
auto attrs = getDefaultAttrs(rewriter);
attrs.emplace_back(rewriter.getStringAttr("call_target_name"),
rewriter.getStringAttr(this->callTargetName));
attrs.emplace_back(rewriter.getStringAttr(getCustomCallAttrName()),
rewriter.getDictionaryAttr(byteir_attrs));

auto customCallOp = rewriter.create<stablehlo::CustomCallOp>(
op->getLoc(), TypeRange{resultType}, ValueRange{input},
ArrayRef<NamedAttribute>(attrs));
rewriter.replaceOp(op, customCallOp->getResults());
return success();
}

private:
std::string callTargetName;
};
} // namespace

namespace {
class ConvertTorchToCustomCall
: public ConvertTorchToCustomCallBase<ConvertTorchToCustomCall> {
Expand Down Expand Up @@ -1268,9 +1221,13 @@ class ConvertTorchToCustomCall
target.addIllegalOp<AtenTopkOp>();
patterns.add<ConvertAtenTopkOp>(typeConverter, context);
}
if (validCustomCallOpsSet.contains("aten.nonzero")) {
target.addIllegalOp<AtenNonzeroOp>();
patterns.add<ConvertAtenNonzeroOp>(typeConverter, context);
}

target.addIllegalOp<AtenNonzeroOp>();
patterns.add<ConvertAtenNonzeroOp>(typeConverter, context);
populateMathToCustomCallPattern(target, typeConverter, patterns,
validCustomCallOpsSet);

target.addIllegalOp<CustomOp>();
patterns.add<ConvertDynamicPartitionCustomOp>(typeConverter, context);
Expand All @@ -1293,6 +1250,30 @@ class ConvertTorchToCustomCall
};
} // namespace

void mlir::populateMathToCustomCallPattern(
ConversionTarget &target, TypeConverter &typeConverter,
RewritePatternSet &patterns,
const llvm::StringSet<> &validCustomCallOpsSet) {
#define CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenOp, MathOpName) \
if (validCustomCallOpsSet.contains(AtenOp::getOperationName())) { \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertMathOp<AtenOp>>(typeConverter, patterns.getContext(), \
MathOpName); \
}

CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenAsinOp, "math.asin");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenAsinhOp, "math.asinh");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenSinhOp, "math.sinh");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenAtanOp, "math.atan");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenTanOp, "math.tan");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenAcosOp, "math.acos");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenAcoshOp, "math.acosh");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenCoshOp, "math.cosh");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenErfOp, "math.erf");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenTruncOp, "math.trunc");
#undef CONVERT_MATH_TO_CUSTOM_CALL_PATTERN
}

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createConvertTorchToCustomCall(ArrayRef<std::string> validCustomCallOps) {
return std::make_unique<ConvertTorchToCustomCall>(validCustomCallOps);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._mlir_libs._torchFrontend import *

from .compile import DebugType, GENERIC_CUSTOM_OPS, BYTEIR_CUSTOM_OPS
from .compile import DebugType, GENERIC_CUSTOM_OPS, BYTEIR_CUSTOM_OPS, MATH_CUSTOM_OPS
from .compile import compile, compile_dynamo_model

from .fx_utils import list_decomposed_ops, preprocess_fx_graph, get_none_indices
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@
"aten.min.dim",
"aten.one_hot",
"aten.topk",
"aten.nonzero",
]

MATH_CUSTOM_OPS = [
"aten.asin",
"aten.asinh",
"aten.sinh",
"aten.atan",
"aten.tan",
"aten.acos",
"aten.acosh",
"aten.cosh",
"aten.erf",
"aten.trunc",
]

BYTEIR_CUSTOM_OPS = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
// RUN: torch-frontend-opt %s -convert-torch-to-custom-call="valid-custom-call-ops=aten.native_layer_norm,aten.layer_norm,aten.native_group_norm,aten.group_norm,aten._softmax,aten.softmax.int,aten._log_softmax,aten.log_softmax.int,aten.nll_loss_forward,aten.nll_loss_backward,aten.gelu,aten.max.dim,aten.min.dim,aten.one_hot,aten.topk" --canonicalize-ext | FileCheck %s

// RUN: torch-frontend-opt %s -convert-torch-to-custom-call="valid-custom-call-ops=aten.native_layer_norm,aten.layer_norm,aten.native_group_norm,aten.group_norm,aten._softmax,aten.softmax.int,aten._log_softmax,aten.log_softmax.int,aten.nll_loss_forward,aten.nll_loss_backward,aten.gelu,aten.max.dim,aten.min.dim,aten.one_hot,aten.topk,aten.nonzero" --canonicalize-ext | FileCheck %s
// RUN: torch-frontend-opt %s -convert-torch-to-custom-call --canonicalize-ext | FileCheck %s --check-prefix NONE
// RUN: torch-frontend-opt %s -convert-torch-to-custom-call="valid-custom-call-ops=aten.asin" --canonicalize-ext | FileCheck %s --check-prefix NONE

func.func @torch.aten.asin(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.asin %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// MATH-LABEL: func.func @torch.aten.asin
// MATH: stablehlo.custom_call
// MATH-SAME: @math.asin
// MATH: byteir_attrs = {}
// MATH-NOT: torch.aten.asin

func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%str = torch.constant.str "tanh"
Expand Down

0 comments on commit 12f2fd6

Please sign in to comment.