Skip to content

Commit

Permalink
[ATen][CUDA][AMP] Fix dtype mismatch in linalg_vector_norm (#125175)
Browse files Browse the repository at this point in the history
Fixes #125174

Pull Request resolved: #125175
Approved by: https://github.com/eqy, https://github.com/lezcano
  • Loading branch information
Aidyn-A authored and pytorchmergebot committed May 1, 2024
1 parent c59cce3 commit 47ba7a7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
10 changes: 8 additions & 2 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2839,10 +2839,16 @@ TORCH_IMPL_FUNC(linalg_vector_norm_out)(const Tensor& self, const Scalar& scalar
}

if (is_reduce_over_1D_vector) {
Tensor self_;
if (opt_dtype.has_value()) {
self_ = self.to(*opt_dtype);
} else {
self_ = self;
}
if (ord != 0.0) {
keepdim ? at::abs_outf(self, const_cast<Tensor&>(result)) : at::abs_outf(self.squeeze(reduce_dim), const_cast<Tensor&>(result));
keepdim ? at::abs_outf(self_, const_cast<Tensor&>(result)) : at::abs_outf(self_.squeeze(reduce_dim), const_cast<Tensor&>(result));
} else {
keepdim ? at::ne_outf(self, 0, const_cast<Tensor&>(result)) : at::ne_outf(self.squeeze(reduce_dim), 0, const_cast<Tensor&>(result));
keepdim ? at::ne_outf(self_, 0, const_cast<Tensor&>(result)) : at::ne_outf(self_.squeeze(reduce_dim), 0, const_cast<Tensor&>(result));
}
return;
}
Expand Down
21 changes: 12 additions & 9 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,7 @@ def test_vector_norm(self, device, dtype):
# torch.linalg.norm given a flattened tensor
ord_vector = [0, 0.9, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf]
input_sizes = [
(1, ),
(10, ),
(4, 5),
(3, 4, 5),
Expand Down Expand Up @@ -1281,15 +1282,17 @@ def run_test_case(input, ord, dim, keepdim, norm_dtype):
else:
raise RuntimeError("Unsupported dtype")

for input_size, ord, keepdim, norm_dtype in product(input_sizes, ord_vector, [True, False], norm_dtypes):
input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
for dim in [None, random.randint(0, len(input_size) - 1)]:
run_test_case(
input,
ord,
dim,
keepdim,
norm_dtype)
for amp in [False, True]:
with torch.autocast(device_type=device, enabled=amp):
for input_size, ord, keepdim, norm_dtype in product(input_sizes, ord_vector, [True, False], norm_dtypes):
input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
for dim in [None, random.randint(0, len(input_size) - 1)]:
run_test_case(
input,
ord,
dim,
keepdim,
norm_dtype)

def test_vector_norm_dim_tuple_arg(self, device):
test_cases = [
Expand Down

0 comments on commit 47ba7a7

Please sign in to comment.