Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

einops-like packing notation #180

Open
MilesCranmer opened this issue Feb 20, 2024 · 4 comments
Open

einops-like packing notation #180

MilesCranmer opened this issue Feb 20, 2024 · 4 comments
Labels
feature New feature

Comments

@MilesCranmer
Copy link

MilesCranmer commented Feb 20, 2024

Hey @patrick-kidger,

I'm wondering how hard it would be to have einops-like notation for packed axes? For example,

Float[Array, "B C (H W)"]

would indicate that the last axis is a flattened version of the height and width axis.

This means that if you have the full signature as:

def unpack(x: Float[Array, "B C (H W)"]) -> Float[Array, "B C H W"]:
    ... # magic unpacking
    return y

then jaxtyping would check that y.shape[2] * y.shape[3] == x.shape[2].

Note that in many cases it would not be able to confirm H and W individually. I think that is okay; it's just free variables. But if it can confirm the individual shapes, then it can do the type check.

What do you think? Does this make sense?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Feb 20, 2024

What you can do today

So it's a little less elegant, but you can do this today via

def unpack(x: Float[Array, "B C H_W"]) -> Float[Array, "B C H (H_W//H)"]:
    ... # magic unpacking
    return y

where the ( ) brackets are just optional for readability. And this is already runtime checkable!

The reverse direction is a bit neater, as the logic (which should typically go on the RHS) is easier to read:

def pack(x: Float[Array, "B C H W"]) -> Float[Array, "B C H*W"]:
    ... # magic unpacking
    return y

Note that this works because in each case the output shape is a function of the input shape -- rather than the other way around!

What we could do tomorrow

First of all, just on the syntax: I think if we were to support something like this, then I'd probably suggest using the syntax H*W rather than (H W). This is because it's the syntax we already have!

And in fact, we can actually already write this:

def unpack(x: Float[Array, "B C H*W"]) -> Float[Array, "B C H W"]:
    ... # magic unpacking
    return y

but this would raise an error if you were to do runtime type-checking, as it won't have seen H and W when you first call the function.

So if were to change anything, I think it would probably be to (a) allow such "incomplete" annotations when checking the arguments, and then to (b) go back and check them all again after the function has finished running and we have its output.

WDYT?

@MilesCranmer
Copy link
Author

Cool! Thanks, that is great that this is already possible! I think this will already be very useful enough for me. The second option is perhaps a bit nicer if it is not too difficult to add. i.e., doing H*W -> H W is perhaps a fractional amount better than HW -> H (HW//H) but if it increases code complexity too much, maybe it’s not worth it?

@patrick-kidger
Copy link
Owner

I'd be happy to add the second notation, I'd just have to ask for a PR on it as I don't have the time to implement this myself :D

If you or anyone else feels strongly about this, then I'd be happy to explain how to tweak the jaxtyping internals to accomplish this.

@MilesCranmer
Copy link
Author

For posterity I'm also time deficient at the moment due to teaching. (Anybody reading this thread; feel free to take a stab at this!)

@patrick-kidger patrick-kidger added the feature New feature label Feb 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants