Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fatal error: math.h: No such file or directory #28

Open
snakers4 opened this issue Sep 15, 2021 · 2 comments
Open

fatal error: math.h: No such file or directory #28

snakers4 opened this issue Sep 15, 2021 · 2 comments

Comments

@snakers4
Copy link

Hi,

I am trying to run Taylor Softmax.

(0)

I run the python3 setup.py install and get:

root@7c09a3f30c39:/home/keras/notebook/nvme_raid/aveysov/pytorch-loss# python3 setup.py install
running install
running bdist_egg
running egg_info
creating pytorch_loss.egg-info
writing pytorch_loss.egg-info/PKG-INFO
writing dependency_links to pytorch_loss.egg-info/dependency_links.txt
writing top-level names to pytorch_loss.egg-info/top_level.txt
writing manifest file 'pytorch_loss.egg-info/SOURCES.txt'
reading manifest file 'pytorch_loss.egg-info/SOURCES.txt'
writing manifest file 'pytorch_loss.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib.linux-x86_64-3.7
creating build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/swish.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/frelu.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/generalized_iou_loss.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/pc_softmax.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/focal_loss_old.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/focal_loss.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/one_hot.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/soft_dice_loss.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/amsoftmax.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/taylor_softmax.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/triplet_loss.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/__init__.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/label_smooth.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/hswish.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/ema.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/test.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/dice_loss.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/large_margin_softmax.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/lovasz_softmax.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/mish.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/conv_ops.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/ohem_loss.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/affinity_loss.py -> build/lib.linux-x86_64-3.7/pytorch_loss
copying pytorch_loss/dual_focal_loss.py -> build/lib.linux-x86_64-3.7/pytorch_loss
running build_ext
building 'focal_cpp' extension
creating /home/keras/notebook/nvme_raid/aveysov/pytorch-loss/build/temp.linux-x86_64-3.7
creating /home/keras/notebook/nvme_raid/aveysov/pytorch-loss/build/temp.linux-x86_64-3.7/csrc
Emitting ninja build file /home/keras/notebook/nvme_raid/aveysov/pytorch-loss/build/temp.linux-x86_64-3.7/build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/1] /usr/local/cuda/bin/nvcc  -I/opt/conda/lib/python3.7/site-packages/torch/include -I/opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.7/site-packages/torch/include/TH -I/opt/conda/lib/python3.7/sit
e-packages/torch/include/THC -I/usr/local/cuda/include -I/opt/conda/include/python3.7m -c -c /home/keras/notebook/nvme_raid/aveysov/pytorch-loss/csrc/focal_kernel.cu -o /home/keras/notebook/nvme_raid/aveysov/pytorch-loss/build/temp.linux-x86_64-3.7/csrc
/focal_kernel.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPIL
ER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=focal_cpp -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++14
FAILED: /home/keras/notebook/nvme_raid/aveysov/pytorch-loss/build/temp.linux-x86_64-3.7/csrc/focal_kernel.o
/usr/local/cuda/bin/nvcc  -I/opt/conda/lib/python3.7/site-packages/torch/include -I/opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.7/site-packages/torch/include/TH -I/opt/conda/lib/python3.7/site-pack
ages/torch/include/THC -I/usr/local/cuda/include -I/opt/conda/include/python3.7m -c -c /home/keras/notebook/nvme_raid/aveysov/pytorch-loss/csrc/focal_kernel.cu -o /home/keras/notebook/nvme_raid/aveysov/pytorch-loss/build/temp.linux-x86_64-3.7/csrc/focal
_kernel.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYP
E="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=focal_cpp -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++14
In file included from /usr/local/cuda/include/crt/math_functions.h:8958:0,
                 from /usr/local/cuda/include/crt/common_functions.h:295,
                 from /usr/local/cuda/include/cuda_runtime.h:115,
                 from <command-line>:0:
/usr/include/c++/7/cmath:45:15: fatal error: math.h: No such file or directory

compilation terminated.
ninja: build stopped: subcommand failed.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1672, in _run_ninja_build
    env=env)
  File "/opt/conda/lib/python3.7/subprocess.py", line 512, in run
    output=stdout, stderr=stderr)
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

I run the python3 setup.py install command in my dockerized research environment, which is derived from the official PyTorch GPU images:

ARG BASE_IMAGE=pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel
FROM $BASE_IMAGE

I remember when I faced similar problems in the past, I did something like this for compilation of some CUDA kernels, but then I removed these lines (it was a while ago!):

RUN apt-get install gcc-5 g++-5 g++-5-multilib gfortran-5 -y && \
    update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 60 --slave /usr/bin/g++ g++ /usr/bin/g++-5 --slave /usr/bin/gfortran gfortran /usr/bin/gfortran-5 && \
    update-alternatives --query gcc
RUN gcc --version

Could you maybe elaborate a bit here, since I am not very familiar with how the C++ ecosystem works.

(1)
As far as I see there is a standard autograd implementation and a custom CUDA implementation.
Since I am not very proficient with C++ and CUDA, may I ask what was the reasoning behind adding a custom CUDA kernel, was the autograd version too slow, or memory intensive?

Many thanks for you advice and code!

@CoinCheung
Copy link
Owner

Hi,

Would you roll back to cuda 10.2 and try again? I have not tried cuda11, there might be unknown problems.

@CoinCheung
Copy link
Owner

I write cuda kernels because it would be a bit faster and memory efficient in some occasions. For example, when you implement label-smooth cross entropy with pytorch, you might need an one-hot tensor which requires more memory than plain label(enlarged C times, where C is number of classes). And if you would like your loss to skip ignore_labels, you would use another operator, with which you call cuda kernel again to access gpu global memory one more time, which brings some performance overheads. These problems would be partially avoided by implementing a customer cuda kernel, which in theory would save some memory and let you train your model faster.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants