Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ImportError within FasterAutoAugmentSearcher #17

Open
adam-mehdi opened this issue Apr 18, 2021 · 1 comment
Open

ImportError within FasterAutoAugmentSearcher #17

adam-mehdi opened this issue Apr 18, 2021 · 1 comment

Comments

@adam-mehdi
Copy link

Hello,

AutoAlbument's FasterAutoAugmentSearcher is yielding an import error cannot import name 'Batch' from 'torchtext.data' when searching for a policy. This bug recently occurred in PyTorch Lightning in general, and it is fixed by importing Lightning from github: pip install git+https://github.com/PyTorchLightning/pytorch-lightning instead of pip install pytorch-lightning. I suspect that the version of Lightning in the implementation of FasterAutoAugmentSearcher is must be upgraded.

Here's an example of the problem using the CIFAR100 dataset.

AutoAlbument Search

!pip install -U git+https://github.com/albumentations-team/autoalbument
!autoalbument-create --config-dir /content/ --task classification --num-classes 100
!autoalbument-search --config-dir /content

dataset.py

import torch.utils.data
from torchvision.datasets import CIFAR100

class Cifar10SearchDataset(torchvision.datasets.CIFAR100):
    def __init__(self, root="content/cifar100", train=True, download=True, transform=None):
        super().__init__(root=root, train=train, download=download, transform=transform)
    
    def __len__(self): return len(self.targets)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label

Search Output

_version: 2
task: classification
policy_model:
  task_factor: 0.1
  gp_factor: 10
  temperature: 0.05
  num_sub_policies: 40
  num_chunks: 4
  operation_count: 4
  operations:
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftRGB
    shift_r: true
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftRGB
    shift_g: true
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftRGB
    shift_b: true
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.RandomBrightness
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.RandomContrast
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.Solarize
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.HorizontalFlip
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.VerticalFlip
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.Rotate
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftX
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftY
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.Scale
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.CutoutFixedNumberOfHoles
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.CutoutFixedSize
classification_model:
  _target_: autoalbument.faster_autoaugment.models.ClassificationModel
  num_classes: 100
  architecture: resnet18
  pretrained: false
data:
  dataset:
    _target_: dataset.SearchDataset
  input_dtype: uint8
  preprocessing: null
  normalization:
    mean:
    - 0.485
    - 0.456
    - 0.406
    std:
    - 0.229
    - 0.224
    - 0.225
  dataloader:
    _target_: torch.utils.data.DataLoader
    batch_size: 16
    shuffle: true
    num_workers: 8
    pin_memory: true
    drop_last: true
searcher:
  _target_: autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher
trainer:
  _target_: pytorch_lightning.Trainer
  gpus: 1
  benchmark: true
  max_epochs: 20
  resume_from_checkpoint: null
optim:
  main:
    _target_: torch.optim.Adam
    lr: 0.001
    betas:
    - 0
    - 0.999
  policy:
    _target_: torch.optim.Adam
    lr: 0.001
    betas:
    - 0
    - 0.999
callbacks:
- _target_: autoalbument.callbacks.MonitorAverageParameterChange
- _target_: autoalbument.callbacks.SavePolicy
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
  save_last: true
  dirpath: checkpoints
logger:
  _target_: pytorch_lightning.loggers.TensorBoardLogger
  save_dir: /content/outputs/2021-04-18/13-18-49/tensorboard_logs
seed: 42

Working directory: /content/outputs/2021-04-18/13-18-49
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/hydra/_internal/utils.py", line 544, in _locate
    import_module(mod)
  File "/usr/lib/python3.7/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1006, in _gcd_import
  File "<frozen importlib._bootstrap>", line 983, in _find_and_load
  File "<frozen importlib._bootstrap>", line 967, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 677, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 728, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/usr/local/lib/python3.7/dist-packages/autoalbument/faster_autoaugment/search.py", line 2, in <module>
    from pytorch_lightning import seed_everything
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/__init__.py", line 66, in <module>
    from pytorch_lightning import metrics
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/metrics/__init__.py", line 14, in <module>
    from pytorch_lightning.metrics.metric import Metric
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/metrics/metric.py", line 23, in <module>
    from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/metrics/utils.py", line 18, in <module>
    from pytorch_lightning.utilities import rank_zero_warn
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/__init__.py", line 24, in <module>
    from pytorch_lightning.utilities.apply_func import move_data_to_device
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/apply_func.py", line 25, in <module>
    from torchtext.data import Batch
ImportError: cannot import name 'Batch' from 'torchtext.data' (/usr/local/lib/python3.7/dist-packages/torchtext/data/__init__.py)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/hydra/utils.py", line 61, in call
    type_or_callable = _locate(cls)
  File "/usr/local/lib/python3.7/dist-packages/hydra/_internal/utils.py", line 548, in _locate
    ) from e
ImportError: Encountered error: `cannot import name 'Batch' from 'torchtext.data' (/usr/local/lib/python3.7/dist-packages/torchtext/data/__init__.py)` when loading module 'autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/autoalbument/cli/search.py", line 54, in main
    searcher = instantiate(cfg.searcher, cfg=cfg)
  File "/usr/local/lib/python3.7/dist-packages/hydra/utils.py", line 70, in call
    raise HydraException(f"Error calling '{cls}' : {e}") from e
hydra.errors.HydraException: Error calling 'autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher' : Encountered error: `cannot import name 'Batch' from 'torchtext.data' (/usr/local/lib/python3.7/dist-packages/torchtext/data/__init__.py)` when loading module 'autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher'

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
@aliochat
Copy link

pip install torchtext==0.8.1 solved the issue for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants