Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RandomJPEG can't handle input on different from CPU device. #2867

Open
ditwoo opened this issue Apr 3, 2024 · 5 comments · May be fixed by #2883
Open

RandomJPEG can't handle input on different from CPU device. #2867

ditwoo opened this issue Apr 3, 2024 · 5 comments · May be fixed by #2883
Labels
bug 🐛 Something isn't working help wanted Extra attention is needed

Comments

@ditwoo
Copy link

ditwoo commented Apr 3, 2024

Describe the bug

RandomJPEG throws an error when the input tensor is not on the CPU. To somehow handle this issue need to manually pass a tensor with the right device to the jpeg_quality parameter when RandomJPEG is created.

Reproduction steps

  1. Step with bug:
import torch
from kornia.augmentation import RandomJPEG

device = "cuda"
jpegq = (1.0, 50.0)
aug = RandomJPEG(jpeg_quality=jpegq, p=1.0)

example_input = torch.randn((3, 224, 224)).to(device)
res = aug(example_input)

An error:

/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Traceback (most recent call last):
  File "/home/dmdr/Documents/Code/Python/aaa/ptrainer/tmp.py", line 27, in <module>
    res = aug(example_input)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/base.py", line 210, in forward
    output = self.apply_func(in_tensor, params, flags)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/_2d/base.py", line 129, in apply_func
    output = self.transform_inputs(in_tensor, params, flags, trans_matrix)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/base.py", line 261, in transform_inputs
    output = self.apply_transform(in_tensor, params, flags, transform=transform)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/_2d/intensity/jpeg.py", line 56, in apply_transform
    jpeg_output: Tensor = jpeg_codec_differentiable(input, params["jpeg_quality"])
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/utils/image.py", line 231, in _wrapper
    output = f(input, *args, **kwargs)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/enhance/jpeg.py", line 484, in jpeg_codec_differentiable
    y_encoded, cb_encoded, cr_encoded = _jpeg_encode(
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/enhance/jpeg.py", line 281, in _jpeg_encode
    y_encoded: Tensor = _quantize(
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/enhance/jpeg.py", line 177, in _quantize
    quantization_table[:, None] * _jpeg_quality_to_scale(jpeg_quality)[:, None, None, None]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
  1. Fix with manual tensor creation and placing:
import torch
from kornia.augmentation import RandomJPEG

device = "cuda"
jpegq = torch.Tensor((1.0, 50.0)).to(device)
aug = RandomJPEG(jpeg_quality=jpegq, p=1.0)

example_input = torch.randn((3, 224, 224)).to(device)
res = aug(example_input)

Expected behavior

I expect that RandomJPEG will understand the tensor location (device) and do all the operations on a device that was passed as an argument like RandomAffine or any other augmentation does.

Environment

PyTorch version: 2.0.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.0
Libc version: glibc-2.35

Python version: 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-6.5.0-26-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1080 Ti
Nvidia driver version: 525.147.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      43 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             32
On-line CPU(s) list:                0-31
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen Threadripper 1950X 16-Core Processor
CPU family:                         23
Model:                              1
Thread(s) per core:                 2
Core(s) per socket:                 16
Socket(s):                          1
Stepping:                           1
Frequency boost:                    enabled
CPU max MHz:                        3400,0000
CPU min MHz:                        2200,0000
BogoMIPS:                           6786.44
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid amd_dcm aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb hw_pstate ssbd ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt sha_ni xsaveopt xsavec xgetbv1 clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif overflow_recov succor smca sev
Virtualization:                     AMD-V
L1d cache:                          512 KiB (16 instances)
L1i cache:                          1 MiB (16 instances)
L2 cache:                           8 MiB (16 instances)
L3 cache:                           32 MiB (4 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Mitigation; untrained return thunk; SMT vulnerable
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] efficientnet_pytorch==0.7.1
[pip3] numpy==1.26.4
[pip3] onnx==1.12.0
[pip3] onnx-tf==1.10.0
[pip3] onnxconverter-common==1.13.0
[pip3] onnxruntime==1.15.1
[pip3] pytorch-lightning==2.2.1
[pip3] pytorch-toolbelt==0.6.3
[pip3] segmentation-models-pytorch==0.3.3
[pip3] torch==2.0.1
[pip3] torchmetrics==1.2.1
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0
[conda] efficientnet-pytorch      0.7.1                    pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] pytorch-lightning         2.2.1                    pypi_0    pypi
[conda] pytorch-toolbelt          0.6.3                    pypi_0    pypi
[conda] segmentation-models-pytorch 0.3.3                    pypi_0    pypi
[conda] torch                     2.0.1                    pypi_0    pypi
[conda] torchmetrics              1.2.1                    pypi_0    pypi
[conda] torchvision               0.15.2                   pypi_0    pypi
[conda] triton                    2.0.0                    pypi_0    pypi

Additional context

No response

@ditwoo ditwoo added the help wanted Extra attention is needed label Apr 3, 2024
@edgarriba
Copy link
Member

the augmentations in the end are nn.Module 's, this same behaviour i believe you face when you forward a tensor to a regular model in pytorch that the tensor should match with the params device and not the other way around @johnnv1 @shijianjian

@ditwoo
Copy link
Author

ditwoo commented Apr 4, 2024

the augmentations in the end are nn.Module 's, this same behaviour i believe you face when you forward a tensor to a regular model in pytorch that the tensor should match with the params device and not the other way around @johnnv1 @shijianjian

Yeah, It should work like with nn.Module. When I'm applying the same logic to a RandomJPEG object - create an object, then move it to a CUDA device, then create an input tensor on the same CUDA device, then call the RandomJPEG object with the previously created tensor I still have an error. Here is an example:

import torch
from kornia.augmentation import RandomJPEG

device = "cuda"
jpegq = (1.0, 50.0)
aug = RandomJPEG(jpeg_quality=jpegq, p=1.0).to(device)

example_input = torch.randn((3, 224, 224)).to(device)
res = aug(example_input)

And here is an error about the wrong devices:

/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Traceback (most recent call last):
  File "/home/dmdr/Documents/Code/Python/aaa/ptrainer/tmp.py", line 27, in <module>
    res = aug(example_input)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/base.py", line 210, in forward
    output = self.apply_func(in_tensor, params, flags)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/_2d/base.py", line 129, in apply_func
    output = self.transform_inputs(in_tensor, params, flags, trans_matrix)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/base.py", line 261, in transform_inputs
    output = self.apply_transform(in_tensor, params, flags, transform=transform)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/_2d/intensity/jpeg.py", line 56, in apply_transform
    jpeg_output: Tensor = jpeg_codec_differentiable(input, params["jpeg_quality"])
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/utils/image.py", line 231, in _wrapper
    output = f(input, *args, **kwargs)
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/enhance/jpeg.py", line 484, in jpeg_codec_differentiable
    y_encoded, cb_encoded, cr_encoded = _jpeg_encode(
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/enhance/jpeg.py", line 281, in _jpeg_encode
    y_encoded: Tensor = _quantize(
  File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/enhance/jpeg.py", line 177, in _quantize
    quantization_table[:, None] * _jpeg_quality_to_scale(jpeg_quality)[:, None, None, None]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@edgarriba
Copy link
Member

@ditwoo thanks ! we'll try to fix unless you want to give it a shot

@edgarriba edgarriba added the bug 🐛 Something isn't working label Apr 7, 2024
@ditwoo
Copy link
Author

ditwoo commented Apr 8, 2024

@edgarriba I can write a PR with a fix.

@edgarriba
Copy link
Member

@ditwoo thanks ! very appreciated

@johnnv1 johnnv1 linked a pull request Apr 14, 2024 that will close this issue
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants