diff --git a/2-cartpole/1-dqn.py b/2-cartpole/1-dqn.py index e7c0eaa..ddbb219 100644 --- a/2-cartpole/1-dqn.py +++ b/2-cartpole/1-dqn.py @@ -18,7 +18,6 @@ L(theta) = ( Q_theta(s)[a] - y )^2 """ import random -import sys from collections import deque import numpy as np @@ -135,8 +134,11 @@ def train_model(self): run_test_loop(env, agent.get_action) scores = [] + solved = False for e in range(EPISODES): + if solved: + break done = False score = 0 state, _ = env.reset() @@ -168,9 +170,8 @@ def train_model(self): # Early stop when consistently near max episode length. if np.mean(scores[-min(10, len(scores)):]) > 490: - torch.save(agent.model.state_dict(), SAVE_PATH) - print(f"Saved trained model to {SAVE_PATH}") - sys.exit() + solved = True + break torch.save(agent.model.state_dict(), SAVE_PATH) print(f"Saved trained model to {SAVE_PATH}") diff --git a/2-cartpole/2-a2c.py b/2-cartpole/2-a2c.py index 1a92b9b..a7d145a 100644 --- a/2-cartpole/2-a2c.py +++ b/2-cartpole/2-a2c.py @@ -19,7 +19,6 @@ Subtracting V_w(s) is the variance-reduction baseline; using a learned V (rather than the Monte-Carlo return) is what makes this *actor-critic*. """ -import sys import numpy as np import torch @@ -125,8 +124,11 @@ def train_model(self, state, action, reward, next_state, done): run_test_loop(env, agent.get_action) scores = [] + solved = False for e in range(EPISODES): + if solved: + break done = False score = 0 state, _ = env.reset() @@ -149,10 +151,8 @@ def train_model(self, state, action, reward, next_state, done): scores.append(score) print(f"episode: {e} score: {score}") if np.mean(scores[-min(10, len(scores)):]) > 490: - torch.save({"actor": agent.actor.state_dict(), - "critic": agent.critic.state_dict()}, SAVE_PATH) - print(f"Saved trained model to {SAVE_PATH}") - sys.exit() + solved = True + break torch.save({"actor": agent.actor.state_dict(), "critic": agent.critic.state_dict()}, SAVE_PATH) diff --git a/2-cartpole/3-ppo.py b/2-cartpole/3-ppo.py index 6850d72..174d266 100644 --- a/2-cartpole/3-ppo.py +++ b/2-cartpole/3-ppo.py @@ -27,7 +27,6 @@ L = L^CLIP - c_v * MSE(V, returns) + c_e * H[pi] """ -import sys import numpy as np import torch @@ -211,9 +210,7 @@ def pick(state): recent = ep_returns[-10:] print(f"update: {episode} recent_mean_return: {np.mean(recent):.1f} episodes: {len(ep_returns)}") if len(recent) >= 10 and np.mean(recent) > 490: - torch.save(model.state_dict(), SAVE_PATH) - print(f"Saved trained model to {SAVE_PATH}") - sys.exit() + break torch.save(model.state_dict(), SAVE_PATH) print(f"Saved trained model to {SAVE_PATH}")