Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converting Tensorflow Dataset to iterator does not sync well with client #293

Open
joeryjoery opened this issue Mar 23, 2023 · 1 comment

Comments

@joeryjoery
Copy link

Hi, I accidentally stumbled upon a problem within the tutorial notebook when playing around with the acme and reverb API that causes a weird synchronization behaviour between sampling from the reverb table and updating priorities. Another artifact of this that I encountered is that the very first transition would be consistently repeated until some hidden tensorflow buffer would be flushed.

What I found is that when I would mutate the priorities in a reverb table using client.mutate_priorities(table_name, my_dict) and then create an iterator from the tf.data.Dataset object, then the priorities would update only after flushing a large number of samples. In contrast, if I didn't convert the tf.data.Dataset to an iterator and used the dataset.batch(n); dataset.take(n) interface, it would immediately sync with the new priorities.

It seems to me that the problem lies with the implementation of __iter__ in tf.data.Dataset, but I posted this issue here since the Colab makes a call to as_numpy_iterator() on the dataset object, and this is also the implementation of the D4PG jax agent. Since this is a silent and obscure bug, this effectively eliminates the possibility of changing the baseline D4PG agent to utilize Prioritized Experience Replay.

Minimal Reproducible example:

import warnings
warnings.filterwarnings('ignore')

import acme

from acme import wrappers
from acme.datasets import reverb as datasets
from acme.adders.reverb import sequence
from acme.jax import utils

import tree
import reverb
import jax

import numpy as np

from dm_control import suite


# Create dummy environment with short episodes to easily dichotomize samples
env = suite.load('cartpole', 'balance')
env = wrappers.step_limit.StepLimitWrapper(env, step_limit=5)
spec = acme.make_environment_spec(env)

# Danger: reverb.Table crashes kernel if run > once
table = reverb.Table(
    name='priority_table',
    sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
    remover=reverb.selectors.Fifo(),
    max_size=10_000,
    rate_limiter=reverb.rate_limiters.MinSize(1),
    signature=sequence.SequenceAdder.signature(spec)
)

server = reverb.Server([table], port=None)
client = reverb.Client(f'localhost:{server.port}')

# Construct adder such that only 1 sample is added to table after an episode.
adder = sequence.SequenceAdder(client, sequence_length=6, period=5)


def new_dataset():
    # Clear old data
    client.reset(table.name)
    return datasets.make_reverb_dataset(
        table=table.name, server_address=client.server_address, batch_size=3
    )


def fill_dataset():
    step = env.reset()
    adder.add_first(step)

    action = env.action_spec().generate_value()
    i = 0
    while (not step.last()) and i < 10:
        step = env.step(action)
        adder.add(action, step) 
        i += 1   

    env.close()
    adder.reset()
    
    
### Example of expected behaviour
dataset = new_dataset()
fill_dataset()

print('before mutation')
for s in dataset.take(1):
    k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()
    
    print(s.data.action.numpy().reshape(3, -1))  # (B, T, 1) -> (B, T)
    print('sample priority:', p)
    
    # Iteratively halve the priorities
    new_priorities = dict(zip(k, p * 0.5))
    client.mutate_priorities(table.name, new_priorities)
    
print()
    
print('after mutation')
for s in dataset.take(1):
    # Priorities have been updated --> all probabilities should now be adjusted.
    
    print(s.data.action.numpy().reshape(3, -1))  # (B, T, 1) -> (B, T)
    print('sample priority:', s.info.priority.numpy())

    
### Test-cases

print('\nUsing dataset.take')
dataset = new_dataset()
fill_dataset()

# This runs fine
for repeat in range(5):
    for i in range(30): # Flush count guess
        for s in dataset.take(1):
            k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()

            # Exponentially decay the priorities
            new_priorities = dict(zip(k, p * 0.999))
            client.mutate_priorities(table.name, new_priorities)

        for s in dataset.take(1):
            new_p = s.info.priority.numpy().ravel()
            assert not np.isclose(new_p, p).any(), "priorities did not update!"
    else:
        # No break in for loop
        print('No errors!')
          
        
print('\nUsing next on iter(dataset) - Problems start here.')
dataset = new_dataset()
fill_dataset()
it = iter(dataset)

# Repeat the test-loop as behaviour strangely changes periodically
for repeat in range(5):
    
    for i in range(30): # Flush count guess
        s = next(it)
        k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()

        # Iteratively halve the priorities
        new_priorities = dict(zip(k, p * 0.999))
        client.mutate_priorities(table.name, new_priorities)

        s = next(it)
        new_p = s.info.priority.numpy().ravel()

        # Priority mutations now sync extremely slowly
        if not np.isclose(p, new_p).all():
            print(f'Priorities updated at flush-step {i}')
            break
    else:
        # No break in for loop : not reached
        print('No errors!')        

Output:

before mutation
[[-1. -1. -1. -1. -1.  0.]
 [-1. -1. -1. -1. -1.  0.]
 [-1. -1. -1. -1. -1.  0.]]
sample priority: [1. 1. 1.]

after mutation
[[-1. -1. -1. -1. -1.  0.]
 [-1. -1. -1. -1. -1.  0.]
 [-1. -1. -1. -1. -1.  0.]]
sample priority: [0.5 0.5 0.5]

Using dataset.take
No errors!
No errors!
No errors!
No errors!
No errors!

Using next on iter(dataset) - Problems start here.
Priorities updated at flush-step 24
Priorities updated at flush-step 5
Priorities updated at flush-step 18
Priorities updated at flush-step 5
Priorities updated at flush-step 18

Proposed Solution

The problem is immediately solved if iter(dataset) is called at each call to next. Because of this, I wasn't sure whether to post this issue here or in the tensorflow github, since the problem is with tf.data.Dataset. Personally I would suggest creating a wrapper around tf.data.Dataset for that either makes use of the take and batch API, or reinitialize the iter at every call. Because of how reverb implements sampling, reinitializing the dataset iterator should have no side-effects.

Example solution:

        
print('\nReinitializing iter on every next call - Problem Solved.')
dataset = new_dataset()
fill_dataset()
it = iter(dataset)  # Ignore this iterator

# Repeat the test-loop as behaviour strangely changes periodically
for repeat in range(5):
    
    for i in range(30): # Flush count guess
        s = next(iter(dataset))  # CHANGE: call iter(dataset) every time `next` is called
        k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()

        # Iteratively halve the priorities
        new_priorities = dict(zip(k, p * 0.999))
        client.mutate_priorities(table.name, new_priorities)

        s = next(iter(dataset))  # CHANGE: call iter(dataset) every time `next` is called
        new_p = s.info.priority.numpy().ravel()

        # Priority mutations now sync extremely slowly
        if not np.isclose(p, new_p).all():
            print(f'Priorities updated at flush-step {i}')
            break
    else:
        # No break in for loop : not reached
        print('No errors!')

Output: ( priorities are updated after every call, which is what we expected).

Reinitializing iter on every next call - Problem Solved.
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
@joeryjoery
Copy link
Author

joeryjoery commented Mar 23, 2023

Another quick fix that I use is to wrap the reverb dataset inside the following class:

class RefreshIterator:
    """tf.data.Dataset fix for slow reverb client synchronization. Wrap around reverb-dataset."""
    
    __slots__ = ["_iterable"]
    
    def __init__(self, iterable):
        self._iterable = iterable
    
    def __iter__(self):
        return self

    def __next__(self):
        return next(iter(self._iterable))
    
    def next(self):
        return self.__next__()

Use:

dataset = datasets.make_reverb_dataset(
    table=my_table.name, server_address=reverb_client.server_address, batch_size=..., ...
)

jax_dataset = utils.multi_device_put(_NumpyIterator(RefreshIterator(dataset)), ...)

With unfortunately _NumpyIterator a private class in tf.dataset_ops.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant