Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconstructing scene from qpos and qvel #345

Open
namheegordonkim opened this issue Apr 24, 2023 · 14 comments
Open

Reconstructing scene from qpos and qvel #345

namheegordonkim opened this issue Apr 24, 2023 · 14 comments

Comments

@namheegordonkim
Copy link
Contributor

I've mentioned this briefly in #157, but I believe this deserves a separate thread.

In certain applications like Go-Explore (https://www.nature.com/articles/s41586-020-03157-9), it's critical for the user to be able to reset the environment to the state corresponding to the gathered observation, i.e., qpos and qvel.

In my experiments with Brax v2, I haven't been able to successfully reconstruct the simulator state from qpos and qvel alone--I suspect the 2nd order information is also necessary. According to MuJoCo docs:

class mujoco_py.MjSimState

Represents a snapshot of the simulator’s state.

This includes time, qpos, qvel, act, and udd_state.

Attributes

Methods

Ignoring time and udd_state, act needs to be taken into account. Not sure by what mechanism MuJoCo does this, but we can gather that qpos, qvel, and act should be necessary and sufficient for reconstructing the state.

Current workaround for this is to actually store the states (i.e. in a batched instance of State class) gathered during exploration and to use that as the argument for step(). But this comes at a huge memory overhead.

Any idea whether this will be made possible?

@erikfrey
Copy link
Collaborator

Hello! Yes, q and qd (or qpos and qvel in mujoco terms) are all you need to recreate simulator state. That said, for performance reasons, Brax stores other data that it reuses from one step to the next. Knowing this, you have two options:

  1. If you wish to store only the minimal state, q and qd, you can always recreate the full state by calling pipeline.init to recreate the desired brax state, then call pipeline.step afterwards.

  2. If you can afford to save the full state, then you can skip calling pipeline.init to recreate the state.

@namheegordonkim
Copy link
Contributor Author

Thanks for the fast response, @erikfrey.

If you wish to store only the minimal state, q and qd, you can always recreate the full state by calling pipeline.init to recreate the desired brax state, then call pipeline.step afterwards.

This has not been the case for me unfortunately--at least in positional backend, using pipeline.init() with q and qd did not result in the same matrices; I checked this by saving tuples of (State, (q, qd)) and comparing entries inside State vs. pipeline.init(q, qd). Would be happy to provide a minimal example when I get the chance.

@erikfrey
Copy link
Collaborator

OK! It's our expectation that this would work, so please do share a repro and we'll take a look.

@namheegordonkim
Copy link
Contributor Author

namheegordonkim commented Apr 25, 2023

Here's the minimal example I promised: https://colab.research.google.com/drive/1lAqGR4mpd4EX4eeXhrlR4CzQEoQncbZk?usp=sharing

You will see that state2 is a reconstructed state from state1's q and qd values, but after exact same actions are used for step(), the two trajectories diverge, as visible in the rendered output. backend="positional" also has the same issue.

@erwincoumans
Copy link

Can you provide info on exactly what data is reused?

Brax stores other data that it reuses from one step to the next.

@namheegordonkim
Copy link
Contributor Author

Not to hijack Erwin's question--I further narrowed it down to kinematics.forward().

Repro: https://colab.research.google.com/drive/1P3mtOCNTngZ6A_CJtjJdWzn4zX4TPcda?usp=sharing

Seems that using anything but the default q and qd results in unreproducible resulst of kinematics.forward().

@erikfrey
Copy link
Collaborator

@namheegordonkim I think you're onto something here. I agree the assert in your colab shouldn't fire. We'll take a closer look.

@erwincoumans it varies by pipeline. In generalized, for example, we iteratively update the inverse mass matrix from the previous step instead of recalculating it from scratch. We can certainly improve the documentation to make this clearer.

@namheegordonkim
Copy link
Contributor Author

namheegordonkim commented Apr 28, 2023

Great to hear that you'll be looking into this. I've wrestled with it for a few days and identified some sus lines. I hope these help!

brax/brax/kinematics.py

Lines 88 to 105 in c2cd14c

def world(parent, j, jd):
"""Convert transform/motion from joint frame to world frame."""
if parent is None:
return j, jd
x, xd = parent
# TODO: determine why the motion `do` is inverted
x = x.vmap().do(j)
xd = xd + Motion(
ang=jax.vmap(math.rotate)(jd.ang, x.rot),
vel=jax.vmap(math.rotate)(
jd.vel + jax.vmap(jp.cross)(x.pos, jd.ang), x.rot
),
)
return x, xd
x, xd = scan.tree(sys, world, 'll', j, jd)
x = x.replace(rot=jax.vmap(math.normalize)(x.rot)[0])

Here, the parent link and the child link are held together by the joint. Initially x is the world coordinate position of the parent link COM, but using x.vmap().do(j) transforms it to the world coordinate position of the child link. The child link inherits the parent's world coordinate angles and velocities, so all that's left (theoretically) is to add the contributions of the angular velocity of the joint.

However, jax.vmap(jp.cross)(x.pos, jd.ang) plainly seems wrong here: the linear velocity contribution via angular velocity should be done as a cross product between angular velocity and moment arm. First, x.pos doesn't give you the local coordinate moment arm; sys.link.joint.pos does. Next, the cross product isn't commutable so shouldn't the order between these two swtiched?

I did fool around with passing sys.link.joint into scan.tree call as another argument, but hadn't had success in replicating the correct linear velocities.

@erikfrey
Copy link
Collaborator

Oh whoops, you know, I lied to you. For spring and pbd pipelines, q and qd are not sufficient to reconstruct physics state. Those two pipelines can produce states with joint constraint violations which cannot be expressed in reduced coordinates. That is why everything lines up for you at init, but after a step you start to see the error - no joint constraints are violated at init, but are after a step.

For spring and pbd, you would want to use either x and xd or x_i and xd_i to reconstruct the rest of the state. For generalized, you should be able to rely on q and qd.

Sorry for leading you down a rabbit hole. Please let me know if some part of simulation still doesn't make sense for you.

@namheegordonkim
Copy link
Contributor Author

I lied to you. For spring and pbd pipelines, q and qd are not sufficient to reconstruct physics state.
For generalized, you should be able to rely on q and qd.

Well, FWIW the demo I shared is using generalized.

According to your comment, if I disabled gravity, hung the character up in the air with absolutely no collision possibilities, q and qd should be sufficient for reconstructing, but this isn't the case either :( I do think kinematics.forward() has a bug in it as mentioned above.

@erikfrey
Copy link
Collaborator

erikfrey commented May 1, 2023

Hi @namheegordonkim - no problem, let's get to the bottom of this. First off, just so I understand, where you say:

Well, FWIW the demo I shared is using generalized.

You mean this colab? As far as I can tell you're using the positional backend, right? Just to be sure I fired up your colab and switched it to 'generalized' and sure enough, the assert goes away:

# using kinematics.forward(env.sys, q1, qd1) should yield x1, xd1.
x3, xd3 = jax.vmap(kinematics.forward, in_axes=(None, 0, 0,))(env.sys, q1, qd1)

tree_map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-3), x1, x3)

That no longer throws after switching to generalized.

@namheegordonkim
Copy link
Contributor Author

@erikfrey

I shared two colabs.

Here's the minimal example I promised: https://colab.research.google.com/drive/1lAqGR4mpd4EX4eeXhrlR4CzQEoQncbZk?usp=sharing

Not to hijack Erwin's question--I further narrowed it down to kinematics.forward().
Repro: https://colab.research.google.com/drive/1P3mtOCNTngZ6A_CJtjJdWzn4zX4TPcda?usp=sharing

I was specifically referring to the first one.

@erikfrey
Copy link
Collaborator

erikfrey commented May 3, 2023

Oh got it, thanks for clearing that up. This is helping me sharpen how I talk about Brax. So the behavior in the generalized colab is expected (by me at least, haha!). We have not done any work to guarantee Brax is deterministic, that it will produce the same trajectory across diverse hardware (the typical case), or the exact same trajectory across states that were created via step vs. via init.

I agree this would be a nice property for Brax and some engines go to through the trouble to make explicit claims here, so we'll add that to our TODO.

One thing that I am certain of is that init vs step will produce slightly different mass matrices. If you would like to remove this difference, you can try setting matrix_inv_iterations to zero, as this will force brax to use the same matrix inverse operation for both step and init. This will slow down simulation though!

For now I think your best bet may be to store the entire State struct.

@erikfrey
Copy link
Collaborator

erikfrey commented May 3, 2023

Also please do let me know if you have a repro for forward kinematics that shows it's broken! It's pretty well tested so I'd be surprised if it's incorrect. But I love surprises!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants