New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Converting Tensorflow Dataset to iterator
does not sync well with client
#293
Comments
Another quick fix that I use is to wrap the reverb dataset inside the following class: class RefreshIterator:
"""tf.data.Dataset fix for slow reverb client synchronization. Wrap around reverb-dataset."""
__slots__ = ["_iterable"]
def __init__(self, iterable):
self._iterable = iterable
def __iter__(self):
return self
def __next__(self):
return next(iter(self._iterable))
def next(self):
return self.__next__() Use: dataset = datasets.make_reverb_dataset(
table=my_table.name, server_address=reverb_client.server_address, batch_size=..., ...
)
jax_dataset = utils.multi_device_put(_NumpyIterator(RefreshIterator(dataset)), ...) With unfortunately |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, I accidentally stumbled upon a problem within the tutorial notebook when playing around with the acme and reverb API that causes a weird synchronization behaviour between sampling from the
reverb
table and updating priorities. Another artifact of this that I encountered is that the very first transition would be consistently repeated until some hidden tensorflow buffer would be flushed.What I found is that when I would mutate the priorities in a
reverb
table usingclient.mutate_priorities(table_name, my_dict)
and then create an iterator from thetf.data.Dataset
object, then the priorities would update only after flushing a large number of samples. In contrast, if I didn't convert thetf.data.Dataset
to an iterator and used thedataset.batch(n); dataset.take(n)
interface, it would immediately sync with the new priorities.It seems to me that the problem lies with the implementation of
__iter__
in tf.data.Dataset, but I posted this issue here since the Colab makes a call toas_numpy_iterator()
on the dataset object, and this is also the implementation of theD4PG
jax agent. Since this is a silent and obscure bug, this effectively eliminates the possibility of changing the baselineD4PG
agent to utilize Prioritized Experience Replay.Minimal Reproducible example:
Output:
Proposed Solution
The problem is immediately solved if
iter(dataset)
is called at each call tonext
. Because of this, I wasn't sure whether to post this issue here or in the tensorflow github, since the problem is with tf.data.Dataset. Personally I would suggest creating a wrapper around tf.data.Dataset for that either makes use of thetake
andbatch
API, or reinitialize theiter
at every call. Because of howreverb
implements sampling, reinitializing the dataset iterator should have no side-effects.Example solution:
Output: ( priorities are updated after every call, which is what we expected).
The text was updated successfully, but these errors were encountered: