Merge pull request #2688 from araffin/upgrade-sb

Add CheckpointCallback and load best automatically
This commit is contained in:
erwincoumans 2020-03-22 14:04:00 -07:00 committed by GitHub
commit 4b41685d0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 67 deletions

View File

@ -4,17 +4,16 @@
# You can run it using: python -m pybullet_envs.stable_baselines.enjoy --algo td3 --env HalfCheetahBulletEnv-v0
# Author: Antonin RAFFIN
# MIT License
import os
import time
import argparse
import multiprocessing
import time
import gym
import numpy as np
import pybullet_envs
from stable_baselines import SAC, TD3
from stable_baselines.common.evaluation import evaluate_policy
from pybullet_envs.stable_baselines.utils import TimeFeatureWrapper
@ -26,7 +25,9 @@ if __name__ == '__main__':
parser.add_argument('-n', '--n-episodes', help='Number of episodes', default=5,
type=int)
parser.add_argument('--no-render', action='store_true', default=False,
help='Do not render the environment')
help='Do not render the environment')
parser.add_argument('--load-best', action='store_true', default=False,
help='Load best model instead of last model if available')
args = parser.parse_args()
env_id = args.env
@ -44,6 +45,13 @@ if __name__ == '__main__':
# We assume that the saved model is in the same folder
save_path = '{}_{}.zip'.format(args.algo, env_id)
if not os.path.isfile(save_path) or args.load_best:
print("Loading best model")
# Try to load best model
save_path = os.path.join('{}_{}'.format(args.algo, env_id), 'best_model.zip')
# Load the saved model
model = algo.load(save_path, env=env)
@ -63,7 +71,7 @@ if __name__ == '__main__':
episode_length += 1
if not args.no_render:
env.render(mode='human')
dt = 1./240.
dt = 1. / 240.
time.sleep(dt)
episode_rewards.append(episode_reward)
episode_lengths.append(episode_length)

View File

@ -12,8 +12,10 @@ import gym
import numpy as np
from stable_baselines import SAC, TD3
from stable_baselines.common.noise import NormalActionNoise
from stable_baselines.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines.common.vec_env import DummyVecEnv
from pybullet_envs.stable_baselines.utils import TimeFeatureWrapper, EvalCallback
from pybullet_envs.stable_baselines.utils import TimeFeatureWrapper
if __name__ == '__main__':
@ -23,6 +25,8 @@ if __name__ == '__main__':
parser.add_argument('--env', type=str, default='HalfCheetahBulletEnv-v0', help='environment ID')
parser.add_argument('-n', '--n-timesteps', help='Number of training timesteps', default=int(1e6),
type=int)
parser.add_argument('--save-freq', help='Save the model every n steps (if negative, no checkpoint)',
default=-1, type=int)
args = parser.parse_args()
env_id = args.env
@ -32,9 +36,15 @@ if __name__ == '__main__':
# Instantiate and wrap the environment
env = TimeFeatureWrapper(gym.make(env_id))
# Create the evaluation environment and callback
eval_env = TimeFeatureWrapper(gym.make(env_id))
callback = EvalCallback(eval_env, best_model_save_path=save_path + '_best')
# Create the evaluation environment and callbacks
eval_env = DummyVecEnv([lambda: TimeFeatureWrapper(gym.make(env_id))])
callbacks = [EvalCallback(eval_env, best_model_save_path=save_path)]
# Save a checkpoint every n steps
if args.save_freq > 0:
callbacks.append(CheckpointCallback(save_freq=args.save_freq, save_path=save_path,
name_prefix='rl_model'))
algo = {
'sac': SAC,
@ -55,7 +65,10 @@ if __name__ == '__main__':
}[args.algo]
model = algo('MlpPolicy', env, verbose=1, **hyperparams)
model.learn(n_timesteps, callback=callback)
try:
model.learn(n_timesteps, callback=callbacks)
except KeyboardInterrupt:
pass
print("Saving to {}.zip".format(save_path))
model.save(save_path)

View File

@ -6,8 +6,6 @@ import gym
import numpy as np
from gym.wrappers import TimeLimit
from stable_baselines.common.evaluation import evaluate_policy
class TimeFeatureWrapper(gym.Wrapper):
"""
@ -25,7 +23,7 @@ class TimeFeatureWrapper(gym.Wrapper):
assert isinstance(env.observation_space, gym.spaces.Box)
# Add a time feature to the observation
low, high = env.observation_space.low, env.observation_space.high
low, high= np.concatenate((low, [0])), np.concatenate((high, [1.]))
low, high = np.concatenate((low, [0])), np.concatenate((high, [1.]))
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)
super(TimeFeatureWrapper, self).__init__(env)
@ -59,56 +57,3 @@ class TimeFeatureWrapper(gym.Wrapper):
time_feature = 1.0
# Optionnaly: concatenate [time_feature, time_feature ** 2]
return np.concatenate((obs, [time_feature]))
class EvalCallback(object):
"""
Callback for evaluating an agent.
:param eval_env: (gym.Env) The environment used for initialization
:param n_eval_episodes: (int) The number of episodes to test the agent
:param eval_freq: (int) Evaluate the agent every eval_freq call of the callback.
:param deterministic: (bool)
:param best_model_save_path: (str)
:param verbose: (int)
"""
def __init__(self, eval_env, n_eval_episodes=5, eval_freq=10000,
deterministic=True, best_model_save_path=None, verbose=1):
super(EvalCallback, self).__init__()
self.n_eval_episodes = n_eval_episodes
self.eval_freq = eval_freq
self.best_mean_reward = -np.inf
self.deterministic = deterministic
self.eval_env = eval_env
self.verbose = verbose
self.model, self.num_timesteps = None, 0
self.best_model_save_path = best_model_save_path
self.n_calls = 0
def __call__(self, locals_, globals_):
"""
:param locals_: (dict)
:param globals_: (dict)
:return: (bool)
"""
self.n_calls += 1
self.model = locals_['self']
self.num_timesteps = self.model.num_timesteps
if self.n_calls % self.eval_freq == 0:
episode_rewards, _ = evaluate_policy(self.model, self.eval_env, n_eval_episodes=self.n_eval_episodes,
deterministic=self.deterministic, return_episode_rewards=True)
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
if self.verbose > 0:
print("Eval num_timesteps={}, "
"episode_reward={:.2f} +/- {:.2f}".format(self.num_timesteps, mean_reward, std_reward))
if mean_reward > self.best_mean_reward:
if self.best_model_save_path is not None:
print("Saving best model")
self.model.save(self.best_model_save_path)
self.best_mean_reward = mean_reward
return True