diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index b07f2883b7e10f..f5c0d979d7c48e 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -442,9 +442,6 @@ xla_cc_test( ":indexing_analysis", ":indexing_map", ":indexing_test_utils", - "//xla:literal_util", - "//xla:shape_util", - "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index f5ca20b76fb48a..f12f923342276c 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -1140,46 +1140,81 @@ UnusedVariables DetectUnusedVariables(const IndexingMap& indexing_map) { return unused_vars; } +SmallBitVector ConcatenateBitVectors(const SmallBitVector& lhs, + const SmallBitVector& rhs) { + SmallBitVector concat(lhs.size() + rhs.size(), false); + int id = 0; + for (int i = 0; i < lhs.size(); ++i, ++id) { + concat[id] = lhs[i]; + } + for (int i = 0; i < rhs.size(); ++i, ++id) { + concat[id] = rhs[i]; + } + return concat; +} + } // namespace -SmallBitVector IndexingMap::RemoveUnusedSymbols() { - if (IsUndefined()) return {}; +bool IndexingMap::CompressVars(const llvm::SmallBitVector& unused_dims, + const llvm::SmallBitVector& unused_symbols) { + MLIRContext* mlir_context = GetMLIRContext(); - UnusedVariables unused_vars = DetectUnusedVariables(*this); - for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) { - constraints_.erase(expr); - } - unsigned num_symbols_before = GetSymbolCount(); - affine_map_ = mlir::compressSymbols(affine_map_, unused_vars.unused_symbols); + bool num_dims_changed = unused_dims.count() > 0; + bool num_symbols_changed = unused_symbols.count() > 0; + if (!num_dims_changed && !num_symbols_changed) return false; - unsigned num_symbols_after = affine_map_.getNumSymbols(); - if (num_symbols_after == num_symbols_before) return {}; + unsigned num_dims_before = GetDimensionCount(); + unsigned num_symbols_before = GetSymbolCount(); - // Remap symbols in the constraint expressions accordingly. - std::vector compressed_range_vars; - std::vector compressed_rt_vars; - MLIRContext* mlir_context = GetMLIRContext(); - int64_t used_symbols_count = 0; - std::vector symbol_replacements( - num_symbols_before, getAffineConstantExpr(0, mlir_context)); - auto range_vars_count = range_vars_.size(); - for (int i = 0; i < unused_vars.unused_symbols.size(); ++i) { - if (!unused_vars.unused_symbols[i]) { - if (i < range_vars_count) { - compressed_range_vars.push_back(range_vars_[i]); - } else { - compressed_rt_vars.push_back(rt_vars_[i - range_vars_count]); + // Compress DimVars. + SmallVector dim_replacements; + if (num_dims_changed) { + affine_map_ = mlir::compressDims(affine_map_, unused_dims); + std::vector compressed_dim_vars; + dim_replacements = SmallVector( + num_dims_before, getAffineConstantExpr(0, mlir_context)); + int64_t used_dims_count = 0; + for (int i = 0; i < unused_dims.size(); ++i) { + if (!unused_dims[i]) { + compressed_dim_vars.push_back(dim_vars_[i]); + dim_replacements[i] = getAffineDimExpr(used_dims_count++, mlir_context); } - symbol_replacements[i] = - getAffineSymbolExpr(used_symbols_count++, mlir_context); } + dim_vars_ = std::move(compressed_dim_vars); + } + + // Compress RangeVars and RTVars. + SmallVector symbol_replacements; + if (num_symbols_changed) { + affine_map_ = mlir::compressSymbols(affine_map_, unused_symbols); + symbol_replacements = SmallVector( + num_symbols_before, getAffineConstantExpr(0, mlir_context)); + std::vector compressed_range_vars; + std::vector compressed_rt_vars; + MLIRContext* mlir_context = GetMLIRContext(); + int64_t used_symbols_count = 0; + auto range_vars_count = range_vars_.size(); + for (int i = 0; i < unused_symbols.size(); ++i) { + if (!unused_symbols[i]) { + if (i < range_vars_count) { + compressed_range_vars.push_back(range_vars_[i]); + } else { + compressed_rt_vars.push_back(rt_vars_[i - range_vars_count]); + } + symbol_replacements[i] = + getAffineSymbolExpr(used_symbols_count++, mlir_context); + } + } + range_vars_ = std::move(compressed_range_vars); + rt_vars_ = std::move(compressed_rt_vars); } - range_vars_ = std::move(compressed_range_vars); - rt_vars_ = std::move(compressed_rt_vars); + + // Remove constraints. std::vector to_remove; std::vector> to_add; for (const auto& [expr, range] : constraints_) { - auto updated_expr = expr.replaceSymbols(symbol_replacements); + auto updated_expr = + expr.replaceDimsAndSymbols(dim_replacements, symbol_replacements); if (updated_expr == expr) continue; to_add.push_back({updated_expr, range}); to_remove.push_back(expr); @@ -1190,50 +1225,47 @@ SmallBitVector IndexingMap::RemoveUnusedSymbols() { for (const auto& [expr, range] : to_add) { AddConstraint(expr, range); } - return std::move(unused_vars.unused_symbols); + return true; } -SmallBitVector IndexingMap::RemoveUnusedDimensions() { +SmallBitVector IndexingMap::RemoveUnusedSymbols() { if (IsUndefined()) return {}; UnusedVariables unused_vars = DetectUnusedVariables(*this); for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) { constraints_.erase(expr); } - unsigned num_dims_before = GetDimensionCount(); - affine_map_ = mlir::compressDims(affine_map_, unused_vars.unused_dims); + if (!CompressVars(/*unused_dims=*/{}, unused_vars.unused_symbols)) { + return {}; + } + return std::move(unused_vars.unused_symbols); +} - unsigned num_dims_after = affine_map_.getNumDims(); - if (num_dims_after == num_dims_before) return {}; +SmallBitVector IndexingMap::RemoveUnusedDimensions() { + if (IsUndefined()) return {}; - // Remap dimensions in the constraint expressions accordingly. - std::vector compressed_dim_vars; - MLIRContext* mlir_context = GetMLIRContext(); - int64_t used_dims_count = 0; - std::vector dim_replacements( - num_dims_before, getAffineConstantExpr(0, mlir_context)); - for (int i = 0; i < unused_vars.unused_dims.size(); ++i) { - if (!unused_vars.unused_dims[i]) { - compressed_dim_vars.push_back(dim_vars_[i]); - dim_replacements[i] = getAffineDimExpr(used_dims_count++, mlir_context); - } + UnusedVariables unused_vars = DetectUnusedVariables(*this); + for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) { + constraints_.erase(expr); } - dim_vars_ = std::move(compressed_dim_vars); - std::vector to_remove; - std::vector> to_add; - for (const auto& [expr, range] : constraints_) { - auto updated_expr = expr.replaceDims(dim_replacements); - if (updated_expr == expr) continue; - to_add.push_back({updated_expr, range}); - to_remove.push_back(expr); + if (!CompressVars(unused_vars.unused_dims, /*unused_symbols=*/{})) { + return {}; } - for (const auto& expr : to_remove) { + return std::move(unused_vars.unused_dims); +} + +SmallBitVector IndexingMap::RemoveUnusedVars() { + if (IsUndefined()) return {}; + + UnusedVariables unused_vars = DetectUnusedVariables(*this); + for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) { constraints_.erase(expr); } - for (const auto& [expr, range] : to_add) { - AddConstraint(expr, range); + if (!CompressVars(unused_vars.unused_dims, unused_vars.unused_symbols)) { + return {}; } - return std::move(unused_vars.unused_dims); + return ConcatenateBitVectors(unused_vars.unused_dims, + unused_vars.unused_symbols); } void IndexingMap::MergeModConstraints() { diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index 9f24095a32c3cd..c615c2a3c09a3c 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -335,6 +335,11 @@ class IndexingMap { // were removed, returns {}. llvm::SmallBitVector RemoveUnusedSymbols(); + // Removes unused dimensions and symbols from the `affine_map_` and + // constraints. Returns a bit vector of all variables [dimensions, symbols] + // that were removed. If none of the symbols were removed, returns {}. + llvm::SmallBitVector RemoveUnusedVars(); + // Rescales all symbols that are sufficiently constrained through `s? mod x = // [N, N]` constraints. Returns true if a rescale took place, otherwise false. bool RescaleSymbols(); @@ -363,6 +368,12 @@ class IndexingMap { // Returns true if a replacement was performed, otherwise false. bool ReplaceConstantRTVars(IndexingMapProvider indexing_map_provider); + // Removes DimVars, RangeVars, RTVars that correspond to the unused dimensions + // and symbols. If unused_dims is empty, then dims won't be removed. The same + // applies to unused_symbols. Returns true, if anything was removed. + bool CompressVars(const llvm::SmallBitVector& unused_dims, + const llvm::SmallBitVector& unused_symbols); + mlir::AffineMap affine_map_; std::vector dim_vars_; std::vector range_vars_; diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index dadf857d7f9745..20807546347c0c 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -26,8 +26,6 @@ limitations under the License. #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" @@ -48,6 +46,15 @@ class IndexingMapTest : public HloTestBase { AffineMapPrinter printer_; }; +std::vector ConvertToSTL(const llvm::SmallBitVector& bit_vector) { + std::vector result; + result.reserve(bit_vector.size()); + for (int i = 0; i < bit_vector.size(); ++i) { + result.push_back(bit_vector[i]); + } + return result; +} + TEST_F(IndexingMapTest, RTVar) { auto zero_dim_map = AffineMap::get(&mlir_context_); std::vector rt_vars{RTVar{Interval{0, 2}, @@ -222,6 +229,53 @@ TEST_F(IndexingMapTest, RemoveUnusedDimensions_ConstraintUsesOnlyUnusedDim) { )")); } +TEST_F(IndexingMapTest, RemoveUnusedDimensions_ConstraintsWithManyDims) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1, d2, d3, d4)[s0, s1] -> (s0 * 4 + d1 + d3 - 42)", + &mlir_context_), + {1, 2, 3, 4, 5}, {32, 64}); + indexing_map.AddConstraint( + ParseAffineExpr("s0 * 4 + d1 + d3", &mlir_context_), Interval{24, 459}); + indexing_map.RemoveUnusedDimensions(); + // dimensions d0, d2, d4 will be removed and d1 and d3 will become d0 and d1. + EXPECT_THAT(indexing_map, MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42) + domain: + d0 in [0, 1] + d1 in [0, 3] + s0 in [0, 31] + s1 in [0, 63] + d0 + s0 * 4 + d1 in [24, 459] + )")); +} + +TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap( + "(d0, d1, d2, d3, d4)[s0, s1, s2] -> (s0 * 4 + d1 + d3 - 42)", + &mlir_context_), + {1, 2, 3, 4, 5}, {32, 64, 96}); + indexing_map.AddConstraint( + ParseAffineExpr("s0 * 4 + d1 + d3", &mlir_context_), Interval{24, 459}); + indexing_map.AddConstraint(ParseAffineExpr("s0 + s2", &mlir_context_), + Interval{0, 512}); + auto unused_vars = indexing_map.RemoveUnusedVars(); + // dimensions d0, d2, d4 and symbol s1 will be removed. + EXPECT_THAT(indexing_map, MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42) + domain: + d0 in [0, 1] + d1 in [0, 3] + s0 in [0, 31] + s1 in [0, 95] + d0 + s0 * 4 + d1 in [24, 459] + s0 + s1 in [0, 512] + )")); + EXPECT_THAT(ConvertToSTL(unused_vars), + ::testing::ElementsAreArray( + {true, false, true, false, true, false, true, false})); +} + TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_),