Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A faster implementation of NCC using cumulative summation #558

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from

Conversation

kvttt
Copy link

@kvttt kvttt commented Oct 26, 2023

Runs 14x faster using a safer approach (to prevent overflow) and 24x faster using an unsafe approach (Tensorflow implementation). Also provide the option to use double precision (to prevent overflow).

Tested on Tensorflow (2.11.0) and PyTorch (2.1.0) with both GPU (RTX 3090) and CPU (i7-8700K). Attached are the run time and accuracy on GPU:

----------
Tensorflow
----------
Safe cumsum impl: 0.29789113998413086 s.
Unsafe cumsum impl: 0.17483901977539062 s.
Original impl: 4.185877799987793 s.
Safe cumsum vs original: [5.820766e-10].
Unsafe cumsum vs original: [5.820766e-10].
-------
PyTorch
-------
Safe cumsum impl: 0.10962605476379395 s.
Unsafe cumsum impl: 0.03331756591796875 s.
Original impl: 1.2561259269714355 s.
Safe cumsum vs original: 2.3283064365386963e-10.
Unsafe cumsum vs original: 0.0.

14x faster using a safer approach (to prevent overflow) and 24x faster using an unsafe approach (tested on a single GPU). Also provide the option to use double precision (to prevent overflow).
Updated the PyTorch impl. Tested on CPU and GPU.
@kvttt
Copy link
Author

kvttt commented Oct 26, 2023

Original implementation in voxelmorph (PyTorch):

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total MFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      aten::convolution         0.01%     209.000us         7.69%     107.911ms       2.158ms       0.000us         0.00%        1.220s      24.394ms           0 b           0 b     937.50 Mb           0 b            50            --  
                                     aten::_convolution         0.02%     280.000us         7.68%     107.702ms       2.154ms       0.000us         0.00%        1.220s      24.394ms           0 b           0 b     937.50 Mb           0 b            50            --  
                                aten::cudnn_convolution         5.73%      80.381ms         7.66%     107.422ms       2.148ms        1.220s        98.60%        1.220s      24.394ms           0 b           0 b     937.50 Mb     937.50 Mb            50            --  
void implicit_convolveNd_sgemm<float, 3, 1024, 5, 5,...         0.00%       0.000us         0.00%       0.000us       0.000us        1.220s        98.60%        1.220s      24.394ms           0 b           0 b           0 b           0 b            50            --  
                                           aten::conv3d         0.02%     234.000us         7.71%     108.115ms       2.162ms       0.000us         0.00%        1.195s      23.908ms           0 b           0 b     937.50 Mb      18.75 Mb            50            --  
                                              aten::mul         0.15%       2.064ms         1.33%      18.663ms     109.782us       9.856ms         0.80%      10.069ms      59.229us           0 b           0 b       3.11 Gb       3.11 Gb           170       835.584  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       7.455ms         0.60%       7.455ms      62.125us           0 b           0 b           0 b           0 b           120            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.985ms         0.40%       4.985ms      63.101us           0 b           0 b           0 b           0 b            79            --  
                                              aten::sub         0.04%     516.000us         1.47%      20.609ms     515.225us       2.837ms         0.23%       2.837ms      70.925us           0 b           0 b     750.00 Mb     750.00 Mb            40            --  
                                              aten::add         0.04%     494.000us         0.07%     917.000us      22.366us       2.632ms         0.21%       2.632ms      64.195us           0 b           0 b     750.00 Mb     750.00 Mb            41       196.608  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.401ms         0.19%       2.401ms      48.020us           0 b           0 b           0 b           0 b            50            --  
                                              aten::div         0.04%     573.000us         8.54%     119.890ms       3.996ms       1.681ms         0.14%       1.681ms      56.033us           0 b           0 b     562.50 Mb     562.50 Mb            30            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     954.000us         0.08%     954.000us      47.700us           0 b           0 b           0 b           0 b            20            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     727.000us         0.06%     727.000us      72.700us           0 b           0 b           0 b           0 b            10            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     493.000us         0.04%     493.000us      44.818us           0 b           0 b           0 b           0 b            11            --  
                                             aten::mean         0.03%     404.000us         0.73%      10.232ms       1.023ms     310.000us         0.03%     310.000us      31.000us           0 b           0 b       5.00 Kb       4.50 Kb            10            --  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us     310.000us         0.03%     310.000us      31.000us           0 b           0 b           0 b           0 b            10            --  
                                               aten::to         0.00%      59.000us        70.22%     985.226ms      98.523ms       0.000us         0.00%     134.000us      13.400us           0 b           0 b      30.00 Kb           0 b            10            --  
                                         aten::_to_copy         0.01%     131.000us        70.21%     985.167ms      98.517ms       0.000us         0.00%     134.000us      13.400us           0 b           0 b      30.00 Kb           0 b            10            --  
                                             cudaMalloc         0.40%       5.595ms         0.40%       5.595ms     266.429us     120.000us         0.01%     120.000us       5.714us           0 b           0 b           0 b           0 b            21            --  
                                  cudaStreamIsCapturing         0.00%      28.000us         0.00%      28.000us       0.418us     119.000us         0.01%     119.000us       1.776us           0 b           0 b           0 b           0 b            67            --  
                                    aten::empty_strided         0.01%      79.000us         0.06%     864.000us      86.400us       0.000us         0.00%      97.000us       9.700us           0 b           0 b      30.00 Kb      30.00 Kb            10            --  
                                       cudaLaunchKernel        13.52%     189.643ms        13.52%     189.643ms     526.786us      71.000us         0.01%      71.000us       0.197us           0 b           0 b           0 b           0 b           360            --  
                                            aten::copy_         0.01%     205.000us        70.14%     984.172ms      98.417ms       7.000us         0.00%      37.000us       3.700us           0 b           0 b           0 b           0 b            10            --  
                                        cudaMemcpyAsync         0.01%     114.000us         0.01%     114.000us      11.400us      30.000us         0.00%      30.000us       3.000us           0 b           0 b           0 b           0 b            10            --  
                                              aten::neg         0.01%     157.000us         1.25%      17.503ms       1.750ms      15.000us         0.00%      15.000us       1.500us           0 b           0 b       5.00 Kb       5.00 Kb            10            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      15.000us         0.00%      15.000us       1.500us           0 b           0 b           0 b           0 b            10            --  
                                             aten::add_         0.01%     101.000us         0.01%     141.000us      15.667us       9.000us         0.00%       9.000us       1.000us           0 b           0 b           0 b           0 b             9            --  
                       Memcpy HtoD (Pageable -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       7.000us         0.00%       7.000us       0.700us           0 b           0 b           0 b           0 b            10            --  
                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         0.00%       2.000us       0.182us           0 b           0 b           0 b           0 b            11            --  
                                             aten::ones         0.01%      89.000us         0.01%     205.000us      20.500us       0.000us         0.00%       0.000us       0.000us      28.48 Kb           0 b           0 b           0 b            10            --  
                                            aten::empty         0.01%      74.000us         0.01%      74.000us       7.400us       0.000us         0.00%       0.000us       0.000us      28.48 Kb      28.48 Kb           0 b           0 b            10            --  
                                            aten::fill_         0.00%      42.000us         0.00%      42.000us       4.200us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            10            --  
                                  cudaStreamSynchronize        70.12%     983.853ms        70.12%     983.853ms      98.385ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            10            --  
                                               [memory]         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us     -28.48 Kb     -28.48 Kb      -6.04 Gb      -6.04 Gb           532            --  
                                     cudaGetDeviceCount         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1            --  
                                   cudaDriverGetVersion         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1            --  
                                 cudaDeviceGetAttribute         0.01%      92.000us         0.01%      92.000us       2.421us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            38            --  
                             cudaGetDeviceProperties_v2         0.01%     106.000us         0.01%     106.000us     106.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1            --  
                              cudaStreamCreateWithFlags         0.18%       2.507ms         0.18%       2.507ms     156.688us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            16            --  
                                        cudaMemsetAsync         0.01%     109.000us         0.01%     109.000us       9.909us       0.000us         0.00%       0.000us       0.000us           0 b           0 b         512 b         512 b            11            --  
                                          cudaHostAlloc         0.06%     800.000us         0.06%     800.000us     800.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1            --  
                               cudaHostGetDevicePointer         0.00%       2.000us         0.00%       2.000us       2.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1            --  
                                               cudaFree         0.88%      12.329ms         0.88%      12.329ms       4.110ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             3            --  
                                   cudaGetSymbolAddress         0.01%     201.000us         0.01%     201.000us     201.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1            --  
                                  cudaStreamGetPriority         0.00%       5.000us         0.00%       5.000us       0.100us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            50            --  
                       cudaDeviceGetStreamPriorityRange         0.00%       6.000us         0.00%       6.000us       0.120us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            50            --  
                                       aten::as_strided         0.00%      49.000us         0.00%      49.000us       4.900us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            10            --  
                                  cudaDeviceSynchronize         8.66%     121.545ms         8.66%     121.545ms     121.545ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1            --  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.403s
Self CUDA time total: 1.237s

Safe cumsum implementation (PyTorch):

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total MFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::copy_         3.29%       1.457ms         7.76%       3.433ms       8.582us      13.832ms        24.77%      13.832ms      34.580us           0 b           0 b           0 b           0 b           400            --  
                                           aten::cumsum         3.68%       1.629ms         8.37%       3.700ms      24.667us      10.860ms        19.45%      13.366ms      89.107us           0 b           0 b       3.06 Gb       2.05 Gb           150            --  
                                             aten::sub_         2.57%       1.138ms         4.74%       2.095ms      13.967us      12.417ms        22.23%      12.417ms      82.780us           0 b           0 b           0 b           0 b           150            --  
                                            aten::clone         1.98%     874.000us        10.28%       4.546ms      22.730us       0.000us         0.00%      10.743ms      53.715us           0 b           0 b       4.20 Gb    -220.00 Mb           200            --  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      10.717ms        19.19%      10.717ms      82.438us           0 b           0 b           0 b           0 b           130            --  
                                              aten::mul         7.45%       3.296ms         9.33%       4.128ms      24.282us       9.856ms        17.65%       9.856ms      57.976us           0 b           0 b       3.11 Gb       3.11 Gb           170       835.584  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       8.816ms        15.79%       8.816ms      58.773us           0 b           0 b           0 b           0 b           150            --  
void at::native::tensor_kernel_scan_outer_dim<float,...         0.00%       0.000us         0.00%       0.000us       0.000us       7.906ms        14.16%       7.906ms      79.060us           0 b           0 b           0 b           0 b           100            --  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       5.016ms         8.98%       5.016ms      50.160us           0 b           0 b           0 b           0 b           100            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.498ms         8.05%       4.498ms      56.225us           0 b           0 b           0 b           0 b            80            --  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.903ms         6.99%       3.903ms      78.060us           0 b           0 b           0 b           0 b            50            --  
                                              aten::pad         0.25%     109.000us         7.41%       3.277ms      65.540us       0.000us         0.00%       3.811ms      76.220us           0 b           0 b       1.07 Gb           0 b            50            --  
                                  aten::constant_pad_nd         1.65%     729.000us         7.16%       3.168ms      63.360us       0.000us         0.00%       3.811ms      76.220us           0 b           0 b       1.07 Gb           0 b            50            --  
void at::native::tensor_kernel_scan_innermost_dim<fl...         0.00%       0.000us         0.00%       0.000us       0.000us       2.954ms         5.29%       2.954ms      59.080us           0 b           0 b           0 b           0 b            50            --  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       2.944ms         5.27%       2.944ms      73.600us           0 b           0 b           0 b           0 b            40            --  
                                              aten::sub         0.91%     404.000us         1.30%     576.000us      14.400us       2.923ms         5.23%       2.923ms      73.075us           0 b           0 b     750.00 Mb     750.00 Mb            40            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.876ms         5.15%       2.876ms      58.694us           0 b           0 b           0 b           0 b            49            --  
                                              aten::add         0.93%     410.000us         1.35%     597.000us      14.561us       2.627ms         4.70%       2.627ms      64.073us           0 b           0 b     750.00 Mb     750.00 Mb            41       196.608  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.414ms         4.32%       2.414ms      48.280us           0 b           0 b           0 b           0 b            50            --  
                                       aten::contiguous         0.19%      85.000us         2.56%       1.131ms      22.620us       0.000us         0.00%       2.406ms      48.120us           0 b           0 b    1000.00 Mb      40.00 Mb            50            --  
                                              aten::div         0.73%     321.000us         1.11%     492.000us      16.400us       1.704ms         3.05%       1.704ms      56.800us           0 b           0 b     562.50 Mb     562.50 Mb            30            --  
                                            aten::fill_         0.64%     285.000us         1.33%     589.000us      11.780us       1.301ms         2.33%       1.301ms      26.020us           0 b           0 b           0 b           0 b            50            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.301ms         2.33%       1.301ms      26.020us           0 b           0 b           0 b           0 b            50            --  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     993.000us         1.78%     993.000us      49.650us           0 b           0 b           0 b           0 b            20            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     711.000us         1.27%     711.000us      71.100us           0 b           0 b           0 b           0 b            10            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     480.000us         0.86%     480.000us      43.636us           0 b           0 b           0 b           0 b            11            --  
                                             aten::mean         0.56%     249.000us         0.82%     361.000us      36.100us     309.000us         0.55%     309.000us      30.900us           0 b           0 b       5.00 Kb       4.50 Kb            10            --  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us     301.000us         0.54%     301.000us      30.100us           0 b           0 b           0 b           0 b            10            --  
                                              aten::neg         0.40%     179.000us         0.52%     228.000us      22.800us      10.000us         0.02%      10.000us       1.000us           0 b           0 b       5.00 Kb       5.00 Kb            10            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us         0.02%      10.000us       1.000us           0 b           0 b           0 b           0 b            10            --  
                                             aten::add_         0.17%      74.000us         0.26%     115.000us      12.778us       9.000us         0.02%       9.000us       1.000us           0 b           0 b           0 b           0 b             9            --  
                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us       8.000us         0.01%       8.000us       0.800us           0 b           0 b           0 b           0 b            10            --  
                                       cudaLaunchKernel         8.99%       3.975ms         8.99%       3.975ms       5.230us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           760            --  
                                            aten::empty         1.09%     481.000us         1.09%     481.000us       4.810us       0.000us         0.00%       0.000us       0.000us           0 b           0 b       2.05 Gb       2.05 Gb           100            --  
                                           aten::narrow         0.93%     411.000us         2.06%     911.000us       3.037us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           300            --  
                                            aten::slice        15.34%       6.783ms        15.92%       7.039ms       2.133us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3300            --  
                                       aten::as_strided         0.60%     266.000us         0.60%     266.000us       0.080us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3310            --  
                                    aten::empty_strided         1.70%     752.000us         1.70%     752.000us       5.013us       0.000us         0.00%       0.000us       0.000us           0 b           0 b       3.22 Gb       3.22 Gb           150            --  
                                        cudaMemcpyAsync         3.26%       1.442ms         3.26%       1.442ms       9.613us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           150            --  
                                               [memory]         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b     -12.49 Gb     -12.49 Gb           812            --  
                                               aten::to         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           150            --  
                                  cudaStreamIsCapturing         0.00%       1.000us         0.00%       1.000us       1.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1            --  
                                             cudaMalloc         0.49%     217.000us         0.49%     217.000us     217.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1            --  
                                       aten::empty_like         0.26%     117.000us         0.62%     274.000us       5.480us       0.000us         0.00%       0.000us       0.000us           0 b           0 b    1000.00 Mb     220.00 Mb            50            --  
                                        cudaMemsetAsync         0.12%      54.000us         0.12%      54.000us       5.400us       0.000us         0.00%       0.000us       0.000us           0 b           0 b         512 b         512 b            10            --  
                                  cudaDeviceSynchronize        41.80%      18.486ms        41.80%      18.486ms      18.486ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1            --  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 44.224ms
Self CUDA time total: 55.848ms

Unsafe cumsum implementation (PyTorch):

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total MFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      24.347ms        42.54%      24.347ms      69.563us           0 b           0 b           0 b           0 b           350            --  
                                              aten::sub         5.92%       3.858ms         9.03%       5.881ms      24.504us      16.260ms        28.41%      16.260ms      67.750us           0 b           0 b       4.39 Gb       4.39 Gb           240            --  
                                              aten::add         5.01%       3.262ms         7.76%       5.055ms      26.466us      13.553ms        23.68%      13.633ms      71.377us           0 b           0 b       3.48 Gb       3.48 Gb           191       933.888  
                                           aten::cumsum         4.30%       2.799ms        13.92%       9.065ms      60.433us      11.330ms        19.80%      11.547ms      76.980us           0 b           0 b       3.22 Gb       3.22 Gb           150            --  
                                              aten::mul         6.95%       4.524ms         8.84%       5.756ms      33.859us       9.812ms        17.15%       9.884ms      58.141us           0 b           0 b       3.11 Gb       3.11 Gb           170       835.584  
void at::native::tensor_kernel_scan_outer_dim<float,...         0.00%       0.000us         0.00%       0.000us       0.000us       8.050ms        14.07%       8.050ms      80.500us           0 b           0 b           0 b           0 b           100            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       7.411ms        12.95%       7.411ms      61.758us           0 b           0 b           0 b           0 b           120            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.996ms         8.73%       4.996ms      63.241us           0 b           0 b           0 b           0 b            79            --  
                                  aten::constant_pad_nd         1.87%       1.216ms        32.69%      21.292ms     425.840us       0.000us         0.00%       3.893ms      77.860us           0 b           0 b       1.07 Gb           0 b            50            --  
                                              aten::pad         0.97%     631.000us        32.99%      21.487ms     429.740us       0.000us         0.00%       3.588ms      71.760us           0 b           0 b       1.07 Gb      88.00 Mb            50            --  
void at::native::tensor_kernel_scan_innermost_dim<fl...         0.00%       0.000us         0.00%       0.000us       0.000us       3.280ms         5.73%       3.280ms      65.600us           0 b           0 b           0 b           0 b            50            --  
                                            aten::copy_         1.13%     735.000us        17.39%      11.325ms     226.500us       2.508ms         4.38%       2.508ms      50.160us           0 b           0 b           0 b           0 b            50            --  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       2.508ms         4.38%       2.508ms      50.160us           0 b           0 b           0 b           0 b            50            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.401ms         4.20%       2.401ms      48.020us           0 b           0 b           0 b           0 b            50            --  
                                              aten::div         0.77%     504.000us         1.14%     742.000us      24.733us       2.125ms         3.71%       2.125ms      70.833us           0 b           0 b     562.50 Mb     562.50 Mb            30            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.415ms         2.47%       1.415ms      70.750us           0 b           0 b           0 b           0 b            20            --  
                                            aten::fill_         0.68%     444.000us         9.46%       6.162ms     123.240us       1.312ms         2.29%       1.385ms      27.700us           0 b           0 b           0 b           0 b            50            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.312ms         2.29%       1.312ms      26.240us           0 b           0 b           0 b           0 b            50            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     710.000us         1.24%     710.000us      71.000us           0 b           0 b           0 b           0 b            10            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     479.000us         0.84%     479.000us      43.545us           0 b           0 b           0 b           0 b            11            --  
                                       cudaLaunchKernel        42.17%      27.466ms        42.17%      27.466ms      30.182us     442.000us         0.77%     442.000us       0.486us           0 b           0 b           0 b           0 b           910            --  
                                             aten::mean         0.65%     424.000us         0.98%     639.000us      63.900us     309.000us         0.54%     309.000us      30.900us           0 b           0 b       5.00 Kb       5.00 Kb            10            --  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us     303.000us         0.53%     303.000us      30.300us           0 b           0 b           0 b           0 b            10            --  
                                              aten::neg         0.30%     196.000us         0.43%     278.000us      27.800us      10.000us         0.02%      10.000us       1.000us           0 b           0 b       5.00 Kb       5.00 Kb            10            --  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us         0.02%      10.000us       1.000us           0 b           0 b           0 b           0 b            10            --  
                                             aten::add_         0.18%     119.000us         0.28%     182.000us      20.222us       9.000us         0.02%       9.000us       1.000us           0 b           0 b           0 b           0 b             9            --  
                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         0.01%       6.000us       0.600us           0 b           0 b           0 b           0 b            10            --  
                                            aten::empty         0.76%     498.000us         1.30%     844.000us      16.880us       0.000us         0.00%       0.000us       0.000us           0 b           0 b       1.07 Gb       1.07 Gb            50            --  
                                  cudaStreamIsCapturing         0.01%       9.000us         0.01%       9.000us       3.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             3            --  
                                             cudaMalloc         1.47%     960.000us         1.47%     960.000us     320.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             3            --  
                                           aten::narrow         1.21%     791.000us         2.49%       1.619ms       5.397us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           300            --  
                                            aten::slice        15.17%       9.880ms        16.80%      10.939ms       4.756us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          2300            --  
                                       aten::as_strided         1.65%       1.072ms         1.65%       1.072ms       0.464us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          2310            --  
                                               aten::to         0.00%       3.000us         0.00%       3.000us       0.020us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           150            --  
                                               [memory]         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b     -15.83 Gb     -15.83 Gb          1012            --  
                                        cudaMemsetAsync         0.18%     115.000us         0.18%     115.000us      11.500us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            10            --  
                                  cudaDeviceSynchronize         8.64%       5.626ms         8.64%       5.626ms       5.626ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1            --  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 65.132ms
Self CUDA time total: 57.228ms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant