New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Quaternion Networks - Bug Fixes & Improvements #2464
base: develop
Are you sure you want to change the base?
Quaternion Networks - Bug Fixes & Improvements #2464
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM, see my questions.
@@ -189,6 +195,21 @@ def forward(self, x): | |||
""" | |||
# (batch, channel, time) | |||
x = x.transpose(1, -1) | |||
|
|||
if self.max_norm is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is renorming individual quaternion components actually is strictly equivalent to renorming the quaternion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I don't believe it is strictly equivalent. I couldn't find any references for how to approach it so I went with the simplest idea which was the component-wise renorm.
@Drew-Wagner, could you please fix the conflicts, merge the latest development, and do the last modifications? |
a95986e
to
d7b5b7f
Compare
def renorm_quaternion_weights_inplace( | ||
r_weight, i_weight, j_weight, k_weight, max_norm | ||
): | ||
"""Renorms the magnitude of the quaternion-valued weights. | ||
|
||
Arguments | ||
--------- | ||
r_weight : torch.Parameter | ||
i_weight : torch.Parameter | ||
j_weight : torch.Parameter | ||
k_weight : torch.Parameter | ||
max_norm : float | ||
The maximum norm of the magnitude of the quaternion weights | ||
""" | ||
weight_magnitude = torch.sqrt( | ||
r_weight.data**2 | ||
+ i_weight.data**2 | ||
+ j_weight.data**2 | ||
+ k_weight.data**2 | ||
) | ||
renormed_weight_magnitude = torch.renorm( | ||
weight_magnitude, p=2, dim=0, maxnorm=max_norm | ||
) | ||
factor = renormed_weight_magnitude / weight_magnitude | ||
|
||
r_weight.data *= factor | ||
i_weight.data *= factor | ||
j_weight.data *= factor | ||
k_weight.data *= factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TParcollet Please review this implementation which renorms the weights according to the magnitude of the quaternions, rather than by individual components
A view was incorrectly being applied to broadcast tensors together
- adds max_norm option
- max_norm - swap - rename .b -> .bias
- in_channels must be divided by groups when creating kernels - Add checks to ensure divisibility
- the mean was not being subtracted from the input
- rqsrt was 8x faster
d7b5b7f
to
0697c73
Compare
What does this PR do?
This PR fixes several bugs which prevented the use of the quaternion network modules, and completes the collection by implementing avg and max pooling. No tests existed for quaternion networks. This PR introduces a minimum (and incomplete) set of tests.
Several bugs were present in the existing quaternion network modules and are fixed by this PR:
Several adjustments were made to improve compatibility of the QConv interface with regular convolution modules:
swap
option was added for QConv2dmax_norm
option was added for QConv and QLinear modules.b
to.bias
A QPooling2d module is added which implements:
Breaking Changes:
.b
to.bias
, however given the number of bugs present, it seems unlikely that anyone was depending on this.Before submitting
PR review
Reviewer checklist