Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kron specialisation for sparse and dense mixes. #2391

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions qutip/core/data/kron.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
#cython: language_level=3

from qutip.core.data.csr cimport CSR
from qutip.core.data.dense cimport Dense
from qutip.core.data.dia cimport Dia
from qutip.core.data.base cimport Data

cpdef Dense kron_dense(Dense left, Dense right)
cpdef CSR kron_csr(CSR left, CSR right)
cpdef Dia kron_dia(Dia left, Dia right)
cpdef CSR kron_dense_csr_csr(Dense left, CSR right)
cpdef CSR kron_csr_dense_csr(CSR left, Dense right)
cpdef Dia kron_dia_dense_dia(Dia left, Dense right)
cpdef Dia kron_dense_dia_dia(Dense left, Dia right)

cpdef Data kron_transpose_data(Data left, Data right)
cpdef Dense kron_transpose_dense(Dense left, Dense right)
38 changes: 38 additions & 0 deletions qutip/core/data/kron.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import numpy

__all__ = [
'kron', 'kron_csr', 'kron_dense', 'kron_dia',
'kron_csr_dense_csr', 'kron_dense_csr_csr',
'kron_dia_dense_dia', 'kron_dense_dia_dia',
'kron_transpose', 'kron_transpose_dense', 'kron_transpose_data',
]

Expand Down Expand Up @@ -70,6 +72,22 @@ cpdef CSR kron_csr(CSR left, CSR right):
return out


cpdef CSR kron_csr_dense_csr(CSR left, Dense right):
# The dispatcher would use kron_dense, but the output is at least as sparse
# as the sparse input. Since the dispatcher does not have precise control
# on which function to use when the signature is missing. We add
# then like this.
return kron_csr(left, _to(CSR, right))


cpdef CSR kron_dense_csr_csr(Dense left, CSR right):
# The dispatcher would use kron_dense, but the output is at least as sparse
# as the sparse input. Since the dispatcher does not have precise control
# on which function to use when the signature is missing. We add
# then like this.
return kron_csr(_to(CSR, left), right)


cdef inline void _vec_kron(
double complex * ptr_l, double complex * ptr_r, double complex * ptr_out,
idxint size_l, idxint size_r, idxint step
Expand Down Expand Up @@ -159,6 +177,22 @@ cpdef Dia kron_dia(Dia left, Dia right):
return out


cpdef Dia kron_dia_dense_dia(Dia left, Dense right):
# The dispatcher would use kron_dense, but the output is at least as sparse
# as the sparse input. Since the dispatcher does not have precise control
# on which function to use when the signature is missing. We add
# then like this.
return kron_dia(left, _to(Dia, right))


cpdef Dia kron_dense_dia_dia(Dense left, Dia right):
# The dispatcher would use kron_dense, but the output is at least as sparse
# as the sparse input. Since the dispatcher does not have precise control
# on which function to use when the signature is missing. We add
# then like this.
return kron_dia(_to(Dia, left), right)


from .dispatch import Dispatcher as _Dispatcher
import inspect as _inspect

Expand All @@ -181,6 +215,10 @@ kron.add_specialisations([
(CSR, CSR, CSR, kron_csr),
(Dense, Dense, Dense, kron_dense),
(Dia, Dia, Dia, kron_dia),
(CSR, Dense, CSR, kron_csr_dense_csr),
(Dense, CSR, CSR, kron_dense_csr_csr),
(Dia, Dense, Dia, kron_dia_dense_dia),
(Dense, Dia, Dia, kron_dense_dia_dia),
], _defer=True)


Expand Down
2 changes: 1 addition & 1 deletion qutip/core/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,4 +467,4 @@ def expand_operator(oper, dims, targets, dtype=None):
for i, ind in enumerate(rest_pos):
new_order[ind] = rest_qubits[i]
id_list = [identity(dims[i]) for i in rest_pos]
return tensor([oper] + id_list).permute(new_order)
return tensor([oper] + id_list).permute(new_order).to(dtype)
4 changes: 4 additions & 0 deletions qutip/tests/core/data/test_mathematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,10 @@ def op_numpy(self, left, right):
pytest.param(data.kron_csr, CSR, CSR, CSR),
pytest.param(data.kron_dense, Dense, Dense, Dense),
pytest.param(data.kron_dia, Dia, Dia, Dia),
pytest.param(data.kron_dense_csr_csr, Dense, CSR, CSR),
pytest.param(data.kron_csr_dense_csr, CSR, Dense, CSR),
pytest.param(data.kron_dense_dia_dia, Dense, Dia, Dia),
pytest.param(data.kron_dia_dense_dia, Dia, Dense, Dia),
]


Expand Down