Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Some decompositions/transforms do not preserve derivatives #5715

Open
1 task done
dwierichs opened this issue May 21, 2024 · 1 comment
Open
1 task done

[BUG] Some decompositions/transforms do not preserve derivatives #5715

dwierichs opened this issue May 21, 2024 · 1 comment
Assignees
Labels
bug 🐛 Something isn't working

Comments

@dwierichs
Copy link
Contributor

dwierichs commented May 21, 2024

Expected behavior

Using decompositions and transforms does not change the derivative of the overall workflow.

Actual behavior

Some decompositions/transforms only reproduce the function, but not its derivative. I found this in the following parts of the codebase:

  • merge_rotations: Some rotation gates are skipped for zero angles
  • single_qubit_fusion: Some rotation gates are skipped for zero angles
  • MottonenStatePreparation: Depending on the input state, gates are skipped, which leads to errors with JITting (no gradient entries to stack) or produces nan values.
  • fuse_rot_angles: Used in merge_rotations and single_qubit_fusion, creates second bugs within both functions

Additional information

Note that JITting usually prevents the source of error (except for MottonenStatePrep), and in all examples above, the code base has special logic for JITting.
As a consequence, JITted derivatives tend to be unaffected by the type of bug observed in the transforms.

Under the hood, this seems like similar to #5541, which is concerned with AmplitudeEmbedding and is being solved in #5620 by modifying the diff method of GlobalPhase. However, the bug described here is of different origin and was encountered while finalizing the tests for #5620 for MottonenStatePreparation.

Source code

#### BUG caused by merge_rotations itself

@qml.transforms.merge_rotations
def _node(x):
    qml.RX(x, 1)
    qml.RX(x, 1)
    return qml.expval(qml.Y(1))

dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))

print("Derivatives at 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:    
    print(jax.jacobian(node_)(0.))

>>> Derivatives at 0:
... 0.0
... -2.0
... 0.0
... -1.9999999999999996

print("Derivatives close to 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
    print(jax.jacobian(node_)(1e-8))

>>> Derivatives close to 0:
... -1.9999999999999993
... -1.9999999999999993
... -1.9999999999999993
... -1.9999999999999993

#### BUG caused by fuse_rot_angles via merge_rotations

@qml.transforms.merge_rotations
def _node(x):
    qml.Rot(x, x, x, 1)
    qml.Rot(x, x, x, 1)
    return qml.expval(qml.X(1))

dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))

print("Derivatives at 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:    
    print(jax.jacobian(node_)(0.))

>>> Derivatives at 0:
... 0.0
... 2.0
... 0.0
... 1.9999999999999996

print("Derivatives close to 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
    print(jax.jacobian(node_)(1e-6))

>>> Derivatives close to 0:
... 2.0000221220668224
... 1.9999999999840001
... 2.0000221220668215
... 1.9999999999839995


#### BUGS in single_qubit_fusion, one in the function itself, one from fuse_rot_angles
@partial(qml.transforms.single_qubit_fusion, atol=1e-6)
def _node(x):
    qml.RX(x, 1)
    qml.RX(x, 1)
    return qml.expval(qml.Y(1))

dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))

print("Derivatives at 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:    
    print(jax.jacobian(node_)(0.))

>>> Derivatives at 0:
... 0.0
... 0.0
... 0.0
... 0.0

print("Derivatives close to 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
    print(jax.jacobian(node_)(1e-7))

>>> Derivatives close to 0:
... 0.0
... -2.000799757290469
... 0.0
... -2.000799757290469

print("Derivatives less close to 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
    print(jax.jacobian(node_)(1e-5))

>>> Derivatives less close to 0:
... -1.9999999168263007
... -1.9999999168263007
... -1.9999999168263003
... -1.9999999168263003

#### BUGS with MottonenStatePreparation
def _node(x):
    qml.MottonenStatePreparation(x, wires=[0, 1])
    return qml.probs()

dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))

x1 = jnp.array([1, 1, 0, 1]) / np.sqrt(3)

for node_ in [node_ps, node_ad, node_ps_jit, node_ad_jit]: # Fails with JITted nodes   
    print(jax.jacobian(node_)(x1))
>>> [[ 7.69800359e-01 -3.84900179e-01             nan             nan]
...  [-3.84900179e-01  7.69800359e-01             nan             nan]
...  [-4.80740672e-17  4.80740672e-17             nan             nan]
...  [-3.84900179e-01 -3.84900179e-01             nan             nan]]
>>> [[ 7.69800359e-01 -3.84900179e-01             nan             nan]
...  [-3.84900179e-01  7.69800359e-01             nan             nan]
...  [-4.80740672e-17  4.80740672e-17             nan             nan]
...  [-3.84900179e-01 -3.84900179e-01             nan             nan]]

x2 = jnp.array([1, 0, 0, 1]) / np.sqrt(2)
for node_ in [node_ps, node_ad, node_ps_jit, node_ad_jit]: # Fails with JITted nodes   
    print(jax.jacobian(node_)(x2))

>>> [[nan nan nan nan]
...  [nan nan nan nan]
...  [nan nan nan nan]
...  [nan nan nan nan]]
>>> [[nan nan nan nan]
...  [nan nan nan nan]
...  [nan nan nan nan]
...  [nan nan nan nan]]

Tracebacks

No response

System information

pl dev

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@dwierichs dwierichs added the bug 🐛 Something isn't working label May 21, 2024
@dwierichs dwierichs self-assigned this May 30, 2024
@dwierichs
Copy link
Contributor Author

While trying to fix this, I noticed that fuse_rot_angles uses a function that - as it stands - is not differentiable everywhere. At those singular points, we're returning wrong derivatives in yet another way :/

dwierichs added a commit that referenced this issue Jun 13, 2024
**Context:**
The decomposition of `MottonenStatePreparation` skips some gates for
special parameter values/input states.
See the linked issue for details.

**Description of the Change:**
This PR introduces a check for differentiability so that the gates only
are skipped when no derivatives are being computed.
Note that this does *not* fix the non-differentiability at other special
parameter points that also is referenced in #5715 and that is being
warned against in the docs already.
Also, the linked issue is about multiple operations and we here only
address `MottonenStatePreparation`.

**Benefits:**
Fixes parts of #5715. Unblocks #5620 .

**Possible Drawbacks:**

**Related GitHub Issues:**
#5715
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant