Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a -1 non-broadcastable constraint encoding for static shapes #1280

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Oct 31, 2022

This PR investigates the use of a -1 static shape value encoding to indicate that a dimension's shape value is not equal to one (per #1122).

Integers s are used to encode static shape constraints as follows:

  • s <= -2: the least amount of shape information (currently None)
  • s == -1: the shape is not equal to one
  • -1 < s: the most amount of shape information (i.e. the shape is strictly equal to s)

In its current form, this PR adds a TensorType.shape_encoded attribute and converts the old TensorType.shape into a computed property that simply converts the non-exact constraints (i.e. s < 0) to None.

  • Finish removing old uses of TensorType.broadcastable (especially when used as TensorType constructor arguments)
    The broadcastable argument to the TensorType constructor was removed and strict non-bool type checks were added in order to find all the places where True/False values are being used as shape values. This means that a lot of the tests will fail until those are all fixed.
    Basically, this PR is also serving as an investigation into all the logic that still relies on TensorType.broadcastable.
  • More direct tests of the new constraint
  • Propagate the new constraint in as many cases as possible/reasonable
  • Rebase on top of Replace some uses of TensorType.broadcastable with TensorType.shape #1297

aesara/tensor/type.py Outdated Show resolved Hide resolved
@brandonwillard brandonwillard force-pushed the add-non-one-Type-constraint branch 15 times, most recently from 1b025d6 to 68d4b8e Compare November 4, 2022 21:08
@codecov
Copy link

codecov bot commented Nov 4, 2022

Codecov Report

Merging #1280 (6454f54) into main (9ec71c8) will decrease coverage by 0.01%.
The diff coverage is 82.85%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1280      +/-   ##
==========================================
- Coverage   74.22%   74.21%   -0.02%     
==========================================
  Files         174      174              
  Lines       48731    48765      +34     
  Branches    10367    10381      +14     
==========================================
+ Hits        36170    36189      +19     
- Misses      10273    10285      +12     
- Partials     2288     2291       +3     
Impacted Files Coverage Δ
aesara/link/jax/dispatch/shape.py 78.04% <0.00%> (-8.44%) ⬇️
aesara/tensor/type.py 89.34% <86.48%> (-1.61%) ⬇️
aesara/tensor/shape.py 91.42% <92.30%> (-0.31%) ⬇️
aesara/link/numba/dispatch/basic.py 92.69% <100.00%> (+0.01%) ⬆️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add simple broadcastability shape constraints to TensorType
2 participants