New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix mem size mismatch from split/chunk in const folding #125199
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125199
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 9e9ccf6 with merge base 2c8237c (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot merge |
This PR needs to be approved by an authorized maintainer before merge. |
This pull request was exported from Phabricator. Differential Revision: D56663931 |
Summary: The chunk/split ops on the weights/constants is folded in a fx pass and each output tensor has the same storage size of the original tensor (which is 3x of its actual size if chunk(3)). However Backend calculates the mem size on device from tensor shape/stride/dtype. This causes the mismatch when copying weights/constants to device as allocated mem on device is always smaller than the size of weights/constants and results in a runtime error in loading weight/constant (T172125529). This diff fixes the issue by cloning the tensors after const folding so that the tensors has correct storage size. Test Plan: Before this change: (18432 = 48 * 64 * 2 * 3) ``` RuntimeError: Failed to load constant getitem_idx0 split (remaining=18432) at fbcode/caffe2/torch/fb/acc_runtime/afg/afg_bindings.cpp:3422: Request failed because an invalid parameter ``` ``` buck2 run mode/opt //caffe2/torch/fb/acc_runtime/afg/tests:test_operators-artemis -- -r test_mem_size_mismatch ``` ``` Ran 1 test in 7.048s OK ``` Reviewed By: jfix71 Differential Revision: D56663931
This pull request was exported from Phabricator. Differential Revision: D56663931 |
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Looking at the unrelated Windows timeout failure on #125199, it looks like we don't have a timeout value set for C++ tests atm. In this case, a C++ test on Windows timed out after 2+ hours. ``` 2024-05-02T23:35:34.0639067Z Running cpp/c10_TypeList_test 1/1 ... [2024-05-02 23:35:34.059021] 2024-05-02T23:35:34.0641108Z Executing ['pytest', 'C:\\actions-runner\\_work\\pytorch\\pytorch\\build\\win_tmp\\build\\torch\\test\\c10_TypeList_test.exe', '-m', 'not serial', '-v', '-vv', '-rfEX', '-n', '2', '--junit-xml-reruns', 'test-reports\\python-pytest\\test\\run_test\\test\\run_test-c898ddeff8f33cbf.xml', '-x', '--reruns=2'] ... [2024-05-02 23:35:34.062137] 2024-05-03T02:45:33.7862004Z Process SpawnPoolWorker-2: 2024-05-03T02:45:33.7927201Z Traceback (most recent call last): 2024-05-03T02:45:33.7928032Z File "C:\Jenkins\Miniconda3\lib\multiprocessing\process.py", line 315, in _bootstrap 2024-05-03T02:45:33.7928722Z self.run() 2024-05-03T02:45:33.7929722Z File "C:\Jenkins\Miniconda3\lib\multiprocessing\process.py", line 108, in run 2024-05-03T02:45:33.7931639Z self._target(*self._args, **self._kwargs) 2024-05-03T02:45:33.7932435Z File "C:\Jenkins\Miniconda3\lib\multiprocessing\pool.py", line 114, in worker 2024-05-03T02:45:33.7933338Z task = get() 2024-05-03T02:45:33.7933946Z File "C:\Jenkins\Miniconda3\lib\multiprocessing\queues.py", line 365, in get 2024-05-03T02:45:33.7935219Z res = self._reader.recv_bytes() 2024-05-03T02:45:33.7935897Z File "C:\Jenkins\Miniconda3\lib\multiprocessing\connection.py", line 221, in recv_bytes 2024-05-03T02:45:33.7936609Z buf = self._recv_bytes(maxlength) 2024-05-03T02:45:33.7937302Z File "C:\Jenkins\Miniconda3\lib\multiprocessing\connection.py", line 310, in _recv_bytes 2024-05-03T02:45:33.7938316Z waitres = _winapi.WaitForMultipleObjects( 2024-05-03T02:45:33.7938766Z KeyboardInterrupt ``` Retrying was working, but it was already too late to finish the job. I'm setting the same default `THRESHOLD * 3` timeout value here for C++ tests. Pull Request resolved: #125517 Approved by: https://github.com/clee2000
Summary:
The chunk/split ops on the weights/constants is folded in a fx pass and each output tensor has the same storage size of the original tensor (which is 3x of its actual size if chunk(3)). However Backend calculates the mem size on device from tensor shape/stride/dtype. This causes the mismatch when copying weights/constants to device as allocated mem on device is always smaller than the size of weights/constants and results in a runtime error in loading weight/constant (T172125529).
This diff fixes the issue by cloning the tensors after const folding so that the tensors has correct storage size.
Test Plan:
Before this change: (18432 = 48 * 64 * 2 * 3)
Reviewed By: jfix71
Differential Revision: D56663931