-
Notifications
You must be signed in to change notification settings - Fork 264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Algorithm] CrossQ #2033
base: main
Are you sure you want to change the base?
[Algorithm] CrossQ #2033
Changes from all commits
0a23ae8
9bdee71
5086249
c3a927f
d1c9c34
e879b7c
75255e7
a7b79c3
be84f3f
2170ad8
75d4cee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# environment and task | ||
env: | ||
name: HalfCheetah-v4 | ||
task: "" | ||
library: gym | ||
max_episode_steps: 1000 | ||
seed: 42 | ||
|
||
# collector | ||
collector: | ||
total_frames: 1_000_000 | ||
init_random_frames: 25000 | ||
frames_per_batch: 1000 | ||
init_env_steps: 1000 | ||
device: cpu | ||
env_per_collector: 1 | ||
reset_at_each_iter: False | ||
|
||
# replay buffer | ||
replay_buffer: | ||
size: 1000000 | ||
prb: 0 # use prioritized experience replay | ||
scratch_dir: null | ||
|
||
# optim | ||
optim: | ||
utd_ratio: 1.0 | ||
policy_update_delay: 3 | ||
gamma: 0.99 | ||
loss_function: l2 | ||
lr: 3.0e-4 | ||
weight_decay: 0.0 | ||
batch_size: 256 | ||
alpha_init: 1.0 | ||
# Adam β1 = 0.5 | ||
adam_eps: 1.0e-8 | ||
|
||
# network | ||
network: | ||
batch_norm_momentum: 0.01 | ||
# warmup_steps: 100000 # 10^5 | ||
critic_hidden_sizes: [2048, 2048] | ||
actor_hidden_sizes: [256, 256] | ||
critic_activation: tanh | ||
actor_activation: relu | ||
default_policy_scale: 1.0 | ||
scale_lb: 0.1 | ||
device: "cuda:0" | ||
|
||
# logging | ||
logger: | ||
backend: wandb | ||
project_name: torchrl_example_crossQ | ||
group_name: null | ||
exp_name: ${env.name}_CrossQ | ||
mode: online | ||
eval_iter: 25000 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
"""CrossQ Example. | ||
|
||
This is a simple self-contained example of a CrossQ training script. | ||
|
||
It supports state environments like MuJoCo. | ||
|
||
The helper functions are coded in the utils.py associated with this script. | ||
""" | ||
import time | ||
|
||
import hydra | ||
|
||
import numpy as np | ||
import torch | ||
import torch.cuda | ||
import tqdm | ||
from torchrl._utils import logger as torchrl_logger | ||
from torchrl.envs.utils import ExplorationType, set_exploration_type | ||
|
||
from torchrl.record.loggers import generate_exp_name, get_logger | ||
from utils import ( | ||
log_metrics, | ||
make_collector, | ||
make_crossQ_agent, | ||
make_crossQ_optimizer, | ||
make_environment, | ||
make_loss_module, | ||
make_replay_buffer, | ||
) | ||
|
||
|
||
@hydra.main(version_base="1.1", config_path=".", config_name="config") | ||
def main(cfg: "DictConfig"): # noqa: F821 | ||
device = torch.device(cfg.network.device) | ||
|
||
# Create logger | ||
exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name) | ||
logger = None | ||
if cfg.logger.backend: | ||
logger = get_logger( | ||
logger_type=cfg.logger.backend, | ||
logger_name="crossq_logging", | ||
experiment_name=exp_name, | ||
wandb_kwargs={ | ||
"mode": cfg.logger.mode, | ||
"config": dict(cfg), | ||
"project": cfg.logger.project_name, | ||
"group": cfg.logger.group_name, | ||
}, | ||
) | ||
|
||
torch.manual_seed(cfg.env.seed) | ||
np.random.seed(cfg.env.seed) | ||
|
||
# Create environments | ||
train_env, eval_env = make_environment(cfg) | ||
|
||
# Create agent | ||
model, exploration_policy = make_crossQ_agent(cfg, train_env, eval_env, device) | ||
|
||
# Create CrossQ loss | ||
loss_module = make_loss_module(cfg, model) | ||
|
||
# Create off-policy collector | ||
collector = make_collector(cfg, train_env, exploration_policy.eval()) | ||
|
||
# Create replay buffer | ||
replay_buffer = make_replay_buffer( | ||
batch_size=cfg.optim.batch_size, | ||
prb=cfg.replay_buffer.prb, | ||
buffer_size=cfg.replay_buffer.size, | ||
scratch_dir=cfg.replay_buffer.scratch_dir, | ||
device="cpu", | ||
) | ||
|
||
# Create optimizers | ||
( | ||
optimizer_actor, | ||
optimizer_critic, | ||
optimizer_alpha, | ||
) = make_crossQ_optimizer(cfg, loss_module) | ||
|
||
# Main loop | ||
start_time = time.time() | ||
collected_frames = 0 | ||
pbar = tqdm.tqdm(total=cfg.collector.total_frames) | ||
|
||
init_random_frames = cfg.collector.init_random_frames | ||
num_updates = int( | ||
cfg.collector.env_per_collector | ||
* cfg.collector.frames_per_batch | ||
* cfg.optim.utd_ratio | ||
) | ||
prb = cfg.replay_buffer.prb | ||
eval_iter = cfg.logger.eval_iter | ||
frames_per_batch = cfg.collector.frames_per_batch | ||
eval_rollout_steps = cfg.env.max_episode_steps | ||
|
||
sampling_start = time.time() | ||
update_counter = 0 | ||
delayed_updates = cfg.optim.policy_update_delay | ||
for _, tensordict in enumerate(collector): | ||
sampling_time = time.time() - sampling_start | ||
|
||
# Update weights of the inference policy | ||
collector.update_policy_weights_() | ||
|
||
pbar.update(tensordict.numel()) | ||
|
||
tensordict = tensordict.reshape(-1) | ||
current_frames = tensordict.numel() | ||
# Add to replay buffer | ||
replay_buffer.extend(tensordict.cpu()) | ||
collected_frames += current_frames | ||
|
||
# Optimization steps | ||
training_start = time.time() | ||
if collected_frames >= init_random_frames: | ||
( | ||
actor_losses, | ||
alpha_losses, | ||
q_losses, | ||
) = ([], [], []) | ||
for _ in range(num_updates): | ||
|
||
# Update actor every delayed_updates | ||
update_counter += 1 | ||
update_actor = update_counter % delayed_updates == 0 | ||
# Sample from replay buffer | ||
sampled_tensordict = replay_buffer.sample() | ||
if sampled_tensordict.device != device: | ||
sampled_tensordict = sampled_tensordict.to( | ||
device, non_blocking=True | ||
) | ||
else: | ||
sampled_tensordict = sampled_tensordict.clone() | ||
|
||
# Compute loss | ||
q_loss, *_ = loss_module._qvalue_loss(sampled_tensordict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should not use private attributes in examples. Let's make qvalue_loss a public method if that is needed |
||
q_loss = q_loss.mean() | ||
# Update critic | ||
optimizer_critic.zero_grad() | ||
q_loss.backward() | ||
optimizer_critic.step() | ||
q_losses.append(q_loss.detach().item()) | ||
|
||
if update_actor: | ||
actor_loss, metadata_actor = loss_module._actor_loss( | ||
sampled_tensordict | ||
) | ||
actor_loss = actor_loss.mean() | ||
alpha_loss = loss_module._alpha_loss( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
log_prob=metadata_actor["log_prob"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the fact that the example requires that much knowledge about the way the loss works is a bit worrying - the script should be immediate. |
||
) | ||
alpha_loss = alpha_loss.mean() | ||
# Update actor | ||
optimizer_actor.zero_grad() | ||
actor_loss.backward() | ||
optimizer_actor.step() | ||
|
||
# Update alpha | ||
optimizer_alpha.zero_grad() | ||
alpha_loss.backward() | ||
optimizer_alpha.step() | ||
|
||
actor_losses.append(actor_loss.detach().item()) | ||
alpha_losses.append(alpha_loss.detach().item()) | ||
|
||
# Update priority | ||
if prb: | ||
replay_buffer.update_priority(sampled_tensordict) | ||
|
||
training_time = time.time() - training_start | ||
episode_end = ( | ||
tensordict["next", "done"] | ||
if tensordict["next", "done"].any() | ||
else tensordict["next", "truncated"] | ||
) | ||
episode_rewards = tensordict["next", "episode_reward"][episode_end] | ||
|
||
# Logging | ||
metrics_to_log = {} | ||
if len(episode_rewards) > 0: | ||
episode_length = tensordict["next", "step_count"][episode_end] | ||
metrics_to_log["train/reward"] = episode_rewards.mean().item() | ||
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( | ||
episode_length | ||
) | ||
if collected_frames >= init_random_frames: | ||
metrics_to_log["train/q_loss"] = np.mean(q_losses).item() | ||
metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item() | ||
metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item() | ||
metrics_to_log["train/sampling_time"] = sampling_time | ||
metrics_to_log["train/training_time"] = training_time | ||
|
||
# Evaluation | ||
if abs(collected_frames % eval_iter) < frames_per_batch: | ||
with set_exploration_type(ExplorationType.MODE), torch.no_grad(): | ||
eval_start = time.time() | ||
eval_rollout = eval_env.rollout( | ||
eval_rollout_steps, | ||
model[0], | ||
auto_cast_to_device=True, | ||
break_when_any_done=True, | ||
) | ||
eval_time = time.time() - eval_start | ||
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() | ||
metrics_to_log["eval/reward"] = eval_reward | ||
metrics_to_log["eval/time"] = eval_time | ||
if logger is not None: | ||
log_metrics(logger, metrics_to_log, collected_frames) | ||
sampling_start = time.time() | ||
|
||
collector.shutdown() | ||
end_time = time.time() | ||
execution_time = end_time - start_time | ||
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.