Skip to content

Commit

Permalink
[compiler] fix inf/nan convert on x86_64 arch
Browse files Browse the repository at this point in the history
  • Loading branch information
jianwenyyy committed Jun 27, 2024
1 parent dd35694 commit 8ff7445
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 24 deletions.
6 changes: 5 additions & 1 deletion compiler/include/byteir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def HloFusionToLinalg : Pass<"hlo-fusion-to-linalg", "func::FuncOp"> {
Option<"enablePrimitiveOps", "enable-primitive-ops", "bool",
/*default=*/"false",
"Lower to primitive Linalg ops (map, reduce and "
"transpose) when possible, instead of linalg.generic">
"transpose) when possible, instead of linalg.generic">,
Option<"target", "target", "std::string", /*default*/ "",
"Specificy the target">,
Option<"arch", "arch", "std::string", /*default*/ "",
"Specificy the target arch">
];
}

Expand Down
10 changes: 6 additions & 4 deletions compiler/include/byteir/Conversion/ToLinalg/ToLinalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ void populateTensorToLinalgConversionPatterns(RewritePatternSet &patterns);
void populateLinalgExtToLinalgConversionPatterns(RewritePatternSet &patterns);

void populateHloToLinalgExtConversionPattern(TypeConverter &typeConverter,
RewritePatternSet &patterns);
RewritePatternSet &patterns,
const std::string &target = "",
const std::string &arch = "");

std::unique_ptr<OperationPass<func::FuncOp>>
createHloFusionToLinalgPass(llvm::StringRef anchorTag = "",
bool enablePrimitiveOps = false);
std::unique_ptr<OperationPass<func::FuncOp>> createHloFusionToLinalgPass(
llvm::StringRef anchorTag = "", bool enablePrimitiveOps = false,
const std::string &target = "", const std::string &arch = "");

std::unique_ptr<OperationPass<func::FuncOp>> createUnrealizedCastToLinalgPass();

Expand Down
3 changes: 3 additions & 0 deletions compiler/include/byteir/Pipelines/LinalgTensorOpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ struct LinalgTensorOptPipelineOptions
*this, "target",
llvm::cl::desc("An optional attribute to speicify target."),
llvm::cl::init("")};
Option<std::string> arch{
*this, "arch", llvm::cl::desc("An optional attribute to speicify arch."),
llvm::cl::init("")};
};

void createLinalgTensorOptPipeline(
Expand Down
197 changes: 185 additions & 12 deletions compiler/lib/Conversion/ToLinalg/HloToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1270,10 +1270,13 @@ class ByteirRepeatCustomCallConverter
struct HloFusionToLinalgPass
: public HloFusionToLinalgBase<HloFusionToLinalgPass> {

HloFusionToLinalgPass(StringRef tag, bool enablePrimitiveOps)
HloFusionToLinalgPass(StringRef tag, bool enablePrimitiveOps,
StringRef target, StringRef arch)
: HloFusionToLinalgBase() {
anchorTag = tag.str();
this->enablePrimitiveOps = enablePrimitiveOps;
this->target = target.str();
this->arch = arch.str();
}

void getDependentDialects(DialectRegistry &registry) const final {
Expand All @@ -1293,13 +1296,13 @@ struct HloFusionToLinalgPass

MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
ConversionTarget target(ctx);
target.addLegalDialect<
ConversionTarget conversionTarget(ctx);
conversionTarget.addLegalDialect<
arith::ArithDialect, cf::ControlFlowDialect, func::FuncDialect,
linalg::LinalgDialect, math::MathDialect, tensor::TensorDialect,
scf::SCFDialect, shape::ShapeDialect, linalg_ext::LinalgExtDialect>();

target.addLegalOp<UnrealizedConversionCastOp>();
conversionTarget.addLegalOp<UnrealizedConversionCastOp>();

auto typeConverter = createHloToLinalgTypeConverter();

Expand All @@ -1308,22 +1311,191 @@ struct HloFusionToLinalgPass
[](Operation *op) { return isInBodyOfLinalgOps(op); });
mhlo::populateHloToLinalgConversionPattern(&ctx, *typeConverter, &patterns,
enablePrimitiveOps);
populateHloToLinalgExtConversionPattern(*typeConverter, patterns);
populateHloToLinalgExtConversionPattern(*typeConverter, patterns,
this->target, this->arch);

FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (failed(applyPartialConversion(func, target, frozenPatterns))) {
if (failed(
applyPartialConversion(func, conversionTarget, frozenPatterns))) {
signalPassFailure();
}
}
};

/// Code below is copied from legalize_to_linalg.cc
/// Remove this when upstream FPToSIOp solves inf/nan convert.
Value coerceTensorShape(OpBuilder &builder, Location loc,
TypedValue<ShapedType> value, ShapedType targetType) {
return builder.createOrFold<tensor::CastOp>(
loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()),
value);
}

inline Value mapFPToSIConvertOpToStdScalarOp(Location loc,
ArrayRef<Type> targetTypes,
ArrayRef<Type> resultTypes,
ArrayRef<Type> argTypes,
ValueRange args, OpBuilder *b) {
assert(targetTypes.size() == 1 && "ConvertOp should return a single result");
assert(resultTypes.size() == 1 && "ConvertOp should return a single result");
assert(argTypes.size() == 1 && "ConvertOp should take a single argument");
assert(args.size() == 1 && "ConvertOp should take a single argument");

Type sourceType = getElementTypeOrSelf(argTypes.front());
Type targetType = getElementTypeOrSelf(targetTypes.front());
Type convertedSourceType = getElementTypeOrSelf(args.front());

if (mlir::arith::FPToSIOp::areCastCompatible(convertedSourceType,
targetType)) {
Value infValue = b->create<mlir::arith::ConstantOp>(
loc,
b->getFloatAttr(
convertedSourceType,
APFloat::getInf(
dyn_cast<FloatType>(convertedSourceType).getFloatSemantics())));
Value isInf = b->create<mlir::arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, args.front(),
infValue); // (todo:yjw) args front()??
Value isNan = b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
args.front(), args.front());
Value maxIntval = b->create<arith::ConstantOp>(
loc,
b->getIntegerAttr(targetType,
APInt::getSignedMaxValue(
dyn_cast<IntegerType>(targetType).getWidth())));
Value zeroIntval =
b->create<arith::ConstantOp>(loc, b->getZeroAttr(targetType));
return b->create<::mlir::arith::SelectOp>(
loc, isInf, maxIntval,
b->create<::mlir::arith::SelectOp>(
loc, isNan, zeroIntval,
b->create<mlir::arith::FPToSIOp>(loc, resultTypes, args,
std::nullopt)));
}
return nullptr;
}

class FPToSIConvertOpConverter : public OpConversionPattern<mhlo::ConvertOp> {
public:
using OpConversionPattern<mhlo::ConvertOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(mhlo::ConvertOp op, typename mhlo::ConvertOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// Apply only if FPToInt32
if (!mlir::arith::FPToSIOp::areCastCompatible(op.getOperand().getType(),
op.getType())) {
return failure();
}
auto targetType = op.getType().getElementType();
if (isa<IntegerType>(targetType) &&
cast<IntegerType>(targetType).getWidth() != 32) {
return failure();
}

auto loc = op.getLoc();
int64_t maxRank = getMaxRank(adaptor);
// Apply only if all operands are scalar or have the same rank.
if (!llvm::all_of(adaptor.getOperands(), [&](Value v) {
int64_t r = getRank(v);
return r == 0 || r == maxRank;
})) {
return rewriter.notifyMatchFailure(
op, "Operands must be of same rank or scalar.");
}
// Find result type, if on tensors.
std::optional<ShapedType> resultTy;
resultTy = this->typeConverter->convertType(op->getResultTypes().front())
.template dyn_cast<ShapedType>();

// Check result type compatibility.
if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != maxRank ||
!(resultTy->getElementType().isSignlessIntOrFloat() ||
resultTy->getElementType().isa<ComplexType>())) {
return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count");
}
// All-scalar pointwise ops inside of linalg ops are processes by
// ScalarHloToArithmeticPattern.
if (maxRank == 0 && isInBodyOfLinalgOps(op))
return failure();

// Find input/output values and types.
Value emptyTensor =
getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands());
// Mapped inputs are cast to the same shape as the init tensor.
// Values from scalar inputs are extracted and used directly in the block.
SmallVector<Value> mappedInputs;
SmallVector<Value> scalarInputs;
for (Value input : adaptor.getOperands()) {
if (getRank(input) == maxRank) {
mappedInputs.push_back(coerceTensorShape(
rewriter, loc, cast<TypedValue<ShapedType>>(input),
cast<ShapedType>(emptyTensor.getType())));
scalarInputs.push_back(nullptr);
} else {
scalarInputs.push_back(rewriter.create<tensor::ExtractOp>(loc, input));
}
}

auto mapOp = rewriter.create<linalg::MapOp>(
loc, mappedInputs, emptyTensor,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value innerResult = mapFPToSIConvertOpToStdScalarOp(
op.getLoc(), op.getType(), getElementTypeOrSelf(emptyTensor),
llvm::to_vector(op->getOperandTypes()),
interleaveScalarAndBlockArgs(scalarInputs, args), &b);
b.create<linalg::YieldOp>(loc, innerResult);
},
linalg::getPrunedAttributeList(op));
rewriter.replaceOp(op, mapOp->getResults());
return success();
}

protected:
int64_t getRank(Value v) const {
return v.getType().cast<ShapedType>().getRank();
}

int64_t getMaxRank(typename mhlo::ConvertOp::Adaptor adaptor) const {
int64_t maxRank = 0;
for (auto operand : adaptor.getOperands()) {
maxRank = std::max(maxRank, getRank(operand));
}
return maxRank;
}

// Inserts block arguments in places where scalar inputs have a nullptr.
SmallVector<Value> interleaveScalarAndBlockArgs(ValueRange scalarInputs,
ValueRange blockArgs) const {
SmallVector<Value> result;
auto argsIter = blockArgs.begin();
for (Value scalarInput : scalarInputs) {
if (scalarInput) {
result.push_back(scalarInput);
} else {
result.push_back(*argsIter);
++argsIter;
}
}
return result;
}
};

} // namespace

void mlir::populateHloToLinalgExtConversionPattern(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
void mlir::populateHloToLinalgExtConversionPattern(TypeConverter &typeConverter,
RewritePatternSet &patterns,
const std::string &target,
const std::string &arch) {
auto ctx = patterns.getContext();
patterns.add<ReduceWindowOpConversion>(typeConverter, ctx, PatternBenefit(2));
patterns.add<DotGeneralLinalgExtBatchMatMulOpConversion>(typeConverter, ctx,
PatternBenefit(2));
if (target == "cpu" && arch == "x86_64") {
patterns.add<FPToSIConvertOpConverter>(typeConverter, ctx,
PatternBenefit(2));
}
patterns.add<SoftmaxCustomCallConverter>(ctx);
patterns.add<ScatterOpConversion>(ctx);
patterns.add<LayerNormCustomCallConverter>(ctx);
Expand All @@ -1333,8 +1505,9 @@ void mlir::populateHloToLinalgExtConversionPattern(
patterns.add<ByteirRepeatCustomCallConverter>(ctx);
}

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createHloFusionToLinalgPass(llvm::StringRef anchorTag,
bool enablePrimitiveOps) {
return std::make_unique<HloFusionToLinalgPass>(anchorTag, enablePrimitiveOps);
std::unique_ptr<OperationPass<func::FuncOp>> mlir::createHloFusionToLinalgPass(
llvm::StringRef anchorTag, bool enablePrimitiveOps,
const std::string &target, const std::string &arch) {
return std::make_unique<HloFusionToLinalgPass>(anchorTag, enablePrimitiveOps,
target, arch);
}
12 changes: 7 additions & 5 deletions compiler/lib/Pipelines/LinalgTensorOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,10 @@ void addGenericLinalgPasses(OpPassManager &pm) {
}
}

void addCPULinalgOptPasses(OpPassManager &pm) {
void addCPULinalgOptPasses(OpPassManager &pm, const std::string &target,
const std::string &arch) {
pm.addNestedPass<func::FuncOp>(createHloFusionToLinalgPass(
getByteIRHloAggressiveFusionAttrName(), true));
getByteIRHloAggressiveFusionAttrName(), true, target, arch));
pm.addNestedPass<func::FuncOp>(createUnrealizedCastToLinalgPass());
{
TileAndVectorizeTransposeOptions options;
Expand All @@ -248,9 +249,10 @@ void addCPULinalgOptPasses(OpPassManager &pm) {
}

void createLinalgTensorOptPipelineImpl(OpPassManager &pm,
const std::string &target) {
const std::string &target,
const std::string &arch) {
if (target == "cpu") {
addCPULinalgOptPasses(pm);
addCPULinalgOptPasses(pm, target, arch);
} else {
addGenericLinalgPasses(pm);
}
Expand All @@ -260,5 +262,5 @@ void createLinalgTensorOptPipelineImpl(OpPassManager &pm,
void mlir::createLinalgTensorOptPipeline(
OpPassManager &pm, const LinalgTensorOptPipelineOptions &options) {
invokeOpPassPipelineBuilder(createLinalgTensorOptPipelineImpl, pm,
options.target);
options.target, options.arch);
}
3 changes: 2 additions & 1 deletion compiler/python/byteir/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,14 +286,15 @@ def _compile_cpu(

entry_func_str = "entry-func={}".format(entry_func)
target_str = "target={}".format(target)
arch_str="arch={}".format(cpu_arch)
with context:
PassManager().parse("builtin.module(hlo-graph-opt{" + entry_func_str + " " + target_str + "})").run(module.operation)
_print_verbose(module, "// IR Dump After Hlo Graph Opt:") if verbose else ...
with context:
PassManager().parse("builtin.module(hlo-fusion-opt{" + entry_func_str + " " + target_str + " outline-single-elemwise-op})").run(module.operation)
_print_verbose(module, "// IR Dump After Hlo Fusion Opt:") if verbose else ...
with context:
PassManager.parse("builtin.module(linalg-tensor-opt{" + target_str + "})").run(module.operation)
PassManager.parse("builtin.module(linalg-tensor-opt{" + target_str + " " + arch_str + "})").run(module.operation)
_print_verbose(module, "// IR Dump After Linalg Tensor Opt:") if verbose else ...
with context:
PassManager.parse("builtin.module(byre-tensor-opt{{append-arg-types {}}})".format(entry_func_str)).run(module.operation)
Expand Down
4 changes: 3 additions & 1 deletion tests/numerical_test/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
MLIR_TEST_SPECIAL_INPUTS = {
"cpu@log_plus_one.mlir": [
np.random.uniform(low=0.5, high=1.0, size=(256, 64)).astype(np.float16)
],
"cpu@convert_f32_i32_special_val.mlir": [
np.array([[np.inf, -np.inf, np.nan], [1., 999.999, -np.inf]], dtype=np.float32),
]
}

Expand Down Expand Up @@ -235,7 +238,6 @@ def compile_and_run_mlir(mhlo_file, target, verbose, mode="numerical", workdir="

interp = Interpreter.load_from_file(mhlo_file, is_stablehlo=True)
golden_outputs = interp.call_function(entry_func_name, np_inputs)

if unique_name is None:
unique_name = os.path.basename(mhlo_file).split(".")[0] + "." + target
# byteir compile
Expand Down
1 change: 1 addition & 0 deletions tests/numerical_test/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def report_results(results: List[TestResult]):
pass_set = []
for result in results:
if result.compilation_error is not None:
print(result.compilation_error)
fail_set.append([result.unique_name, 'compilation failed: ' +
result.unique_name + "\n" + result.compilation_error])
elif result.runtime_error is not None:
Expand Down

0 comments on commit 8ff7445

Please sign in to comment.