Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

long query throw "Dtype object" due to predefined max_length in the batch #95

Open
salrowili opened this issue Nov 8, 2023 · 0 comments

Comments

@salrowili
Copy link

salrowili commented Nov 8, 2023

I think there is a bug in src/tevatron/driver/jax_train.py this line :

https://github.com/texttron/tevatron/blob/0e939457444f78284ab0471da74a0c74bc76a833/src/tevatron/driver/jax_train.py#L147C43-L147C56

The issue is caused by defining the max_length to 32, assuming all queries will not exceed this length, and that creates a problem when we choose data_args.q_max_len >32. I have a custom dataset with a couple of examples where queries even reach ~ 128 max_length. It would be great if you could fix this issue because the error thrown by python3 is tricky and has no indication that the cause of the problem is due to this line. I have spent two days just to realize that this line is the root of the problem. I fixed the issue by setting the max_length to 128 instead of 32. I think one solution would be just to replace 32 with data_args.q_max_len :

 return dict(tokenizer.pad(qq, max_length=data_args.q_max_len, padding='max_length', return_tensors='np')), dict(
                tokenizer.pad(dd, max_length=data_args.p_max_len, padding='max_length', return_tensors='np'))

Thank you
Sultan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant