Skip to content

Commit

Permalink
[XLA:GPU][IndexAnalysis] Add a method to remove unused dims.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629330924
  • Loading branch information
pifon2a authored and tensorflower-gardener committed Apr 30, 2024
1 parent 6b11e60 commit e5b3a0d
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 26 deletions.
114 changes: 88 additions & 26 deletions third_party/xla/xla/service/gpu/model/indexing_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ using mlir::AffineMap;
using mlir::AffineSymbolExpr;
using mlir::getAffineBinaryOpExpr;
using mlir::getAffineConstantExpr;
using mlir::getAffineDimExpr;
using mlir::MLIRContext;

class AffineExprSimplifier {
Expand Down Expand Up @@ -1087,60 +1088,78 @@ bool IsFunctionOfUnusedDimsAndSymbolsOnly(
return true;
}

} // namespace
struct UnusedVariables {
SmallBitVector unused_dims;
SmallBitVector unused_symbols;
SmallVector<AffineExpr> constraints_with_unused_vars_only;
};

SmallBitVector IndexingMap::RemoveUnusedSymbols() {
if (IsUndefined()) return {};
// Detects unused dimensions and symbols in the inde
UnusedVariables DetectUnusedVariables(const IndexingMap& indexing_map) {
AffineMap affine_map = indexing_map.GetAffineMap();

// Remove unused symbols from the affine_map.
unsigned num_symbols_before = affine_map_.getNumSymbols();
SmallBitVector unused_symbols_bit_vector =
mlir::getUnusedSymbolsBitVector({affine_map_});
SmallBitVector unused_dims_bit_vector =
mlir::getUnusedDimsBitVector({affine_map_});
UnusedVariables unused_vars;
// Find unused dimensions and symbols in the affine_map.
unused_vars.unused_dims = mlir::getUnusedDimsBitVector({affine_map});
unused_vars.unused_symbols = mlir::getUnusedSymbolsBitVector({affine_map});

// Check if the symbols that are unused in `affine_map` are also unused in
// expressions.
std::vector<std::pair<AffineExpr, UsedParameters>> candidates_to_remove;
for (const auto& [expr, range] : constraints_) {
SmallVector<std::pair<AffineExpr, UsedParameters>, 2>
unused_constraints_candidates;
for (const auto& [expr, range] : indexing_map.GetConstraints()) {
UsedParameters used_parameters = GetUsedParameters(expr);
// If the expression uses only symbols and dims that are "unused" in
// `affine_map`, then we can remove it.
if (IsFunctionOfUnusedDimsAndSymbolsOnly(used_parameters,
unused_dims_bit_vector,
unused_symbols_bit_vector)) {
candidates_to_remove.push_back({expr, used_parameters});
unused_vars.unused_dims,
unused_vars.unused_symbols)) {
unused_constraints_candidates.push_back({expr, used_parameters});
continue;
}
// Otherwise, we need to mark all symbols of these expr as "used".
// Otherwise, we need to mark all dims and symbols of these expr as "used".
for (int64_t dim_id : used_parameters.dimension_ids) {
unused_vars.unused_dims[dim_id] = false;
}
for (int64_t symbol_id : used_parameters.symbol_ids) {
unused_symbols_bit_vector[symbol_id] = false;
unused_vars.unused_symbols[symbol_id] = false;
}
}
for (const auto& [expr, used_parameters] : candidates_to_remove) {
for (const auto& [expr, used_parameters] : unused_constraints_candidates) {
if (IsFunctionOfUnusedDimsAndSymbolsOnly(used_parameters,
unused_dims_bit_vector,
unused_symbols_bit_vector)) {
constraints_.erase(expr);
unused_vars.unused_dims,
unused_vars.unused_symbols)) {
unused_vars.constraints_with_unused_vars_only.push_back(expr);
}
}
return unused_vars;
}

// Compress `affine_map` using the updated `unused_symbols_bit_vector`.
affine_map_ = mlir::compressSymbols(affine_map_, unused_symbols_bit_vector);
} // namespace

SmallBitVector IndexingMap::RemoveUnusedSymbols() {
if (IsUndefined()) return {};

UnusedVariables unused_vars = DetectUnusedVariables(*this);
for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) {
constraints_.erase(expr);
}
unsigned num_symbols_before = GetSymbolCount();
affine_map_ = mlir::compressSymbols(affine_map_, unused_vars.unused_symbols);

// Remap symbols in the constraint expressions accordingly.
unsigned num_symbols_after = affine_map_.getNumSymbols();
if (num_symbols_after == num_symbols_before) return {};

// Remap symbols in the constraint expressions accordingly.
std::vector<RangeVar> compressed_range_vars;
std::vector<RTVar> compressed_rt_vars;
MLIRContext* mlir_context = GetMLIRContext();
int64_t used_symbols_count = 0;
std::vector<AffineExpr> symbol_replacements(
num_symbols_before, getAffineConstantExpr(0, mlir_context));
auto range_vars_count = range_vars_.size();
for (int i = 0; i < unused_symbols_bit_vector.size(); ++i) {
if (!unused_symbols_bit_vector[i]) {
for (int i = 0; i < unused_vars.unused_symbols.size(); ++i) {
if (!unused_vars.unused_symbols[i]) {
if (i < range_vars_count) {
compressed_range_vars.push_back(range_vars_[i]);
} else {
Expand All @@ -1166,7 +1185,50 @@ SmallBitVector IndexingMap::RemoveUnusedSymbols() {
for (const auto& [expr, range] : to_add) {
AddConstraint(expr, range);
}
return unused_symbols_bit_vector;
return std::move(unused_vars.unused_symbols);
}

SmallBitVector IndexingMap::RemoveUnusedDimensions() {
if (IsUndefined()) return {};

UnusedVariables unused_vars = DetectUnusedVariables(*this);
for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) {
constraints_.erase(expr);
}
unsigned num_dims_before = GetDimensionCount();
affine_map_ = mlir::compressDims(affine_map_, unused_vars.unused_dims);

unsigned num_dims_after = affine_map_.getNumDims();
if (num_dims_after == num_dims_before) return {};

// Remap dimensions in the constraint expressions accordingly.
std::vector<DimVar> compressed_dim_vars;
MLIRContext* mlir_context = GetMLIRContext();
int64_t used_dims_count = 0;
std::vector<AffineExpr> dim_replacements(
num_dims_before, getAffineConstantExpr(0, mlir_context));
for (int i = 0; i < unused_vars.unused_dims.size(); ++i) {
if (!unused_vars.unused_dims[i]) {
compressed_dim_vars.push_back(dim_vars_[i]);
dim_replacements[i] = getAffineDimExpr(used_dims_count++, mlir_context);
}
}
dim_vars_ = std::move(compressed_dim_vars);
std::vector<AffineExpr> to_remove;
std::vector<std::pair<AffineExpr, Interval>> to_add;
for (const auto& [expr, range] : constraints_) {
auto updated_expr = expr.replaceDims(dim_replacements);
if (updated_expr == expr) continue;
to_add.push_back({updated_expr, range});
to_remove.push_back(expr);
}
for (const auto& expr : to_remove) {
constraints_.erase(expr);
}
for (const auto& [expr, range] : to_add) {
AddConstraint(expr, range);
}
return std::move(unused_vars.unused_dims);
}

void IndexingMap::MergeModConstraints() {
Expand Down
5 changes: 5 additions & 0 deletions third_party/xla/xla/service/gpu/model/indexing_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,11 @@ class IndexingMap {

bool IsUndefined() const { return affine_map_ == mlir::AffineMap(); }

// Removes unused dimensions from the `affine_map_` and constraints.
// Returns a bit vector of dimensions that were removed. If none of the
// dimensions were removed, returns {}.
llvm::SmallBitVector RemoveUnusedDimensions();

// Removes unused symbols from the `affine_map_` and constraints.
// Returns a bit vector of symbols that were removed. If none of the symbols
// were removed, returns {}.
Expand Down
39 changes: 39 additions & 0 deletions third_party/xla/xla/service/gpu/model/indexing_map_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,45 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) {
)"));
}

TEST_F(IndexingMapTest, RemoveUnusedDimensions_ConstraintUsesDim) {
IndexingMap indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap("(d0, d1)[s0, s1] -> (d1, s0, s1)", &mlir_context_),
{50, 60}, {70, 20});
// This constraint cannot be removed, because it contains a "used dim".
indexing_map.AddConstraint(ParseAffineExpr("s0 + d0", &mlir_context_),
Interval{1, 100});
indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_),
Interval{0, 0});
indexing_map.RemoveUnusedDimensions();
EXPECT_THAT(indexing_map, MatchIndexingMap(R"(
(d0, d1)[s0, s1] -> (d1, s0, s1)
domain:
d0 in [0, 49]
d1 in [0, 59]
s0 in [0, 69]
s1 in [0, 19]
d0 + s0 in [1, 100]
s0 mod 3 in [0, 0]
)"));
}

TEST_F(IndexingMapTest, RemoveUnusedDimensions_ConstraintUsesOnlyUnusedDim) {
IndexingMap indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap("(d0, d1)[s0, s1] -> (s0, d1, s1)", &mlir_context_),
{50, 60}, {70, 20});
// This constraint can be removed, because it contains only the unused dim.
indexing_map.AddConstraint(ParseAffineExpr("d0 mod 3", &mlir_context_),
Interval{0, 0});
indexing_map.RemoveUnusedDimensions();
EXPECT_THAT(indexing_map, MatchIndexingMap(R"(
(d0)[s0, s1] -> (s0, d0, s1)
domain:
d0 in [0, 59]
s0 in [0, 69]
s1 in [0, 19]
)"));
}

TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) {
IndexingMap indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_),
Expand Down

0 comments on commit e5b3a0d

Please sign in to comment.