Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] CrossQ #2033

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di
logger.backend=
# logger.record_video=True \
# logger.record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/crossq/crossq.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
total_frames=200 \
init_random_frames=10 \
Expand Down
57 changes: 57 additions & 0 deletions sota-implementations/crossq/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# environment and task
env:
name: HalfCheetah-v4
task: ""
library: gym
max_episode_steps: 1000
seed: 42

# collector
collector:
total_frames: 1_000_000
init_random_frames: 25000
frames_per_batch: 1000
init_env_steps: 1000
device: cpu
env_per_collector: 1
reset_at_each_iter: False

# replay buffer
replay_buffer:
size: 1000000
prb: 0 # use prioritized experience replay
scratch_dir: null

# optim
optim:
utd_ratio: 1.0
policy_update_delay: 3
gamma: 0.99
loss_function: l2
lr: 3.0e-4
weight_decay: 0.0
batch_size: 256
alpha_init: 1.0
# Adam β1 = 0.5
adam_eps: 1.0e-8

# network
network:
batch_norm_momentum: 0.01
# warmup_steps: 100000 # 10^5
critic_hidden_sizes: [2048, 2048]
actor_hidden_sizes: [256, 256]
critic_activation: tanh
actor_activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"

# logging
logger:
backend: wandb
project_name: torchrl_example_crossQ
group_name: null
exp_name: ${env.name}_CrossQ
mode: online
eval_iter: 25000
225 changes: 225 additions & 0 deletions sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""CrossQ Example.

This is a simple self-contained example of a CrossQ training script.

It supports state environments like MuJoCo.

The helper functions are coded in the utils.py associated with this script.
"""
import time

import hydra

import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
log_metrics,
make_collector,
make_crossQ_agent,
make_crossQ_optimizer,
make_environment,
make_loss_module,
make_replay_buffer,
)


@hydra.main(version_base="1.1", config_path=".", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
device = torch.device(cfg.network.device)
device = torch.device(cfg.network.device)
if device is None:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Create logger
exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="crossq_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create environments
train_env, eval_env = make_environment(cfg)

# Create agent
model, exploration_policy = make_crossQ_agent(cfg, train_env, eval_env, device)

# Create CrossQ loss
loss_module = make_loss_module(cfg, model)

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy.eval())

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
scratch_dir=cfg.replay_buffer.scratch_dir,
device="cpu",
)

# Create optimizers
(
optimizer_actor,
optimizer_critic,
optimizer_alpha,
) = make_crossQ_optimizer(cfg, loss_module)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
update_counter = 0
delayed_updates = cfg.optim.policy_update_delay
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
alpha_losses,
q_losses,
) = ([], [], [])
for _ in range(num_updates):

# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
q_loss, *_ = loss_module._qvalue_loss(sampled_tensordict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not use private attributes in examples. Let's make qvalue_loss a public method if that is needed

q_loss = q_loss.mean()
# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.detach().item())

if update_actor:
actor_loss, metadata_actor = loss_module._actor_loss(
sampled_tensordict
)
actor_loss = actor_loss.mean()
alpha_loss = loss_module._alpha_loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

log_prob=metadata_actor["log_prob"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fact that the example requires that much knowledge about the way the loss works is a bit worrying - the script should be immediate.
Is there a version of this were alpha_loss just takes the metadata dict?

)
alpha_loss = alpha_loss.mean()
# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
optimizer_alpha.step()

actor_losses.append(actor_loss.detach().item())
alpha_losses.append(alpha_loss.detach().item())

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
)
episode_rewards = tensordict["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses).item()
metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item()
metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()