mirror of
https://github.com/bulletphysics/bullet3
synced 2024-12-13 21:30:09 +00:00
Merge pull request #2688 from araffin/upgrade-sb
Add CheckpointCallback and load best automatically
This commit is contained in:
commit
4b41685d0e
@ -4,17 +4,16 @@
|
||||
# You can run it using: python -m pybullet_envs.stable_baselines.enjoy --algo td3 --env HalfCheetahBulletEnv-v0
|
||||
# Author: Antonin RAFFIN
|
||||
# MIT License
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pybullet_envs
|
||||
|
||||
from stable_baselines import SAC, TD3
|
||||
|
||||
from stable_baselines.common.evaluation import evaluate_policy
|
||||
|
||||
from pybullet_envs.stable_baselines.utils import TimeFeatureWrapper
|
||||
|
||||
|
||||
@ -26,7 +25,9 @@ if __name__ == '__main__':
|
||||
parser.add_argument('-n', '--n-episodes', help='Number of episodes', default=5,
|
||||
type=int)
|
||||
parser.add_argument('--no-render', action='store_true', default=False,
|
||||
help='Do not render the environment')
|
||||
help='Do not render the environment')
|
||||
parser.add_argument('--load-best', action='store_true', default=False,
|
||||
help='Load best model instead of last model if available')
|
||||
args = parser.parse_args()
|
||||
|
||||
env_id = args.env
|
||||
@ -44,6 +45,13 @@ if __name__ == '__main__':
|
||||
|
||||
# We assume that the saved model is in the same folder
|
||||
save_path = '{}_{}.zip'.format(args.algo, env_id)
|
||||
|
||||
if not os.path.isfile(save_path) or args.load_best:
|
||||
print("Loading best model")
|
||||
# Try to load best model
|
||||
save_path = os.path.join('{}_{}'.format(args.algo, env_id), 'best_model.zip')
|
||||
|
||||
|
||||
# Load the saved model
|
||||
model = algo.load(save_path, env=env)
|
||||
|
||||
@ -63,7 +71,7 @@ if __name__ == '__main__':
|
||||
episode_length += 1
|
||||
if not args.no_render:
|
||||
env.render(mode='human')
|
||||
dt = 1./240.
|
||||
dt = 1. / 240.
|
||||
time.sleep(dt)
|
||||
episode_rewards.append(episode_reward)
|
||||
episode_lengths.append(episode_length)
|
||||
|
@ -12,8 +12,10 @@ import gym
|
||||
import numpy as np
|
||||
from stable_baselines import SAC, TD3
|
||||
from stable_baselines.common.noise import NormalActionNoise
|
||||
from stable_baselines.common.callbacks import EvalCallback, CheckpointCallback
|
||||
from stable_baselines.common.vec_env import DummyVecEnv
|
||||
|
||||
from pybullet_envs.stable_baselines.utils import TimeFeatureWrapper, EvalCallback
|
||||
from pybullet_envs.stable_baselines.utils import TimeFeatureWrapper
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -23,6 +25,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--env', type=str, default='HalfCheetahBulletEnv-v0', help='environment ID')
|
||||
parser.add_argument('-n', '--n-timesteps', help='Number of training timesteps', default=int(1e6),
|
||||
type=int)
|
||||
parser.add_argument('--save-freq', help='Save the model every n steps (if negative, no checkpoint)',
|
||||
default=-1, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
env_id = args.env
|
||||
@ -32,9 +36,15 @@ if __name__ == '__main__':
|
||||
# Instantiate and wrap the environment
|
||||
env = TimeFeatureWrapper(gym.make(env_id))
|
||||
|
||||
# Create the evaluation environment and callback
|
||||
eval_env = TimeFeatureWrapper(gym.make(env_id))
|
||||
callback = EvalCallback(eval_env, best_model_save_path=save_path + '_best')
|
||||
# Create the evaluation environment and callbacks
|
||||
eval_env = DummyVecEnv([lambda: TimeFeatureWrapper(gym.make(env_id))])
|
||||
|
||||
callbacks = [EvalCallback(eval_env, best_model_save_path=save_path)]
|
||||
|
||||
# Save a checkpoint every n steps
|
||||
if args.save_freq > 0:
|
||||
callbacks.append(CheckpointCallback(save_freq=args.save_freq, save_path=save_path,
|
||||
name_prefix='rl_model'))
|
||||
|
||||
algo = {
|
||||
'sac': SAC,
|
||||
@ -55,7 +65,10 @@ if __name__ == '__main__':
|
||||
}[args.algo]
|
||||
|
||||
model = algo('MlpPolicy', env, verbose=1, **hyperparams)
|
||||
model.learn(n_timesteps, callback=callback)
|
||||
try:
|
||||
model.learn(n_timesteps, callback=callbacks)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
print("Saving to {}.zip".format(save_path))
|
||||
model.save(save_path)
|
||||
|
@ -6,8 +6,6 @@ import gym
|
||||
import numpy as np
|
||||
from gym.wrappers import TimeLimit
|
||||
|
||||
from stable_baselines.common.evaluation import evaluate_policy
|
||||
|
||||
|
||||
class TimeFeatureWrapper(gym.Wrapper):
|
||||
"""
|
||||
@ -25,7 +23,7 @@ class TimeFeatureWrapper(gym.Wrapper):
|
||||
assert isinstance(env.observation_space, gym.spaces.Box)
|
||||
# Add a time feature to the observation
|
||||
low, high = env.observation_space.low, env.observation_space.high
|
||||
low, high= np.concatenate((low, [0])), np.concatenate((high, [1.]))
|
||||
low, high = np.concatenate((low, [0])), np.concatenate((high, [1.]))
|
||||
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)
|
||||
|
||||
super(TimeFeatureWrapper, self).__init__(env)
|
||||
@ -59,56 +57,3 @@ class TimeFeatureWrapper(gym.Wrapper):
|
||||
time_feature = 1.0
|
||||
# Optionnaly: concatenate [time_feature, time_feature ** 2]
|
||||
return np.concatenate((obs, [time_feature]))
|
||||
|
||||
|
||||
class EvalCallback(object):
|
||||
"""
|
||||
Callback for evaluating an agent.
|
||||
|
||||
:param eval_env: (gym.Env) The environment used for initialization
|
||||
:param n_eval_episodes: (int) The number of episodes to test the agent
|
||||
:param eval_freq: (int) Evaluate the agent every eval_freq call of the callback.
|
||||
:param deterministic: (bool)
|
||||
:param best_model_save_path: (str)
|
||||
:param verbose: (int)
|
||||
"""
|
||||
def __init__(self, eval_env, n_eval_episodes=5, eval_freq=10000,
|
||||
deterministic=True, best_model_save_path=None, verbose=1):
|
||||
super(EvalCallback, self).__init__()
|
||||
self.n_eval_episodes = n_eval_episodes
|
||||
self.eval_freq = eval_freq
|
||||
self.best_mean_reward = -np.inf
|
||||
self.deterministic = deterministic
|
||||
self.eval_env = eval_env
|
||||
self.verbose = verbose
|
||||
self.model, self.num_timesteps = None, 0
|
||||
self.best_model_save_path = best_model_save_path
|
||||
self.n_calls = 0
|
||||
|
||||
def __call__(self, locals_, globals_):
|
||||
"""
|
||||
:param locals_: (dict)
|
||||
:param globals_: (dict)
|
||||
:return: (bool)
|
||||
"""
|
||||
self.n_calls += 1
|
||||
self.model = locals_['self']
|
||||
self.num_timesteps = self.model.num_timesteps
|
||||
|
||||
if self.n_calls % self.eval_freq == 0:
|
||||
episode_rewards, _ = evaluate_policy(self.model, self.eval_env, n_eval_episodes=self.n_eval_episodes,
|
||||
deterministic=self.deterministic, return_episode_rewards=True)
|
||||
|
||||
|
||||
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
||||
if self.verbose > 0:
|
||||
print("Eval num_timesteps={}, "
|
||||
"episode_reward={:.2f} +/- {:.2f}".format(self.num_timesteps, mean_reward, std_reward))
|
||||
|
||||
if mean_reward > self.best_mean_reward:
|
||||
if self.best_model_save_path is not None:
|
||||
print("Saving best model")
|
||||
self.model.save(self.best_model_save_path)
|
||||
self.best_mean_reward = mean_reward
|
||||
|
||||
return True
|
||||
|
Loading…
Reference in New Issue
Block a user