Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make GlobalPhase not differentiable #5620

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

Tarun-Kumar07
Copy link
Contributor

@Tarun-Kumar07 Tarun-Kumar07 commented May 2, 2024

Context:
When using the following state preparation methods (AmplitudeEmbedding, StatePrep, MottonenStatePreparation) with jit and grad, the error ValueError: need at least one array to stack was encountered.

Description of the Change:
All state preparation strategies used GlobalPhase under the hood, which caused the above error. After this PR, GlobalPhase may not be differentiable anymore, as its grad_method is set to None.

Benefits:

Possible Drawbacks:

Related GitHub Issues:
It fixes #5541

@albi3ro
Copy link
Contributor

albi3ro commented May 2, 2024

Thanks for this @Tarun-Kumar07

For the failures due to errors:

ValueError: Computing the gradient of circuits that return the state with the parameter-shift rule gradient transform is not supported, as it is a hardware-compatible method.

That would be expected, and we should shift the measurement to expectation values.

For the failures due to:

FAILED tests/templates/test_state_preparations/test_mottonen_state_prep.py::test_jacobian_with_and_without_jit_has_same_output_with_high_shots[StatePrep] - AssertionError: assert Array(False, dtype=bool)
 +  where Array(False, dtype=bool) = <function allclose at 0x7fdb77481940>(Array([-0.0003,  0.0153, -0.0153,  0.0003], dtype=float64), Array([ 1.0187, -0.9953, -1.0047,  0.9813], dtype=float64), atol=0.02)

Those are legitimately different results, so we can safely safe we are getting wrong results in that case 😢 I'll investigate.

Copy link
Contributor

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a couple of small comments and one major suggestion: Could we set GlobalPhase.grad_method = "F"? This will produce unnecessary shifted tapes for expectation values and probabilities, but it will avoid wrong results when differentiating qml.state with finite_diff and param_shift.

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
pennylane/ops/identity.py Outdated Show resolved Hide resolved
pennylane/ops/op_math/controlled.py Outdated Show resolved Hide resolved
tests/templates/test_embeddings/test_amplitude.py Outdated Show resolved Hide resolved
@dwierichs
Copy link
Contributor

Those are legitimately different results, so we can safely safe we are getting wrong results in that case 😢 I'll investigate.

@Tarun-Kumar07 @albi3ro Not sure you got to this yet, but it seems that the decomposition of those state preparation methods handle special parameter values differently than others. This makes the derivative wrong at those special values, because param_shift is handed a tape that does not contain the general decomposition, so it will not shift all operations that need shifting. As JITting does not allow such special cases, we only make this mistake without JITting, hence the difference in the results within those tests.

Basically, the decomposition does something like the following decomposition for RZ:

def compute_decomposition(theta, wires):
    if not qml.math.is_abstract(theta) and qml.math.isclose(theta, 0):
        return []
    return [qml.RZ(theta, wires)]

It's correct but it does not have the correct parameter-shift derivative at 0.

This looks like an independent bug to me, and like one that could be hiding across the codebase for other ops as well, theoretically.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] jax.grad + jax.jit does not work with AmplitudeEmbedding and finite shots
3 participants