-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Avoid bool being upcast to int #109913
base: gh/peterbell10/621/base
Are you sure you want to change the base?
Conversation
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109913
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 6 Unrelated FailuresAs of commit 93fec8e with merge base 87ea6fb (): NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` ghstack-source-id: 4d6db6f9d7bdd302845dd082d3e171c5c60931a4 Pull Request resolved: #109913
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` ghstack-source-id: 5c1fb44d998fa5b13ff50d9fbdf9a177baf37bce Pull Request resolved: #109913
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` ghstack-source-id: 5c1fb44d998fa5b13ff50d9fbdf9a177baf37bce Pull Request resolved: pytorch#109913
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` ghstack-source-id: fc2536d67024ce4b57c23f0b309e793225aff0cd Pull Request resolved: #109913
torch/_inductor/codegen/triton.py
Outdated
ttype = triton_compute_type(src_dtype) | ||
other = self.cse.generate( | ||
self.compute, | ||
f"tl.full({[1] * self.triton_tensor_ndim()}, {default}, {ttype})", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tl.where
does not broadcast on the number of dims?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This generates the same shape as ops.constant
so there's a chance it gets CSE'd and cleans up the code a bit. No performance or correctness issues.
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` ghstack-source-id: fc2536d67024ce4b57c23f0b309e793225aff0cd Pull Request resolved: pytorch#109913
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` ghstack-source-id: fc2536d67024ce4b57c23f0b309e793225aff0cd Pull Request resolved: pytorch#109913
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` [ghstack-poisoned]
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` ghstack-source-id: 482c0387f874327753b885bc1584d34863830b09 Pull Request resolved: pytorch#109913
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` [ghstack-poisoned]
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` ghstack-source-id: 9417be1b8bacaed57806506527899f77cd3f078f Pull Request resolved: #109913
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot revert -m "causing performance regression in relevant metrics, @malfet I believe you are the correct person to help identify and fix the issues. More details check internal OPS count for ads metricsnin the internal related diff" -c nosignal |
@pytorchbot successfully started a revert job. Check the current status here. |
@peterbell10 your PR has been successfully reverted. |
This reverts commit 9299869. Reverted #109913 on behalf of https://github.com/jeanschmidt due to causing performance regression in relevant metrics, @malfet I believe you are the correct person to help identify and fix the issues. More details check internal OPS count for ads metricsnin the internal related diff ([comment](#109913 (comment)))
Please note, that revert category should be |
@malfet are you able to share any details of the test failure? |
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` Pull Request resolved: pytorch#109913 Approved by: https://github.com/lezcano
This reverts commit 9299869. Reverted pytorch#109913 on behalf of https://github.com/jeanschmidt due to causing performance regression in relevant metrics, @malfet I believe you are the correct person to help identify and fix the issues. More details check internal OPS count for ads metricsnin the internal related diff ([comment](pytorch#109913 (comment)))
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` Pull Request resolved: pytorch#109913 Approved by: https://github.com/lezcano
This reverts commit 9299869. Reverted pytorch#109913 on behalf of https://github.com/jeanschmidt due to causing performance regression in relevant metrics, @malfet I believe you are the correct person to help identify and fix the issues. More details check internal OPS count for ads metricsnin the internal related diff ([comment](pytorch#109913 (comment)))
Currently the inductor code for `x.any(-1)` does a this strange dance: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask) tmp1 = tmp0.to(tl.int64) tmp2 = (tmp1 != 0) ``` This happens because `register_lowering` is doing type promotion with the dimension argument, and so promotes to `int64` which we then cast back to bool. A better fix would be to fix `register_lowering` but for now I just remove the unnecessary type promotion from `aten.any`. In the current code we also see: ```python tmp5 = tl.where(rmask & xmask, tmp3, 0) ``` which promotes the boolean value to int since `0` is an int32 in triton. This fixes it to generate a boolean constant instead. Finally there is also a triton bug where the `tl.load` itself upcasts to `tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final kernel code looks like: ```python tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1) tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) tmp3 = tl.full([1, 1], 0, tl.int1) tmp4 = tl.where(rmask & xmask, tmp1, tmp3) tmp5 = triton_helpers.any(tmp4, 1)[:, None] ``` Pull Request resolved: pytorch#109913 Approved by: https://github.com/lezcano
This reverts commit 9299869. Reverted pytorch#109913 on behalf of https://github.com/jeanschmidt due to causing performance regression in relevant metrics, @malfet I believe you are the correct person to help identify and fix the issues. More details check internal OPS count for ads metricsnin the internal related diff ([comment](pytorch#109913 (comment)))
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
@malfet ping |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
Currently the inductor code for
x.any(-1)
does a this strange dance:This happens because
register_lowering
is doing type promotion with thedimension argument, and so promotes to
int64
which we then cast back to bool.A better fix would be to fix
register_lowering
but for now I just removethe unnecessary type promotion from
aten.any
.In the current code we also see:
which promotes the boolean value to int since
0
is an int32 in triton.This fixes it to generate a boolean constant instead.
Finally there is also a triton bug where the
tl.load
itself upcasts totl.int8
. I fix this by adding an explicit cast totl.int1
. The finalkernel code looks like:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov