Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quaternion Networks - Bug Fixes & Improvements #2464

Open
wants to merge 20 commits into
base: develop
Choose a base branch
from

Conversation

Drew-Wagner
Copy link

What does this PR do?

This PR fixes several bugs which prevented the use of the quaternion network modules, and completes the collection by implementing avg and max pooling. No tests existed for quaternion networks. This PR introduces a minimum (and incomplete) set of tests.

Several bugs were present in the existing quaternion network modules and are fixed by this PR:

  • Bias was uninitialized for Conv2d and Conv1d when bias=False was set, leading to nan values in forward and / or backward pass.
  • QBatchNorm training implementation did not match evaluation implementation: the mean was not being subtracted from the input.
  • In QBatchNorm, a view was being used incorrectly to broadcast tensors together.
  • Performance issue with QBatchNorm: torch.rsqrt is several times faster than 1 / torch.sqrt (I measured between a 3-8x speedup)
  • Groups for QConv modules were broken: the input_channels for the weights must be divided by the # of groups.

Several adjustments were made to improve compatibility of the QConv interface with regular convolution modules:

  • The swap option was added for QConv2d
  • max_norm option was added for QConv and QLinear modules
  • The bias parameter / buffer was renamed from .b to .bias

A QPooling2d module is added which implements:

  • component-wise average pooling
  • max pooling by magnitude of the quaternions

Breaking Changes:

  • The bias parameter of QConv and QLinear modules was renamed from .b to .bias, however given the number of bugs present, it seems unlikely that anyone was depending on this.
  • The # of input & output channels for QConv modules must be divisible by the # of groups.
Before submitting
  • Did you read the contributor guideline?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Does your code adhere to project-specific code style and conventions?

PR review

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified
  • Confirm that the changes adhere to compatibility requirements (e.g., Python version, platform)
  • Review the self-review checklist to ensure the code is ready for review

@mravanelli mravanelli added the bug Something isn't working label Mar 19, 2024
Copy link
Collaborator

@TParcollet TParcollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM, see my questions.

@@ -189,6 +195,21 @@ def forward(self, x):
"""
# (batch, channel, time)
x = x.transpose(1, -1)

if self.max_norm is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is renorming individual quaternion components actually is strictly equivalent to renorming the quaternion?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I don't believe it is strictly equivalent. I couldn't find any references for how to approach it so I went with the simplest idea which was the component-wise renorm.

speechbrain/nnet/quaternion_networks/q_CNN.py Outdated Show resolved Hide resolved
speechbrain/nnet/quaternion_networks/q_normalization.py Outdated Show resolved Hide resolved
@mravanelli
Copy link
Collaborator

@Drew-Wagner, could you please fix the conflicts, merge the latest development, and do the last modifications?

@Drew-Wagner Drew-Wagner force-pushed the quaternion-network-improvements branch from a95986e to d7b5b7f Compare April 14, 2024 14:55
Comment on lines +858 to +886
def renorm_quaternion_weights_inplace(
r_weight, i_weight, j_weight, k_weight, max_norm
):
"""Renorms the magnitude of the quaternion-valued weights.

Arguments
---------
r_weight : torch.Parameter
i_weight : torch.Parameter
j_weight : torch.Parameter
k_weight : torch.Parameter
max_norm : float
The maximum norm of the magnitude of the quaternion weights
"""
weight_magnitude = torch.sqrt(
r_weight.data**2
+ i_weight.data**2
+ j_weight.data**2
+ k_weight.data**2
)
renormed_weight_magnitude = torch.renorm(
weight_magnitude, p=2, dim=0, maxnorm=max_norm
)
factor = renormed_weight_magnitude / weight_magnitude

r_weight.data *= factor
i_weight.data *= factor
j_weight.data *= factor
k_weight.data *= factor
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TParcollet Please review this implementation which renorms the weights according to the magnitude of the quaternions, rather than by individual components

@Drew-Wagner Drew-Wagner force-pushed the quaternion-network-improvements branch from d7b5b7f to 0697c73 Compare April 17, 2024 20:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants