Skip to content

Commit

Permalink
[XLA:GPU][IndexAnalysis] Add a method to remove unused dims and symbols.
Browse files Browse the repository at this point in the history
Both RemoveUnusedSymbols and RemoveUnusedDims find unused dims and symbols. Therefore, it would not be efficient to run this part twice when we want to remove both symbols and dimensions.

PiperOrigin-RevId: 629340717
  • Loading branch information
pifon2a authored and tensorflower-gardener committed May 2, 2024
1 parent 97092c8 commit 616b53c
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 63 deletions.
3 changes: 0 additions & 3 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,6 @@ xla_cc_test(
":indexing_analysis",
":indexing_map",
":indexing_test_utils",
"//xla:literal_util",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main",
Expand Down
148 changes: 90 additions & 58 deletions third_party/xla/xla/service/gpu/model/indexing_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1140,46 +1140,81 @@ UnusedVariables DetectUnusedVariables(const IndexingMap& indexing_map) {
return unused_vars;
}

SmallBitVector ConcatenateBitVectors(const SmallBitVector& lhs,
const SmallBitVector& rhs) {
SmallBitVector concat(lhs.size() + rhs.size(), false);
int id = 0;
for (int i = 0; i < lhs.size(); ++i, ++id) {
concat[id] = lhs[i];
}
for (int i = 0; i < rhs.size(); ++i, ++id) {
concat[id] = rhs[i];
}
return concat;
}

} // namespace

SmallBitVector IndexingMap::RemoveUnusedSymbols() {
if (IsUndefined()) return {};
bool IndexingMap::CompressVars(const llvm::SmallBitVector& unused_dims,
const llvm::SmallBitVector& unused_symbols) {
MLIRContext* mlir_context = GetMLIRContext();

UnusedVariables unused_vars = DetectUnusedVariables(*this);
for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) {
constraints_.erase(expr);
}
unsigned num_symbols_before = GetSymbolCount();
affine_map_ = mlir::compressSymbols(affine_map_, unused_vars.unused_symbols);
bool num_dims_changed = unused_dims.count() > 0;
bool num_symbols_changed = unused_symbols.count() > 0;
if (!num_dims_changed && !num_symbols_changed) return false;

unsigned num_symbols_after = affine_map_.getNumSymbols();
if (num_symbols_after == num_symbols_before) return {};
unsigned num_dims_before = GetDimensionCount();
unsigned num_symbols_before = GetSymbolCount();

// Remap symbols in the constraint expressions accordingly.
std::vector<RangeVar> compressed_range_vars;
std::vector<RTVar> compressed_rt_vars;
MLIRContext* mlir_context = GetMLIRContext();
int64_t used_symbols_count = 0;
std::vector<AffineExpr> symbol_replacements(
num_symbols_before, getAffineConstantExpr(0, mlir_context));
auto range_vars_count = range_vars_.size();
for (int i = 0; i < unused_vars.unused_symbols.size(); ++i) {
if (!unused_vars.unused_symbols[i]) {
if (i < range_vars_count) {
compressed_range_vars.push_back(range_vars_[i]);
} else {
compressed_rt_vars.push_back(rt_vars_[i - range_vars_count]);
// Compress DimVars.
SmallVector<AffineExpr, 2> dim_replacements;
if (num_dims_changed) {
affine_map_ = mlir::compressDims(affine_map_, unused_dims);
std::vector<DimVar> compressed_dim_vars;
dim_replacements = SmallVector<AffineExpr, 2>(
num_dims_before, getAffineConstantExpr(0, mlir_context));
int64_t used_dims_count = 0;
for (int i = 0; i < unused_dims.size(); ++i) {
if (!unused_dims[i]) {
compressed_dim_vars.push_back(dim_vars_[i]);
dim_replacements[i] = getAffineDimExpr(used_dims_count++, mlir_context);
}
symbol_replacements[i] =
getAffineSymbolExpr(used_symbols_count++, mlir_context);
}
dim_vars_ = std::move(compressed_dim_vars);
}

// Compress RangeVars and RTVars.
SmallVector<AffineExpr, 2> symbol_replacements;
if (num_symbols_changed) {
affine_map_ = mlir::compressSymbols(affine_map_, unused_symbols);
symbol_replacements = SmallVector<AffineExpr, 2>(
num_symbols_before, getAffineConstantExpr(0, mlir_context));
std::vector<RangeVar> compressed_range_vars;
std::vector<RTVar> compressed_rt_vars;
MLIRContext* mlir_context = GetMLIRContext();
int64_t used_symbols_count = 0;
auto range_vars_count = range_vars_.size();
for (int i = 0; i < unused_symbols.size(); ++i) {
if (!unused_symbols[i]) {
if (i < range_vars_count) {
compressed_range_vars.push_back(range_vars_[i]);
} else {
compressed_rt_vars.push_back(rt_vars_[i - range_vars_count]);
}
symbol_replacements[i] =
getAffineSymbolExpr(used_symbols_count++, mlir_context);
}
}
range_vars_ = std::move(compressed_range_vars);
rt_vars_ = std::move(compressed_rt_vars);
}
range_vars_ = std::move(compressed_range_vars);
rt_vars_ = std::move(compressed_rt_vars);

// Remove constraints.
std::vector<AffineExpr> to_remove;
std::vector<std::pair<AffineExpr, Interval>> to_add;
for (const auto& [expr, range] : constraints_) {
auto updated_expr = expr.replaceSymbols(symbol_replacements);
auto updated_expr =
expr.replaceDimsAndSymbols(dim_replacements, symbol_replacements);
if (updated_expr == expr) continue;
to_add.push_back({updated_expr, range});
to_remove.push_back(expr);
Expand All @@ -1190,50 +1225,47 @@ SmallBitVector IndexingMap::RemoveUnusedSymbols() {
for (const auto& [expr, range] : to_add) {
AddConstraint(expr, range);
}
return std::move(unused_vars.unused_symbols);
return true;
}

SmallBitVector IndexingMap::RemoveUnusedDimensions() {
SmallBitVector IndexingMap::RemoveUnusedSymbols() {
if (IsUndefined()) return {};

UnusedVariables unused_vars = DetectUnusedVariables(*this);
for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) {
constraints_.erase(expr);
}
unsigned num_dims_before = GetDimensionCount();
affine_map_ = mlir::compressDims(affine_map_, unused_vars.unused_dims);
if (!CompressVars(/*unused_dims=*/{}, unused_vars.unused_symbols)) {
return {};
}
return std::move(unused_vars.unused_symbols);
}

unsigned num_dims_after = affine_map_.getNumDims();
if (num_dims_after == num_dims_before) return {};
SmallBitVector IndexingMap::RemoveUnusedDimensions() {
if (IsUndefined()) return {};

// Remap dimensions in the constraint expressions accordingly.
std::vector<DimVar> compressed_dim_vars;
MLIRContext* mlir_context = GetMLIRContext();
int64_t used_dims_count = 0;
std::vector<AffineExpr> dim_replacements(
num_dims_before, getAffineConstantExpr(0, mlir_context));
for (int i = 0; i < unused_vars.unused_dims.size(); ++i) {
if (!unused_vars.unused_dims[i]) {
compressed_dim_vars.push_back(dim_vars_[i]);
dim_replacements[i] = getAffineDimExpr(used_dims_count++, mlir_context);
}
UnusedVariables unused_vars = DetectUnusedVariables(*this);
for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) {
constraints_.erase(expr);
}
dim_vars_ = std::move(compressed_dim_vars);
std::vector<AffineExpr> to_remove;
std::vector<std::pair<AffineExpr, Interval>> to_add;
for (const auto& [expr, range] : constraints_) {
auto updated_expr = expr.replaceDims(dim_replacements);
if (updated_expr == expr) continue;
to_add.push_back({updated_expr, range});
to_remove.push_back(expr);
if (!CompressVars(unused_vars.unused_dims, /*unused_symbols=*/{})) {
return {};
}
for (const auto& expr : to_remove) {
return std::move(unused_vars.unused_dims);
}

SmallBitVector IndexingMap::RemoveUnusedVars() {
if (IsUndefined()) return {};

UnusedVariables unused_vars = DetectUnusedVariables(*this);
for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) {
constraints_.erase(expr);
}
for (const auto& [expr, range] : to_add) {
AddConstraint(expr, range);
if (!CompressVars(unused_vars.unused_dims, unused_vars.unused_symbols)) {
return {};
}
return std::move(unused_vars.unused_dims);
return ConcatenateBitVectors(unused_vars.unused_dims,
unused_vars.unused_symbols);
}

void IndexingMap::MergeModConstraints() {
Expand Down
11 changes: 11 additions & 0 deletions third_party/xla/xla/service/gpu/model/indexing_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,11 @@ class IndexingMap {
// were removed, returns {}.
llvm::SmallBitVector RemoveUnusedSymbols();

// Removes unused dimensions and symbols from the `affine_map_` and
// constraints. Returns a bit vector of all variables [dimensions, symbols]
// that were removed. If none of the symbols were removed, returns {}.
llvm::SmallBitVector RemoveUnusedVars();

// Rescales all symbols that are sufficiently constrained through `s? mod x =
// [N, N]` constraints. Returns true if a rescale took place, otherwise false.
bool RescaleSymbols();
Expand Down Expand Up @@ -363,6 +368,12 @@ class IndexingMap {
// Returns true if a replacement was performed, otherwise false.
bool ReplaceConstantRTVars(IndexingMapProvider indexing_map_provider);

// Removes DimVars, RangeVars, RTVars that correspond to the unused dimensions
// and symbols. If unused_dims is empty, then dims won't be removed. The same
// applies to unused_symbols. Returns true, if anything was removed.
bool CompressVars(const llvm::SmallBitVector& unused_dims,
const llvm::SmallBitVector& unused_symbols);

mlir::AffineMap affine_map_;
std::vector<DimVar> dim_vars_;
std::vector<RangeVar> range_vars_;
Expand Down
58 changes: 56 additions & 2 deletions third_party/xla/xla/service/gpu/model/indexing_map_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ limitations under the License.
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/model/affine_map_printer.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_test_utils.h"
Expand All @@ -48,6 +46,15 @@ class IndexingMapTest : public HloTestBase {
AffineMapPrinter printer_;
};

std::vector<bool> ConvertToSTL(const llvm::SmallBitVector& bit_vector) {
std::vector<bool> result;
result.reserve(bit_vector.size());
for (int i = 0; i < bit_vector.size(); ++i) {
result.push_back(bit_vector[i]);
}
return result;
}

TEST_F(IndexingMapTest, RTVar) {
auto zero_dim_map = AffineMap::get(&mlir_context_);
std::vector<RTVar> rt_vars{RTVar{Interval{0, 2},
Expand Down Expand Up @@ -222,6 +229,53 @@ TEST_F(IndexingMapTest, RemoveUnusedDimensions_ConstraintUsesOnlyUnusedDim) {
)"));
}

TEST_F(IndexingMapTest, RemoveUnusedDimensions_ConstraintsWithManyDims) {
IndexingMap indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap("(d0, d1, d2, d3, d4)[s0, s1] -> (s0 * 4 + d1 + d3 - 42)",
&mlir_context_),
{1, 2, 3, 4, 5}, {32, 64});
indexing_map.AddConstraint(
ParseAffineExpr("s0 * 4 + d1 + d3", &mlir_context_), Interval{24, 459});
indexing_map.RemoveUnusedDimensions();
// dimensions d0, d2, d4 will be removed and d1 and d3 will become d0 and d1.
EXPECT_THAT(indexing_map, MatchIndexingMap(R"(
(d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42)
domain:
d0 in [0, 1]
d1 in [0, 3]
s0 in [0, 31]
s1 in [0, 63]
d0 + s0 * 4 + d1 in [24, 459]
)"));
}

TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) {
IndexingMap indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap(
"(d0, d1, d2, d3, d4)[s0, s1, s2] -> (s0 * 4 + d1 + d3 - 42)",
&mlir_context_),
{1, 2, 3, 4, 5}, {32, 64, 96});
indexing_map.AddConstraint(
ParseAffineExpr("s0 * 4 + d1 + d3", &mlir_context_), Interval{24, 459});
indexing_map.AddConstraint(ParseAffineExpr("s0 + s2", &mlir_context_),
Interval{0, 512});
auto unused_vars = indexing_map.RemoveUnusedVars();
// dimensions d0, d2, d4 and symbol s1 will be removed.
EXPECT_THAT(indexing_map, MatchIndexingMap(R"(
(d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42)
domain:
d0 in [0, 1]
d1 in [0, 3]
s0 in [0, 31]
s1 in [0, 95]
d0 + s0 * 4 + d1 in [24, 459]
s0 + s1 in [0, 512]
)"));
EXPECT_THAT(ConvertToSTL(unused_vars),
::testing::ElementsAreArray(
{true, false, true, false, true, false, true, false}));
}

TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) {
IndexingMap indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_),
Expand Down

0 comments on commit 616b53c

Please sign in to comment.