-
-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TabTransformer TypeError caused by MixtureDensityHead input_dim is None #424
Comments
I tried to hack a bit around and got the input_dim working but it seems MixtureDensityHead is not fully supported yet as an Option for TabTransformer? The loss calculation fails with
as y_hat is an tuple with pi, sigma & mu and not a tensor directly. |
TabTransformer does need some special processing to make it work for MixtureDensityNetworks... But I thought I had done that. Can you post a reproduceable and self-contained code to replicate the issue (including sample data etc)? Maybe there is some thing weird going on. It's weird seeing the |
Yeah i tried it quite naively so there is a high change the error is on me but here you go.
|
Okay..So I finally got some time to check this out.. MDN is supposed to be used separately with a different config and not just as a head. I'll add some protection against this usage This is how MDN models are used:
This should work |
Describe the bug
While experimenting with different settings and just passing
head="MixtuerDensityHead"
to the TabTransformerConfig I ran into thelib/python3.9/site-packages/torch/nn/modules/linear.py", line 96, in __init__ self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs)) TypeError: empty(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got NoneType"
This is caused as
self.pi = nn.Linear(self.hparams.input_dim, self.hparams.num_gaussian)
in the_build_network()
method of the MixtureDensityHead is called withself.hparams.input_dim
havingNone
value.To Reproduce
My TabTransformerConfig:
Versions
Python: 3.9
Pytorch-tabular: 1.1.0
The text was updated successfully, but these errors were encountered: