New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Suggestion with Solution about loss function #605
Comments
This is just an idea, I am a huge fan of hydra template, so I decided to do my first contribution. If you guys think this is a good idea to do so, I will fix pytest part. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I believe it is advantageous for us to separate the loss and weight addition in
./src/models/mnist_module.py
.For the original code, it uses
self.criterion = torch.nn.CrossEntropyLoss()
to be the only loss function inloss = self.criterion(logits, y)
.However, I think there is a better way to do so; if we change loss function to be
loss_fns: list[torch.nn.Module]
, then usinglist[dict]
can provide more flexible for user.For example, in model step, I do
This revised approach retains the functionality designed by you but allows greater loss function inclusion. Users simply need to populate their custom loss function into
src/models/components/loss_fn.py
, and the rest is taken care of.The text was updated successfully, but these errors were encountered: