You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using multiple validation datasets with transformers.Trainer and setting dataloader_persistent_workers=True in the transformers.TrainingArguments, all evaluations are done using the first validation dataset.
In the example above, the model only learns to predict the class 0, so we should have a big loss for the "bad" validation dataset and a small one for the "good" one.
This seems related to #28469 and #29538; which does not support passing a dictionary of evaluation datasets :
# def get_eval_dataloader in src/transformers/trainer.pyifhasattr(self, "_eval_dataloader") andself.args.dataloader_persistent_workers:
returnself.accelerator.prepare(self._eval_dataloader)
The evaluation dataloaders should probably also be stored in a dictionary, or the _eval_dataloader attribute should be suffixed with the eval_dataset_name.
I can look into opening a PR for this.
The text was updated successfully, but these errors were encountered:
System Info
transformers
version: 4.40.1Who can help?
@muellerzr @pacman100
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
With
dataloader_persistent_workers=True
:With
dataloader_persistent_workers=False
:Expected behavior
Hi there,
When using multiple validation datasets with
transformers.Trainer
and settingdataloader_persistent_workers=True
in thetransformers.TrainingArguments
, all evaluations are done using the first validation dataset.In the example above, the model only learns to predict the class
0
, so we should have a big loss for the "bad" validation dataset and a small one for the "good" one.This seems related to #28469 and #29538; which does not support passing a dictionary of evaluation datasets :
The evaluation dataloaders should probably also be stored in a dictionary, or the
_eval_dataloader
attribute should be suffixed with theeval_dataset_name
.I can look into opening a PR for this.
The text was updated successfully, but these errors were encountered: