Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Update FindNonTrivialHero to work with HloInstructionAdaptor. #66713

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ HloFusionAnalysis HloFusionAnalysis::Create(
std::vector<const HloInstruction*> heroes;
for (auto root : fusion->GetRoots()) {
roots.push_back(&root.instruction());
heroes.push_back(&FindNonTrivialHero(*roots.back(), *fusion));
heroes.push_back(&FindNonTrivialHero(root).instruction());
}

std::vector<const HloInstruction*> fusion_arguments;
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/hlo_traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class HloInstructionAdaptor {

// Use sparingly; prefer extending the interface.
const HloInstruction& instruction() const { return *instruction_; }
const HloFusionAdaptor& parent() const { return *parent_; }

private:
const HloInstruction* instruction_;
Expand Down
31 changes: 16 additions & 15 deletions third_party/xla/xla/service/gpu/ir_emission_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count,
}

static std::optional<HloInstructionAdaptor> FindNonTrivialHero(
HloInstructionAdaptor root, const HloFusionAdaptor& fusion,
const HloInstructionAdaptor& root,
const std::function<bool(const HloInstruction&)>& predicate) {
std::optional<HloInstructionAdaptor> hero = std::nullopt;
auto visitor = [&](HloInstructionAdaptor node) {
Expand All @@ -710,7 +710,7 @@ static std::optional<HloInstructionAdaptor> FindNonTrivialHero(
}
return TraversalResult::kAdvance;
};
HloBfsConsumersFirstTraversal({root}, fusion, visitor);
HloBfsConsumersFirstTraversal({root}, root.parent(), visitor);
if (!hero) {
return std::nullopt;
}
Expand All @@ -727,23 +727,23 @@ static std::optional<HloInstructionAdaptor> FindNonTrivialHero(
/*add_single_user_check=*/true);
};
bool visit_operands = false;
if (HloAnyOf(hero->GetUsers(), fusion, is_nontrivial, visit_operands)) {
if (HloAnyOf(hero->GetUsers(), hero->parent(), is_nontrivial,
visit_operands)) {
return std::nullopt;
}

return hero;
}

const HloInstruction& FindNonTrivialHero(const HloInstruction& instr,
const HloFusionAdaptor& fusion) {
HloInstructionAdaptor hero{instr, &fusion};
HloInstructionAdaptor FindNonTrivialHero(const HloInstructionAdaptor& instr) {
HloInstructionAdaptor hero = instr;

// Go up the chain of trivial element-wise(+bitcast, -copy) operations. Note
// that no memoization is needed due to number of operands constraints: we
// never have to revisit same nodes.
while (IsIntermediate(&hero.instruction(), /*allowed_operand_count=*/1,
&fusion) &&
fusion.ContainsInstruction(hero.GetOperand(0))) {
&hero.parent()) &&
hero.parent().ContainsInstruction(hero.GetOperand(0))) {
hero = hero.GetOperand(0);
}

Expand All @@ -753,25 +753,26 @@ const HloInstruction& FindNonTrivialHero(const HloInstruction& instr,
auto is_transpose = [](const HloInstruction& node) {
return FindTiledLogicalTranspose(node).has_value();
};
if (auto transpose = FindNonTrivialHero(hero, fusion, is_transpose)) {
return transpose->instruction();
if (auto transpose = FindNonTrivialHero(hero, is_transpose)) {
return *transpose;
}
auto is_concatenate = [](const HloInstruction& node) {
return node.opcode() == HloOpcode::kConcatenate;
};
if (auto concatenate = FindNonTrivialHero(hero, fusion, is_concatenate)) {
return concatenate->instruction();
if (auto concatenate = FindNonTrivialHero(hero, is_concatenate)) {
return *concatenate;
}
if (hero.opcode() != HloOpcode::kReduce) {
return instr;
}
return hero.instruction();
return hero;
}

const HloInstruction& FindNonTrivialHero(const HloInstruction& instr) {
CHECK_NE(instr.opcode(), HloOpcode::kFusion);
return FindNonTrivialHero(instr,
*HloFusionAdaptor::ForComputation(instr.parent()));
auto fusion_adaptor = HloFusionAdaptor::ForComputation(instr.parent());
HloInstructionAdaptor instr_adaptor(instr, fusion_adaptor.get());
return FindNonTrivialHero(instr_adaptor).instruction();
}

void VLogModule(int level, const llvm::Module& module) {
Expand Down
16 changes: 6 additions & 10 deletions third_party/xla/xla/service/gpu/ir_emission_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,12 @@ std::vector<const HloInstruction*> GetOutputDefiningDynamicUpdateSlices(

Shape GetShape(mlir::Value value);

// `is_boundary` returns `true` for edges that are on the boundary of the
// fusion, i.e., they go from an instruction inside the fusion to one outside,
// or vice versa.
// Note: when this is called with a fusion instruction, it will traverse into
// the fusion (unless the boundary function stops it).
const HloInstruction& FindNonTrivialHero(const HloInstruction& instr,
const HloFusionAdaptor& fusion);

// Like above, but assumes the instruction is inside an HloFusionInstruction.
// Returns the instruction itself if it is an HloFusionInstruction.
// Returns the first hero instruction reachable from `instr` as root. Hero
// instruction can be in a different computation if the parent HloFusionAdaptor
// is a producer-consumer fusion.
HloInstructionAdaptor FindNonTrivialHero(const HloInstructionAdaptor& instr);

// Same as above, but fusion is the parent computation of the hlo instruction.
const HloInstruction& FindNonTrivialHero(const HloInstruction& instr);

/// Description of how to emit a given transposition.
Expand Down
31 changes: 14 additions & 17 deletions third_party/xla/xla/service/gpu/ir_emission_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusion) {

HloInstruction* r = module->entry_computation()->root_instruction();
auto fusion = HloFusionAdaptor::ForInstruction(r);
const auto& result =
FindNonTrivialHero(fusion->GetRoots()[0].instruction(), *fusion);
const auto& result = FindNonTrivialHero(fusion->GetRoots()[0]);
EXPECT_EQ(result.name(), "reduce.0");
}

Expand Down Expand Up @@ -187,11 +186,9 @@ TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusionTwoRootUsers) {

HloInstruction* r = module->entry_computation()->root_instruction();
auto fusion = HloFusionAdaptor::ForInstruction(r);
const auto& result =
FindNonTrivialHero(fusion->GetRoots()[1].instruction(), *fusion);
const auto& result = FindNonTrivialHero(fusion->GetRoots()[1]);
EXPECT_EQ(result.name(), "reduce.1");
const auto& result2 =
FindNonTrivialHero(fusion->GetRoots()[2].instruction(), *fusion);
const auto& result2 = FindNonTrivialHero(fusion->GetRoots()[2]);
EXPECT_EQ(result2.name(), "reduce.1");
}

Expand Down Expand Up @@ -225,13 +222,11 @@ TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusionHeroAlsoUsedAsNonHero) {

HloInstruction* r = module->entry_computation()->root_instruction();
auto fusion = HloFusionAdaptor::ForInstruction(r);
const auto& result =
FindNonTrivialHero(fusion->GetRoots()[1].instruction(), *fusion);
const auto& result = FindNonTrivialHero(fusion->GetRoots()[1]);
// reduce.0 is also an operand of broadcast, but it is not a hero for that
// root.
EXPECT_EQ(result.name(), "broadcast");
const auto& result2 =
FindNonTrivialHero(fusion->GetRoots()[2].instruction(), *fusion);
const auto& result2 = FindNonTrivialHero(fusion->GetRoots()[2]);
EXPECT_EQ(result2.name(), "reduce.0");
}

Expand All @@ -258,11 +253,9 @@ TEST_F(IrEmissionUtilsTest, DoNotFindTransposeHeroEpilogueFusionTwoRootUsers) {

HloInstruction* r = module->entry_computation()->root_instruction();
auto fusion = HloFusionAdaptor::ForInstruction(r);
const auto& result =
FindNonTrivialHero(fusion->GetRoots()[0].instruction(), *fusion);
const auto& result = FindNonTrivialHero(fusion->GetRoots()[0]);
EXPECT_EQ(result.name(), "bitcast.1");
const auto& result2 =
FindNonTrivialHero(fusion->GetRoots()[1].instruction(), *fusion);
const auto& result2 = FindNonTrivialHero(fusion->GetRoots()[1]);
EXPECT_EQ(result2.name(), "sign.1");
}

Expand Down Expand Up @@ -370,14 +363,16 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));

HloInstruction* r = module->GetComputationWithName("f")->root_instruction();
HloInstruction* transpose =
module->entry_computation()->GetInstructionWithName("t");
HloInstruction* fusion =
module->entry_computation()->GetInstructionWithName("fusion");
auto fusion_adaptor =
HloFusionAdaptor::ForProducerConsumer(transpose, fusion);
EXPECT_EQ(&FindNonTrivialHero(*r, *fusion_adaptor), transpose);
HloInstructionAdaptor r(
*module->GetComputationWithName("f")->root_instruction(),
fusion_adaptor.get());
EXPECT_EQ(&FindNonTrivialHero(r).instruction(), transpose);
}

TEST_F(IrEmissionUtilsTest, FindNonTrivialHeroInsideFusion) {
Expand Down Expand Up @@ -409,7 +404,9 @@ ENTRY entry {
HloInstruction* fusion =
module->entry_computation()->GetInstructionWithName("fusion");
auto fusion_adaptor = HloFusionAdaptor::ForProducerConsumer(fusion, r);
EXPECT_EQ(&FindNonTrivialHero(*r, *fusion_adaptor), transpose);
EXPECT_EQ(&FindNonTrivialHero(HloInstructionAdaptor(*r, fusion_adaptor.get()))
.instruction(),
transpose);
}

TEST_F(IrEmissionUtilsTest, TransposeReachableViaTrivialAndNontrivialOps) {
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/service/gpu/rename_fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ std::string MakeFusionHeroNames(const HloInstruction* instruction) {
absl::btree_set<absl::string_view> heroes;

for (auto root : fusion_adaptor->GetRoots()) {
heroes.insert(HloOpcodeString(
FindNonTrivialHero(root.instruction(), *fusion_adaptor).opcode()));
heroes.insert(HloOpcodeString(FindNonTrivialHero(root).opcode()));
}
return absl::StrReplaceAll(absl::StrJoin(heroes, "_"), {{"-", "_"}});
}
Expand Down