Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert Eye to a COp or implement it in terms of existing COps #1177

Open
brandonwillard opened this issue Sep 13, 2022 · 10 comments · May be fixed by #1217
Open

Convert Eye to a COp or implement it in terms of existing COps #1177

brandonwillard opened this issue Sep 13, 2022 · 10 comments · May be fixed by #1217
Labels
C-backend help wanted Extra attention is needed Op implementation Involves the implementation of an Op performance concern

Comments

@brandonwillard
Copy link
Member

Eye doesn't have a C implementation, but adding one should be straightforward.

@brandonwillard brandonwillard added help wanted Extra attention is needed C-backend performance concern Op implementation Involves the implementation of an Op labels Sep 13, 2022
@ricardoV94
Copy link
Contributor

We can also try to use an OpFromGraph instead, it's just zeros + set_subtensor, no?

@brandonwillard brandonwillard changed the title Convert Eye to a COp Convert Eye to a COp or implement it in terms of existing COps Sep 13, 2022
@brandonwillard
Copy link
Member Author

We can also try to use an OpFromGraph instead, it's just zeros + set_subtensor, no?

That might not be as performant as a COp, but I like it a lot more. Aside from being easier to implement, it could work well with our other transpilation targets (e.g. the result could be about as fast as direct use of np.eye in Numba and/or JAX).

@jessegrabowski
Copy link
Contributor

jessegrabowski commented Sep 27, 2022

I believe the following is an implementation using OpFromGraph:

n = at.iscalar('n')
m = at.iscalar('m')
k = at.iscalar('k')

i = at.switch(k >= 0, k, -k * m)
eye = at.zeros(n * m)
eye = at.set_subtensor(eye[i::m+1], 1).reshape((n, m))
eye = at.set_subtensor(eye[m-k:, :], 0)

Eye = aesara.compile.builders.OpFromGraph([n, m, k], [eye])

The existing eye function in aesara.tensor could then be modified as follows:

def eye(n, m=None, k=0, dtype=None):
    """Return a 2-D array with ones on the diagonal and zeros elsewhere.
    Parameters
    ----------
    n : int
        Number of rows in the output.
    m : int, optional
        Number of columns in the output. If None, defaults to `N`.
    k : int, optional
        Index of the diagonal: 0 (the default) refers to the main diagonal,
        a positive value refers to an upper diagonal, and a negative value
        to a lower diagonal.
    dtype : data-type, optional
        Data-type of the returned array.
    Returns
    -------
    ndarray of shape (N,M)
        An array where all elements are equal to zero, except for the `k`-th
        diagonal, whose values are equal to one.
    """
    if dtype is None:
        dtype = aesara.config.floatX
    if m is None:
        m = n
    return Eye(n, m, k).astype(dtype)

I can add some tests to check corner cases and submit this as a pull request if it looks like I'm barking up the right tree? Where would I add the code to create the OpFromGraph, just floating inside aesara.tensor above def eye, or is there a more organized place to put it?

@ricardoV94
Copy link
Contributor

Looks about right. About where to put it... good question. Floating sounds about right, maybe add an underscore prefix to those intermediate variables?

Other thing worth checking is if any rewrites currently target the Eye Op and if they still work.

@jessegrabowski
Copy link
Contributor

How would I check for rewrites that target Eye? I did ctrl+f on all the files in aesara.tensor.rewriting for Eye (and @node_rewriter([Eye]) ) and came up with nothing. This doesn't strike me as a very sophisticated way to check, though.

@ricardoV94
Copy link
Contributor

ricardoV94 commented Sep 27, 2022

How would I check for rewrites that target Eye? I did ctrl+f on all the files in aesara.tensor.rewriting for Eye (and @node_rewriter([Eye]) ) and came up with nothing. This doesn't strike me as a very sophisticated way to check, though.

Sounds about right. There might not be any in which case you are in luck ;)

Edit: I didn't find anything either

@ricardoV94
Copy link
Contributor

ricardoV94 commented Sep 27, 2022

You can also get rid of the Numba and JAX dispatch (I assume a dispatch for OpFromGraph is already been implemented)

Edit: It seems to be only for Numba...

@ricardoV94
Copy link
Contributor

ricardoV94 commented Sep 27, 2022

Actually we might not even need an OpFromGraph, if we don't need to easily target it in rewrites and if we are not overriding the grad. You can just make eye a function that returns the correct Aesara symbolic expression. This also makes the dtype of the inputs more flexible, instead of constraining them to int32 in your example.

@ricardoV94
Copy link
Contributor

import aesara
import aesara.tensor as at
def eye_new(n, m=None, k=0, dtype=None):
    if m is None:
        m = n
    if dtype is None:
        dtype = aesara.config.floatX
        
    n = at.as_tensor_variable(n)
    m = at.as_tensor_variable(m)
    k = at.as_tensor_variable(k)
    
    i = at.switch(k >= 0, k, -k * m)
    eye = at.zeros(n * m, dtype=dtype)
    eye = at.set_subtensor(eye[i::m + 1], 1).reshape((n, m))
    eye = at.set_subtensor(eye[m - k:, :], 0)
    return eye    

Seems to do alright

@jessegrabowski
Copy link
Contributor

I'll make a pull request for this in a minute, I'm just fumbling around with git at the moment.

@jessegrabowski jessegrabowski linked a pull request Sep 27, 2022 that will close this issue
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
C-backend help wanted Extra attention is needed Op implementation Involves the implementation of an Op performance concern
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants