Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPerClassSampler: 'numpy.int32' object is not iterable #611

Closed
rheum opened this issue Apr 17, 2023 · 4 comments
Closed

MPerClassSampler: 'numpy.int32' object is not iterable #611

rheum opened this issue Apr 17, 2023 · 4 comments
Labels
question A general question about the library

Comments

@rheum
Copy link
Contributor

rheum commented Apr 17, 2023

Hey all. I believe I have found a bug with MPerClassSampler. Using a dataloader which makes use of that sampler throws an exception. As far as I can tell, my HTR code meets all documented preconditions of MPerClassSampler.

My setup:
pytorch_metric_learning.version == "2.1.0"
torch.version == '1.13.1+cu116'

How to reproduce:

from torch.utils.data import TensorDataset, DataLoader
from pytorch_metric_learning.samplers import MPerClassSampler

ds = TensorDataset(torch.Tensor([[1], [2], [3]]), torch.Tensor([0, 0, 1]))
batch_sampler = MPerClassSampler(labels=[0, 0, 1], m=1, batch_size=2)
train_loader = DataLoader(ds, num_workers=0, batch_sampler=batch_sampler)
next(iter(train_loader))
Traceback (most recent call last):
  File "...\AppData\Roaming\Python\Python310\site-packages\IPython\core\interactiveshell.py", line 3460, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-11-0fd2476fd2ff>", line 1, in <module>
    next(iter(train_loader))
  File "...\AppData\Roaming\Python\Python310\site-packages\torch\utils\data\dataloader.py", line 628, in __next__
    data = self._next_data()
  File ...\AppData\Roaming\Python\Python310\site-packages\torch\utils\data\dataloader.py", line 671, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "...\AppData\Roaming\Python\Python310\site-packages\torch\utils\data\_utils\fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
TypeError: 'numpy.int32' object is not iterable
@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Apr 17, 2023

Can you try changing batch_sampler=batch_sampler to sampler=batch_sampler ?

@rheum
Copy link
Contributor Author

rheum commented Apr 17, 2023

Thanks for the quick response!
While there's no longer an exception, the batch_size doesn't seem to work:

image

@KevinMusgrave
Copy link
Owner

You have to pass in batch_size=2 into the DataLoader as well.

The reason is that MPerClassSampler is a sampler not a batch_sampler, so the batch size is actually determined by DataLoader.

The batch_size argument for MPerClassSampler is just for internal checks.

It would make more sense for it to be a batch_sampler, then you wouldn't have to specify batch_size in 2 places. I think this has come up a few times, so I've made a separate issue for fixing this: #612

@rheum
Copy link
Contributor Author

rheum commented Apr 18, 2023

Thank you. Turns out I was distracted by the explicit mention of batch_sampler in the docs for HierarchicalSampler, whereas for MPerClassSampler there is no information on how it is used. Maybe as a quick-fix, this could be clarified in the doc. I'll create a PR later today.

@rheum rheum closed this as completed Apr 18, 2023
@KevinMusgrave KevinMusgrave added the question A general question about the library label Apr 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question A general question about the library
Projects
None yet
Development

No branches or pull requests

2 participants