Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras-RL2 DQN agent fails to learn on some environments #393

Open
rajfly opened this issue Jul 10, 2023 · 0 comments
Open

Keras-RL2 DQN agent fails to learn on some environments #393

rajfly opened this issue Jul 10, 2023 · 0 comments

Comments

@rajfly
Copy link

rajfly commented Jul 10, 2023

Why I am using Keras-RL2

I am using Keras-RL2 (v1.0.5) because my use case requires the AdamW optimiser which currently actively supported in TensorFlow 2 and thus, incompatible with Keras-RL since it is based on TensorFlow 1. I understand this repo is not Keras-RL2 but that repo is archived (so I can't post an issue there). This is why I am posting an issue here so please do not remove it.

Problem

The results from the trained DQN agent using Keras-RL2 show that 2 Atari environments (Pong and Boxing) completely fail to learn while 1 environment (Freeway) learns correctly. However, I tested this same DQN configuration (i.e., same hyperparameters etc.) across multiple other RL frameworks such as Stable Baselines 3 and RLlib and observed that all 3 environments were able to learn correctly. This leads me to think that there is either a bug in the DQN agent or I made a mistake while configuring the DQN agent here in Keras-RL2.

Keras-RL2 reward graph for comparison

Screenshot 2023-07-10 at 1 00 31 PM

Stable Baselines 3 reward graph for comparison

Screenshot 2023-07-10 at 12 59 10 PM

The DQN algorithm should conform to the algorithm in the Nature DQN paper. Listed below are the hyperparameters I set to be the same across all RL frameworks.

Gymnasium Environment Configuration

  • Max epsisode frames: 108k frames (default for ALE/*-v5 envs)
  • Mode: Default
  • Difficulty: Defalut
  • Obs type: Grayscale
  • Frameskip (w/ max pooling): 4
  • Repeat action probability: 0.25
  • Full action space: False
  • Noop reset: 0
  • Terminal on life loss: False
  • Resize: 84 x 84
  • Scale observation: [0,1)
  • Reward clipped: [-1, 1]
  • Frame stack: 4

Network Configuration

  • Layer 1: Conv2D (in: 4, out: 32, kernel: 8, stride: 4, padding: valid, activation: relu, bias: True, dialation: 1, groups: 1)
  • Layer 2: Conv2D (in: 32, out: 64, kernel: 4, stride: 2, padding: valid, activation: relu, bias: True, dialation: 1, groups: 1)
  • Layer 3: Conv2D (in: 64, out: 64, kernel: 3, stride: 1, padding: valid, activation: relu, bias: True, dialation: 1, groups: 1)
  • Layer 4: Flatten
  • Layer 5: LazyLinear(out: 512, activation: relu, bias: True)
  • Layer 6: Linear(in: 512, out: n_action, bias: True)
  • Layer initlizers (kernel and bias) were not specified, i.e., defaults for PyTorch, TensorFlow and Jax were used.
  • Optimizer: AdamW(lr=1e-4, amsgrad=True (False for Jax since it does not have that option), betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)
  • Loss: Huber (reduction: mean (might be sum for Keras RL 2, none for tf agents), delta: 1.0)
  • Target network update type: Hard
  • Gradient clipping: None

Algorithm Configuration

  • Gamma: 0.99
  • Gradient steps: 1
  • N step: 1
  • Minibatch size: 32
  • Eps start: 1
  • Eps end: 0.01
  • Eps steps: 1e6
  • Replay memory capacity: 100e3
  • Replay memory init capacity: 50e3
  • Total train steps: 5e6
  • Policy update freq: 4
  • Target update freq: (1e3 for Pong, Boxing), (1e4 for Freeway)
  • Eval episodes: 100
  • Eval eps: 0.001

Below is my code for Keras-RL2

import os
import uuid
import argparse
import numpy as np
import gym
from gym.wrappers import AtariPreprocessing, TransformReward, FrameStack

from rl.agents.dqn import DQNAgent
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy
from rl.memory import SequentialMemory
from rl.callbacks import Callback

import tensorflow as tf
import tensorflow_addons as tfa
from tensorboardX import SummaryWriter

def train_eval(config):
    # set gpu, path, and writer
    os.environ["CUDA_VISIBLE_DEVICES"] = f'{config.gpu}'
    path = os.path.join(os.getcwd(), 'runs', 'krl2', f'train_eval_{config.env}_{uuid.uuid4()}')
    writer = SummaryWriter(path)

    # framestack done internally with window_length
    env = gym.make(f"ALE/{config.env}-v5")
    env = AtariPreprocessing(env, noop_max=0, frame_skip=1, scale_obs=True)
    env = TransformReward(env, lambda x: np.clip(x, -1, 1))
    num_outputs = env.action_space.n

    # loss huber by default with delta_clip=1 in DQNAgent
    # target network update is hard by default if target_model_update > 1
    # gradient clipping not enabled by default
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(32, 8, 4, 'valid', 'channels_first', activation='relu', input_shape=(4,84,84)))
    model.add(tf.keras.layers.Conv2D(64, 4, 2, 'valid', 'channels_first', activation='relu'))
    model.add(tf.keras.layers.Conv2D(64, 3, 1, 'valid', 'channels_first', activation='relu'))
    model.add(tf.keras.layers.Flatten('channels_first'))
    model.add(tf.keras.layers.Dense(512, 'relu'))
    model.add(tf.keras.layers.Dense(num_outputs))
    optim = tfa.optimizers.AdamW(learning_rate=1e-4, epsilon=1e-8, weight_decay=0.01, amsgrad=True)

    # circular replay buffer with random sample
    memory = SequentialMemory(limit=100000, window_length=4, ignore_episode_boundaries=False) 
    policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1.0, value_min=0.01, value_test=0.001, nb_steps=1000000+50000)
    test_policy = EpsGreedyQPolicy(0.001)

    # g_step is 1 by default
    # n_step is 1 by default
    dqn = DQNAgent(
        nb_actions=num_outputs, 
        memory=memory, 
        gamma=.99, 
        batch_size=32,
        nb_steps_warmup=50000, 
        train_interval=4, 
        memory_interval=1,
        target_model_update=10000 if config.env == 'Freeway' else 1000, 
        delta_range=None,
        delta_clip=1,
        model=model, 
        policy=policy, 
        test_policy=test_policy,
        enable_double_dqn=False,
        enable_dueling_network=False,
    )

    dqn.compile(optim)
    class TrainCallback(Callback):
        def on_episode_end(self, episode, logs):
            writer.add_scalar('Timestep/reward', logs['episode_reward'], logs['nb_steps'])

    class TestCallback(Callback):
        def __init__(self):
            self.episodes = 0
            self.total_reward = 0
        def on_episode_end(self, episode, logs):
            self.episodes += 1
            self.total_reward += logs['episode_reward']
            if self.episodes == 100:
                score = self.total_reward/self.episodes
                np.save(os.path.join(path, 'score.npy'), [score])

    dqn.fit(
        env, 
        nb_steps=5000000, 
        action_repetition=1,
        callbacks=[TrainCallback()],
        verbose=0,
        visualize=False,
        nb_max_start_steps=0,
        start_step_policy=None,
        nb_max_episode_steps=None,
    )
    writer.close()

    print('train done')

    dqn.test(
        env, 
        nb_episodes=100, 
        action_repetition=1,
        callbacks=[TestCallback()],
        visualize=False,
        nb_max_episode_steps=None,
        nb_max_start_steps=0,
        start_step_policy=None,
        verbose=0,
    )

    print('eval done')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, help='Specify GPU index', required=True)
    parser.add_argument('--env', type=str, help='Specify gym environment to use', required=True)
    args = parser.parse_args()
    train_eval(args)

I would appreciate any thoughts/comments on this matter. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant