Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute dimension sums in Elemwise.grad at run-time #1260

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Oct 16, 2022

This PR provides initial fixes for #1089 by supporting ambiguous broadcasting cases. In its current state, this PR is an investigation into the requirements and unforeseen issues behind adding support for those cases.

The approach used here moves the compile/construction-time conditional sum logic in Elemwise.grad into the graph in order to handle dimensions lacking complete broadcast information (see here). In this context, having complete broadcast information means that we know definitively whether or not the shape of a dimension is 1 (i.e. we know var.type.shape[d] != 1 for a dimension d).

Since we want to interpret var.type.shape[d] == None as var.type.shape[d] != 1 at some points, we need to address the old logic in Elemwise.grad when it encounters Variables with TensorTypes that cannot be used to determine which axes need to be summed when the gradient graph is constructed. By moving the axes summing conditions into the graph, we can always perform the correct computations, especially when the requisite broadcasting information is only available at run-time.

More importantly, the extra logic can be removed whenever the requisite information becomes available at compile-time. By using existing Ops, we can leverage existing rewrites and our shape inference to automatically remove the extra logic. How we choose to represent that logic (i.e. which Ops to use) will be important, so expect this PR to iterate on that.

One important point that needs to be addressed—for at least an ifelse approach—is the forking of conditional branches introduced by nested conditionals and iterated gradients. There's ultimately no avoiding this issue when the requisite shape information isn't available, but there are definitely better ways to handle it.
Regardless, we can always make different default assumptions (e.g. that var.type.shape != 1 in certain cases) and/or represent constraints that provide complete broadcast information and, as a result, produce graphs without the extra logic in the vast majority of use cases. #1122 and #1170 cover this topic in different ways. (N.B. The topic is effectively independent of the new support being added in this PR, although the issue this PR addresses could—and is—largely mitigated by such changes.)

  • It looks like second order gradients generate graphs that make use of DimShuffle.input_broadcastable and DimShuffle's broadcastable dimension dropping feature, and that currently requires complete broadcasting information. DimShuffle will probably need to be changed so that it can at least attempt to drop dimensions that aren't known to be broadcastable.
    This issue has to do with gradients of IfElse when one branch is broadcastable and the other isn't. More specifically, when the IfElse gradients are evaluated at a point that conforms to/is intended for only one branch and its distinct shape. As a work-around, specify_shape is being used to make sure that the gradients in each branch match their original shapes, although something significantly better seems possible, but we need to investigate that separately.
  • Update the static shape information provided during tests (e.g. set shape=(..., 1, ...) when broadcastable input values are used). This will allow us to produce the original graphs expected by the tests.
  • Confirm that new gradient graphs with adequate broadcast information are reduced to equivalent graphs without the extra logic

@brandonwillard brandonwillard marked this pull request as draft October 16, 2022 01:48
@brandonwillard brandonwillard changed the title Fix deprecated TensorType.broadcastable usage in Elemwise.grad and RandomVariable.make_node Fix deprecated TensorType.broadcastable usage in Elemwise.grad Oct 16, 2022
@brandonwillard brandonwillard added Op implementation Involves the implementation of an Op bug Something isn't working labels Oct 16, 2022
@ricardoV94

This comment was marked as off-topic.

@brandonwillard

This comment was marked as off-topic.

@brandonwillard brandonwillard changed the title Fix deprecated TensorType.broadcastable usage in Elemwise.grad Compute dimension sums at run-time in Elemwise.grad results Nov 4, 2022
@brandonwillard brandonwillard changed the title Compute dimension sums at run-time in Elemwise.grad results Compute dimension sums in Elemwise.grad at run-time Jan 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working important Op implementation Involves the implementation of an Op shape inference
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants