-
Notifications
You must be signed in to change notification settings - Fork 1
/
demo_vpg.py
41 lines (36 loc) · 1003 Bytes
/
demo_vpg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from algos.vpg import vpg, PolicyNet
import gym
import matplotlib.pyplot as plt
env = gym.make('CartPole-v0')
agent, mean_return_list = vpg(env, num_iter=100, max_num_steps=100, gamma=1.0,
num_traj=5)
plt.plot(mean_return_list)
plt.xlabel('Iteration')
plt.ylabel('Mean Return')
plt.savefig('vpg_returns.png', format='png', dpi=300)
state = env.reset()
for t in range(1000):
action = agent.act(state)
env.render()
state, reward, done, _ = env.step(action)
if done:
break
env.close()
# # Load saved model from file instead
# import torch
# input_size = env.observation_space.shape[0]
# output_size = env.action_space.n
# agent = PolicyNet(input_size, output_size)
# agent.load_state_dict(torch.load('vpg_policy.pt'))
# agent.eval()
#
#
# state = env.reset()
# for t in range(1000):
# action = agent.act(state)
# print(action)
# env.render()
# state, reward, done, _ = env.step(action)
# if done:
# break
# env.close()