Skip to content

Commit

Permalink
Merge pull request #280 from kaiks/checkpoints
Browse files Browse the repository at this point in the history
Add storing and restoring RL agent checkpoints
  • Loading branch information
daochenzha committed Apr 19, 2023
2 parents 1066e7d + 9ec4e9a commit f4ae4fc
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 17 deletions.
48 changes: 35 additions & 13 deletions examples/run_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,32 @@ def train(args):
# Initialize the agent and use random agents as opponents
if args.algorithm == 'dqn':
from rlcard.agents import DQNAgent
agent = DQNAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
mlp_layers=[64,64],
device=device,
)
if args.load_checkpoint_path != "":
agent = DQNAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
else:
agent = DQNAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
mlp_layers=[64,64],
device=device,
save_path=args.log_dir,
save_every=args.save_every
)

elif args.algorithm == 'nfsp':
from rlcard.agents import NFSPAgent
agent = NFSPAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
hidden_layers_sizes=[64,64],
q_mlp_layers=[64,64],
device=device,
)
if args.load_checkpoint_path != "":
agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
else:
agent = NFSPAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
hidden_layers_sizes=[64,64],
q_mlp_layers=[64,64],
device=device,
save_path=args.log_dir,
save_every=args.save_every
)
agents = [agent]
for _ in range(1, env.num_players):
agents.append(RandomAgent(num_actions=env.num_actions))
Expand Down Expand Up @@ -152,6 +163,17 @@ def train(args):
type=str,
default='experiments/leduc_holdem_dqn_result/',
)

parser.add_argument(
"--load_checkpoint_path",
type=str,
default="",
)

parser.add_argument(
"--save_every",
type=int,
default=-1)

args = parser.parse_args()

Expand Down
139 changes: 138 additions & 1 deletion rlcard/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def __init__(self,
train_every=1,
mlp_layers=None,
learning_rate=0.00005,
device=None):
device=None,
save_path=None,
save_every=float('inf'),):

'''
Q-Learning algorithm for off-policy TD control using Function Approximation.
Expand All @@ -81,6 +83,8 @@ def __init__(self,
mlp_layers (list): The layer number and the dimension of each layer in MLP
learning_rate (float): The learning rate of the DQN agent.
device (torch.device): whether to use the cpu or gpu
save_path (str): The path to save the model checkpoints
save_every (int): Save the model every X training steps
'''
self.use_raw = False
self.replay_memory_init_size = replay_memory_init_size
Expand Down Expand Up @@ -114,6 +118,10 @@ def __init__(self,

# Create replay memory
self.memory = Memory(replay_memory_size, batch_size)

# Checkpoint saving parameters
self.save_path = save_path
self.save_every = save_every

def feed(self, ts):
''' Store data in to replay buffer and train the agent. There are two stages.
Expand Down Expand Up @@ -221,6 +229,13 @@ def train(self):

self.train_t += 1

if self.save_path and self.train_t % self.save_every == 0:
# To preserve every checkpoint separately,
# add another argument to the function call parameterized by self.train_t
self.save_checkpoint(self.save_path)
print("\nINFO - Saved model checkpoint.")


def feed_memory(self, state, action, reward, next_state, legal_actions, done):
''' Feed transition to memory
Expand All @@ -239,6 +254,73 @@ def set_device(self, device):
self.q_estimator.device = device
self.target_estimator.device = device

def checkpoint_attributes(self):
'''
Return the current checkpoint attributes (dict)
Checkpoint attributes are used to save and restore the model in the middle of training
Saves the model state dict, optimizer state dict, and all other instance variables
'''

return {
'agent_type': 'DQNAgent',
'q_estimator': self.q_estimator.checkpoint_attributes(),
'memory': self.memory.checkpoint_attributes(),
'total_t': self.total_t,
'train_t': self.train_t,
'epsilon_start': self.epsilons.min(),
'epsilon_end': self.epsilons.max(),
'epsilon_decay_steps': self.epsilon_decay_steps,
'discount_factor': self.discount_factor,
'update_target_estimator_every': self.update_target_estimator_every,
'batch_size': self.batch_size,
'num_actions': self.num_actions,
'train_every': self.train_every,
'device': self.device
}

@classmethod
def from_checkpoint(cls, checkpoint):
'''
Restore the model from a checkpoint
Args:
checkpoint (dict): the checkpoint attributes generated by checkpoint_attributes()
'''

print("\nINFO - Restoring model from checkpoint...")
agent_instance = cls(
replay_memory_size=checkpoint['memory']['memory_size'],
update_target_estimator_every=checkpoint['update_target_estimator_every'],
discount_factor=checkpoint['discount_factor'],
epsilon_start=checkpoint['epsilon_start'],
epsilon_end=checkpoint['epsilon_end'],
epsilon_decay_steps=checkpoint['epsilon_decay_steps'],
batch_size=checkpoint['batch_size'],
num_actions=checkpoint['num_actions'],
device=checkpoint['device'],
state_shape=checkpoint['q_estimator']['state_shape'],
mlp_layers=checkpoint['q_estimator']['mlp_layers'],
train_every=checkpoint['train_every']
)

agent_instance.total_t = checkpoint['total_t']
agent_instance.train_t = checkpoint['train_t']

agent_instance.q_estimator = Estimator.from_checkpoint(checkpoint['q_estimator'])
agent_instance.target_estimator = deepcopy(agent_instance.q_estimator)
agent_instance.memory = Memory.from_checkpoint(checkpoint['memory'])


return agent_instance

def save_checkpoint(self, path, filename='checkpoint_dqn.pt'):
''' Save the model checkpoint (all attributes)
Args:
path (str): the path to save the model
'''
torch.save(self.checkpoint_attributes(), path + '/' + filename)

class Estimator(object):
'''
Approximate clone of rlcard.agents.dqn_agent.Estimator that
Expand Down Expand Up @@ -334,6 +416,35 @@ def update(self, s, a, y):
self.qnet.eval()

return batch_loss

def checkpoint_attributes(self):
''' Return the attributes needed to restore the model from a checkpoint
'''
return {
'qnet': self.qnet.state_dict(),
'optimizer': self.optimizer.state_dict(),
'num_actions': self.num_actions,
'learning_rate': self.learning_rate,
'state_shape': self.state_shape,
'mlp_layers': self.mlp_layers,
'device': self.device
}

@classmethod
def from_checkpoint(cls, checkpoint):
''' Restore the model from a checkpoint
'''
estimator = cls(
num_actions=checkpoint['num_actions'],
learning_rate=checkpoint['learning_rate'],
state_shape=checkpoint['state_shape'],
mlp_layers=checkpoint['mlp_layers'],
device=checkpoint['device']
)

estimator.qnet.load_state_dict(checkpoint['qnet'])
estimator.optimizer.load_state_dict(checkpoint['optimizer'])
return estimator


class EstimatorNetwork(nn.Module):
Expand Down Expand Up @@ -415,3 +526,29 @@ def sample(self):
samples = random.sample(self.memory, self.batch_size)
samples = tuple(zip(*samples))
return tuple(map(np.array, samples[:-1])) + (samples[-1],)

def checkpoint_attributes(self):
''' Returns the attributes that need to be checkpointed
'''

return {
'memory_size': self.memory_size,
'batch_size': self.batch_size,
'memory': self.memory
}

@classmethod
def from_checkpoint(cls, checkpoint):
'''
Restores the attributes from the checkpoint
Args:
checkpoint (dict): the checkpoint dictionary
Returns:
instance (Memory): the restored instance
'''

instance = cls(checkpoint['memory_size'], checkpoint['batch_size'])
instance.memory = checkpoint['memory']
return instance

0 comments on commit f4ae4fc

Please sign in to comment.