Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU][IndexAnalysis] Add a method to remove unused dims and symbols. #66703

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,6 @@ xla_cc_test(
":indexing_analysis",
":indexing_map",
":indexing_test_utils",
"//xla:literal_util",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main",
Expand Down
148 changes: 90 additions & 58 deletions third_party/xla/xla/service/gpu/model/indexing_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1140,46 +1140,81 @@ UnusedVariables DetectUnusedVariables(const IndexingMap& indexing_map) {
return unused_vars;
}

SmallBitVector ConcatenateBitVectors(const SmallBitVector& lhs,
const SmallBitVector& rhs) {
SmallBitVector concat(lhs.size() + rhs.size(), false);
int id = 0;
for (int i = 0; i < lhs.size(); ++i, ++id) {
concat[id] = lhs[i];
}
for (int i = 0; i < rhs.size(); ++i, ++id) {
concat[id] = rhs[i];
}
return concat;
}

} // namespace

SmallBitVector IndexingMap::RemoveUnusedSymbols() {
if (IsUndefined()) return {};
bool IndexingMap::CompressVars(const llvm::SmallBitVector& unused_dims,
const llvm::SmallBitVector& unused_symbols) {
MLIRContext* mlir_context = GetMLIRContext();

UnusedVariables unused_vars = DetectUnusedVariables(*this);
for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) {
constraints_.erase(expr);
}
unsigned num_symbols_before = GetSymbolCount();
affine_map_ = mlir::compressSymbols(affine_map_, unused_vars.unused_symbols);
bool num_dims_changed = unused_dims.count() > 0;
bool num_symbols_changed = unused_symbols.count() > 0;
if (!num_dims_changed && !num_symbols_changed) return false;

unsigned num_symbols_after = affine_map_.getNumSymbols();
if (num_symbols_after == num_symbols_before) return {};
unsigned num_dims_before = GetDimensionCount();
unsigned num_symbols_before = GetSymbolCount();

// Remap symbols in the constraint expressions accordingly.
std::vector<RangeVar> compressed_range_vars;
std::vector<RTVar> compressed_rt_vars;
MLIRContext* mlir_context = GetMLIRContext();
int64_t used_symbols_count = 0;
std::vector<AffineExpr> symbol_replacements(
num_symbols_before, getAffineConstantExpr(0, mlir_context));
auto range_vars_count = range_vars_.size();
for (int i = 0; i < unused_vars.unused_symbols.size(); ++i) {
if (!unused_vars.unused_symbols[i]) {
if (i < range_vars_count) {
compressed_range_vars.push_back(range_vars_[i]);
} else {
compressed_rt_vars.push_back(rt_vars_[i - range_vars_count]);
// Compress DimVars.
SmallVector<AffineExpr, 2> dim_replacements;
if (num_dims_changed) {
affine_map_ = mlir::compressDims(affine_map_, unused_dims);
std::vector<DimVar> compressed_dim_vars;
dim_replacements = SmallVector<AffineExpr, 2>(
num_dims_before, getAffineConstantExpr(0, mlir_context));
int64_t used_dims_count = 0;
for (int i = 0; i < unused_dims.size(); ++i) {
if (!unused_dims[i]) {
compressed_dim_vars.push_back(dim_vars_[i]);
dim_replacements[i] = getAffineDimExpr(used_dims_count++, mlir_context);
}
symbol_replacements[i] =
getAffineSymbolExpr(used_symbols_count++, mlir_context);
}
dim_vars_ = std::move(compressed_dim_vars);
}

// Compress RangeVars and RTVars.
SmallVector<AffineExpr, 2> symbol_replacements;
if (num_symbols_changed) {
affine_map_ = mlir::compressSymbols(affine_map_, unused_symbols);
symbol_replacements = SmallVector<AffineExpr, 2>(
num_symbols_before, getAffineConstantExpr(0, mlir_context));
std::vector<RangeVar> compressed_range_vars;
std::vector<RTVar> compressed_rt_vars;
MLIRContext* mlir_context = GetMLIRContext();
int64_t used_symbols_count = 0;
auto range_vars_count = range_vars_.size();
for (int i = 0; i < unused_symbols.size(); ++i) {
if (!unused_symbols[i]) {
if (i < range_vars_count) {
compressed_range_vars.push_back(range_vars_[i]);
} else {
compressed_rt_vars.push_back(rt_vars_[i - range_vars_count]);
}
symbol_replacements[i] =
getAffineSymbolExpr(used_symbols_count++, mlir_context);
}
}
range_vars_ = std::move(compressed_range_vars);
rt_vars_ = std::move(compressed_rt_vars);
}
range_vars_ = std::move(compressed_range_vars);
rt_vars_ = std::move(compressed_rt_vars);

// Remove constraints.
std::vector<AffineExpr> to_remove;
std::vector<std::pair<AffineExpr, Interval>> to_add;
for (const auto& [expr, range] : constraints_) {
auto updated_expr = expr.replaceSymbols(symbol_replacements);
auto updated_expr =
expr.replaceDimsAndSymbols(dim_replacements, symbol_replacements);
if (updated_expr == expr) continue;
to_add.push_back({updated_expr, range});
to_remove.push_back(expr);
Expand All @@ -1190,50 +1225,47 @@ SmallBitVector IndexingMap::RemoveUnusedSymbols() {
for (const auto& [expr, range] : to_add) {
AddConstraint(expr, range);
}
return std::move(unused_vars.unused_symbols);
return true;
}

SmallBitVector IndexingMap::RemoveUnusedDimensions() {
SmallBitVector IndexingMap::RemoveUnusedSymbols() {
if (IsUndefined()) return {};

UnusedVariables unused_vars = DetectUnusedVariables(*this);
for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) {
constraints_.erase(expr);
}
unsigned num_dims_before = GetDimensionCount();
affine_map_ = mlir::compressDims(affine_map_, unused_vars.unused_dims);
if (!CompressVars(/*unused_dims=*/{}, unused_vars.unused_symbols)) {
return {};
}
return std::move(unused_vars.unused_symbols);
}

unsigned num_dims_after = affine_map_.getNumDims();
if (num_dims_after == num_dims_before) return {};
SmallBitVector IndexingMap::RemoveUnusedDimensions() {
if (IsUndefined()) return {};

// Remap dimensions in the constraint expressions accordingly.
std::vector<DimVar> compressed_dim_vars;
MLIRContext* mlir_context = GetMLIRContext();
int64_t used_dims_count = 0;
std::vector<AffineExpr> dim_replacements(
num_dims_before, getAffineConstantExpr(0, mlir_context));
for (int i = 0; i < unused_vars.unused_dims.size(); ++i) {
if (!unused_vars.unused_dims[i]) {
compressed_dim_vars.push_back(dim_vars_[i]);
dim_replacements[i] = getAffineDimExpr(used_dims_count++, mlir_context);
}
UnusedVariables unused_vars = DetectUnusedVariables(*this);
for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) {
constraints_.erase(expr);
}
dim_vars_ = std::move(compressed_dim_vars);
std::vector<AffineExpr> to_remove;
std::vector<std::pair<AffineExpr, Interval>> to_add;
for (const auto& [expr, range] : constraints_) {
auto updated_expr = expr.replaceDims(dim_replacements);
if (updated_expr == expr) continue;
to_add.push_back({updated_expr, range});
to_remove.push_back(expr);
if (!CompressVars(unused_vars.unused_dims, /*unused_symbols=*/{})) {
return {};
}
for (const auto& expr : to_remove) {
return std::move(unused_vars.unused_dims);
}

SmallBitVector IndexingMap::RemoveUnusedVars() {
if (IsUndefined()) return {};

UnusedVariables unused_vars = DetectUnusedVariables(*this);
for (AffineExpr expr : unused_vars.constraints_with_unused_vars_only) {
constraints_.erase(expr);
}
for (const auto& [expr, range] : to_add) {
AddConstraint(expr, range);
if (!CompressVars(unused_vars.unused_dims, unused_vars.unused_symbols)) {
return {};
}
return std::move(unused_vars.unused_dims);
return ConcatenateBitVectors(unused_vars.unused_dims,
unused_vars.unused_symbols);
}

void IndexingMap::MergeModConstraints() {
Expand Down
11 changes: 11 additions & 0 deletions third_party/xla/xla/service/gpu/model/indexing_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,11 @@ class IndexingMap {
// were removed, returns {}.
llvm::SmallBitVector RemoveUnusedSymbols();

// Removes unused dimensions and symbols from the `affine_map_` and
// constraints. Returns a bit vector of all variables [dimensions, symbols]
// that were removed. If none of the symbols were removed, returns {}.
llvm::SmallBitVector RemoveUnusedVars();

// Rescales all symbols that are sufficiently constrained through `s? mod x =
// [N, N]` constraints. Returns true if a rescale took place, otherwise false.
bool RescaleSymbols();
Expand Down Expand Up @@ -363,6 +368,12 @@ class IndexingMap {
// Returns true if a replacement was performed, otherwise false.
bool ReplaceConstantRTVars(IndexingMapProvider indexing_map_provider);

// Removes DimVars, RangeVars, RTVars that correspond to the unused dimensions
// and symbols. If unused_dims is empty, then dims won't be removed. The same
// applies to unused_symbols. Returns true, if anything was removed.
bool CompressVars(const llvm::SmallBitVector& unused_dims,
const llvm::SmallBitVector& unused_symbols);

mlir::AffineMap affine_map_;
std::vector<DimVar> dim_vars_;
std::vector<RangeVar> range_vars_;
Expand Down
58 changes: 56 additions & 2 deletions third_party/xla/xla/service/gpu/model/indexing_map_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ limitations under the License.
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/model/affine_map_printer.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_test_utils.h"
Expand All @@ -48,6 +46,15 @@ class IndexingMapTest : public HloTestBase {
AffineMapPrinter printer_;
};

std::vector<bool> ConvertToSTL(const llvm::SmallBitVector& bit_vector) {
std::vector<bool> result;
result.reserve(bit_vector.size());
for (int i = 0; i < bit_vector.size(); ++i) {
result.push_back(bit_vector[i]);
}
return result;
}

TEST_F(IndexingMapTest, RTVar) {
auto zero_dim_map = AffineMap::get(&mlir_context_);
std::vector<RTVar> rt_vars{RTVar{Interval{0, 2},
Expand Down Expand Up @@ -222,6 +229,53 @@ TEST_F(IndexingMapTest, RemoveUnusedDimensions_ConstraintUsesOnlyUnusedDim) {
)"));
}

TEST_F(IndexingMapTest, RemoveUnusedDimensions_ConstraintsWithManyDims) {
IndexingMap indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap("(d0, d1, d2, d3, d4)[s0, s1] -> (s0 * 4 + d1 + d3 - 42)",
&mlir_context_),
{1, 2, 3, 4, 5}, {32, 64});
indexing_map.AddConstraint(
ParseAffineExpr("s0 * 4 + d1 + d3", &mlir_context_), Interval{24, 459});
indexing_map.RemoveUnusedDimensions();
// dimensions d0, d2, d4 will be removed and d1 and d3 will become d0 and d1.
EXPECT_THAT(indexing_map, MatchIndexingMap(R"(
(d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42)
domain:
d0 in [0, 1]
d1 in [0, 3]
s0 in [0, 31]
s1 in [0, 63]
d0 + s0 * 4 + d1 in [24, 459]
)"));
}

TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) {
IndexingMap indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap(
"(d0, d1, d2, d3, d4)[s0, s1, s2] -> (s0 * 4 + d1 + d3 - 42)",
&mlir_context_),
{1, 2, 3, 4, 5}, {32, 64, 96});
indexing_map.AddConstraint(
ParseAffineExpr("s0 * 4 + d1 + d3", &mlir_context_), Interval{24, 459});
indexing_map.AddConstraint(ParseAffineExpr("s0 + s2", &mlir_context_),
Interval{0, 512});
auto unused_vars = indexing_map.RemoveUnusedVars();
// dimensions d0, d2, d4 and symbol s1 will be removed.
EXPECT_THAT(indexing_map, MatchIndexingMap(R"(
(d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42)
domain:
d0 in [0, 1]
d1 in [0, 3]
s0 in [0, 31]
s1 in [0, 95]
d0 + s0 * 4 + d1 in [24, 459]
s0 + s1 in [0, 512]
)"));
EXPECT_THAT(ConvertToSTL(unused_vars),
::testing::ElementsAreArray(
{true, false, true, false, true, false, true, false}));
}

TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) {
IndexingMap indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_),
Expand Down