Skip to content

Commit

Permalink
[XLA:GPU] Update FindNonTrivialHero to work with HloInstructionAdaptor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629431715
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed May 2, 2024
1 parent 872be62 commit c210569
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 45 deletions.
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ HloFusionAnalysis HloFusionAnalysis::Create(
std::vector<const HloInstruction*> heroes;
for (auto root : fusion->GetRoots()) {
roots.push_back(&root.instruction());
heroes.push_back(&FindNonTrivialHero(*roots.back(), *fusion));
heroes.push_back(&FindNonTrivialHero(root).instruction());
}

std::vector<const HloInstruction*> fusion_arguments;
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/hlo_traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class HloInstructionAdaptor {

// Use sparingly; prefer extending the interface.
const HloInstruction& instruction() const { return *instruction_; }
const HloFusionAdaptor& parent() const { return *parent_; }

private:
const HloInstruction* instruction_;
Expand Down
31 changes: 16 additions & 15 deletions third_party/xla/xla/service/gpu/ir_emission_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count,
}

static std::optional<HloInstructionAdaptor> FindNonTrivialHero(
HloInstructionAdaptor root, const HloFusionAdaptor& fusion,
const HloInstructionAdaptor& root,
const std::function<bool(const HloInstruction&)>& predicate) {
std::optional<HloInstructionAdaptor> hero = std::nullopt;
auto visitor = [&](HloInstructionAdaptor node) {
Expand All @@ -710,7 +710,7 @@ static std::optional<HloInstructionAdaptor> FindNonTrivialHero(
}
return TraversalResult::kAdvance;
};
HloBfsConsumersFirstTraversal({root}, fusion, visitor);
HloBfsConsumersFirstTraversal({root}, root.parent(), visitor);
if (!hero) {
return std::nullopt;
}
Expand All @@ -727,23 +727,23 @@ static std::optional<HloInstructionAdaptor> FindNonTrivialHero(
/*add_single_user_check=*/true);
};
bool visit_operands = false;
if (HloAnyOf(hero->GetUsers(), fusion, is_nontrivial, visit_operands)) {
if (HloAnyOf(hero->GetUsers(), hero->parent(), is_nontrivial,
visit_operands)) {
return std::nullopt;
}

return hero;
}

const HloInstruction& FindNonTrivialHero(const HloInstruction& instr,
const HloFusionAdaptor& fusion) {
HloInstructionAdaptor hero{instr, &fusion};
HloInstructionAdaptor FindNonTrivialHero(const HloInstructionAdaptor& instr) {
HloInstructionAdaptor hero = instr;

// Go up the chain of trivial element-wise(+bitcast, -copy) operations. Note
// that no memoization is needed due to number of operands constraints: we
// never have to revisit same nodes.
while (IsIntermediate(&hero.instruction(), /*allowed_operand_count=*/1,
&fusion) &&
fusion.ContainsInstruction(hero.GetOperand(0))) {
&hero.parent()) &&
hero.parent().ContainsInstruction(hero.GetOperand(0))) {
hero = hero.GetOperand(0);
}

Expand All @@ -753,25 +753,26 @@ const HloInstruction& FindNonTrivialHero(const HloInstruction& instr,
auto is_transpose = [](const HloInstruction& node) {
return FindTiledLogicalTranspose(node).has_value();
};
if (auto transpose = FindNonTrivialHero(hero, fusion, is_transpose)) {
return transpose->instruction();
if (auto transpose = FindNonTrivialHero(hero, is_transpose)) {
return *transpose;
}
auto is_concatenate = [](const HloInstruction& node) {
return node.opcode() == HloOpcode::kConcatenate;
};
if (auto concatenate = FindNonTrivialHero(hero, fusion, is_concatenate)) {
return concatenate->instruction();
if (auto concatenate = FindNonTrivialHero(hero, is_concatenate)) {
return *concatenate;
}
if (hero.opcode() != HloOpcode::kReduce) {
return instr;
}
return hero.instruction();
return hero;
}

const HloInstruction& FindNonTrivialHero(const HloInstruction& instr) {
CHECK_NE(instr.opcode(), HloOpcode::kFusion);
return FindNonTrivialHero(instr,
*HloFusionAdaptor::ForComputation(instr.parent()));
auto fusion_adaptor = HloFusionAdaptor::ForComputation(instr.parent());
HloInstructionAdaptor instr_adaptor(instr, fusion_adaptor.get());
return FindNonTrivialHero(instr_adaptor).instruction();
}

void VLogModule(int level, const llvm::Module& module) {
Expand Down
16 changes: 6 additions & 10 deletions third_party/xla/xla/service/gpu/ir_emission_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,12 @@ std::vector<const HloInstruction*> GetOutputDefiningDynamicUpdateSlices(

Shape GetShape(mlir::Value value);

// `is_boundary` returns `true` for edges that are on the boundary of the
// fusion, i.e., they go from an instruction inside the fusion to one outside,
// or vice versa.
// Note: when this is called with a fusion instruction, it will traverse into
// the fusion (unless the boundary function stops it).
const HloInstruction& FindNonTrivialHero(const HloInstruction& instr,
const HloFusionAdaptor& fusion);

// Like above, but assumes the instruction is inside an HloFusionInstruction.
// Returns the instruction itself if it is an HloFusionInstruction.
// Returns the first hero instruction reachable from `instr` as root. Hero
// instruction can be in a different computation if the parent HloFusionAdaptor
// is a producer-consumer fusion.
HloInstructionAdaptor FindNonTrivialHero(const HloInstructionAdaptor& instr);

// Same as above, but fusion is the parent computation of the hlo instruction.
const HloInstruction& FindNonTrivialHero(const HloInstruction& instr);

/// Description of how to emit a given transposition.
Expand Down
31 changes: 14 additions & 17 deletions third_party/xla/xla/service/gpu/ir_emission_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusion) {

HloInstruction* r = module->entry_computation()->root_instruction();
auto fusion = HloFusionAdaptor::ForInstruction(r);
const auto& result =
FindNonTrivialHero(fusion->GetRoots()[0].instruction(), *fusion);
const auto& result = FindNonTrivialHero(fusion->GetRoots()[0]);
EXPECT_EQ(result.name(), "reduce.0");
}

Expand Down Expand Up @@ -187,11 +186,9 @@ TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusionTwoRootUsers) {

HloInstruction* r = module->entry_computation()->root_instruction();
auto fusion = HloFusionAdaptor::ForInstruction(r);
const auto& result =
FindNonTrivialHero(fusion->GetRoots()[1].instruction(), *fusion);
const auto& result = FindNonTrivialHero(fusion->GetRoots()[1]);
EXPECT_EQ(result.name(), "reduce.1");
const auto& result2 =
FindNonTrivialHero(fusion->GetRoots()[2].instruction(), *fusion);
const auto& result2 = FindNonTrivialHero(fusion->GetRoots()[2]);
EXPECT_EQ(result2.name(), "reduce.1");
}

Expand Down Expand Up @@ -225,13 +222,11 @@ TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusionHeroAlsoUsedAsNonHero) {

HloInstruction* r = module->entry_computation()->root_instruction();
auto fusion = HloFusionAdaptor::ForInstruction(r);
const auto& result =
FindNonTrivialHero(fusion->GetRoots()[1].instruction(), *fusion);
const auto& result = FindNonTrivialHero(fusion->GetRoots()[1]);
// reduce.0 is also an operand of broadcast, but it is not a hero for that
// root.
EXPECT_EQ(result.name(), "broadcast");
const auto& result2 =
FindNonTrivialHero(fusion->GetRoots()[2].instruction(), *fusion);
const auto& result2 = FindNonTrivialHero(fusion->GetRoots()[2]);
EXPECT_EQ(result2.name(), "reduce.0");
}

Expand All @@ -258,11 +253,9 @@ TEST_F(IrEmissionUtilsTest, DoNotFindTransposeHeroEpilogueFusionTwoRootUsers) {

HloInstruction* r = module->entry_computation()->root_instruction();
auto fusion = HloFusionAdaptor::ForInstruction(r);
const auto& result =
FindNonTrivialHero(fusion->GetRoots()[0].instruction(), *fusion);
const auto& result = FindNonTrivialHero(fusion->GetRoots()[0]);
EXPECT_EQ(result.name(), "bitcast.1");
const auto& result2 =
FindNonTrivialHero(fusion->GetRoots()[1].instruction(), *fusion);
const auto& result2 = FindNonTrivialHero(fusion->GetRoots()[1]);
EXPECT_EQ(result2.name(), "sign.1");
}

Expand Down Expand Up @@ -370,14 +363,16 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));

HloInstruction* r = module->GetComputationWithName("f")->root_instruction();
HloInstruction* transpose =
module->entry_computation()->GetInstructionWithName("t");
HloInstruction* fusion =
module->entry_computation()->GetInstructionWithName("fusion");
auto fusion_adaptor =
HloFusionAdaptor::ForProducerConsumer(transpose, fusion);
EXPECT_EQ(&FindNonTrivialHero(*r, *fusion_adaptor), transpose);
HloInstructionAdaptor r(
*module->GetComputationWithName("f")->root_instruction(),
fusion_adaptor.get());
EXPECT_EQ(&FindNonTrivialHero(r).instruction(), transpose);
}

TEST_F(IrEmissionUtilsTest, FindNonTrivialHeroInsideFusion) {
Expand Down Expand Up @@ -409,7 +404,9 @@ ENTRY entry {
HloInstruction* fusion =
module->entry_computation()->GetInstructionWithName("fusion");
auto fusion_adaptor = HloFusionAdaptor::ForProducerConsumer(fusion, r);
EXPECT_EQ(&FindNonTrivialHero(*r, *fusion_adaptor), transpose);
EXPECT_EQ(&FindNonTrivialHero(HloInstructionAdaptor(*r, fusion_adaptor.get()))
.instruction(),
transpose);
}

TEST_F(IrEmissionUtilsTest, TransposeReachableViaTrivialAndNontrivialOps) {
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/service/gpu/rename_fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ std::string MakeFusionHeroNames(const HloInstruction* instruction) {
absl::btree_set<absl::string_view> heroes;

for (auto root : fusion_adaptor->GetRoots()) {
heroes.insert(HloOpcodeString(
FindNonTrivialHero(root.instruction(), *fusion_adaptor).opcode()));
heroes.insert(HloOpcodeString(FindNonTrivialHero(root).opcode()));
}
return absl::StrReplaceAll(absl::StrJoin(heroes, "_"), {{"-", "_"}});
}
Expand Down

0 comments on commit c210569

Please sign in to comment.