Add enjoy script for Stable Baselines

This commit is contained in:
Antonin RAFFIN 2020-02-15 21:06:10 +01:00
parent cbede4eb6c
commit 21efd84c18
2 changed files with 89 additions and 2 deletions

View File

@ -0,0 +1,86 @@
# Code adapted from https://github.com/araffin/rl-baselines-zoo
# it requires stable-baselines to be installed
# Colab Notebook: https://colab.research.google.com/drive/1nZkHO4QTYfAksm9ZTaZ5vXyC7szZxC3F
# You can run it using: python -m pybullet_envs.stable_baselines.enjoy --algo td3 --env HalfCheetahBulletEnv-v0
# Author: Antonin RAFFIN
# MIT License
import argparse
import multiprocessing
import gym
import numpy as np
import pybullet_envs
from stable_baselines import SAC, TD3
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.common.evaluation import evaluate_policy
from pybullet_envs.stable_baselines.utils import TimeFeatureWrapper
if __name__ == '__main__':
parser = argparse.ArgumentParser("Enjoy an RL agent trained using Stable Baselines")
parser.add_argument('--algo', help='RL Algorithm (Soft Actor-Critic by default)', default='sac',
type=str, required=False, choices=['sac', 'td3'])
parser.add_argument('--env', type=str, default='HalfCheetahBulletEnv-v0', help='environment ID')
parser.add_argument('-n', '--n-episodes', help='Number of episodes', default=5,
type=int)
parser.add_argument('--no-render', action='store_true', default=False,
help='Do not render the environment')
args = parser.parse_args()
env_id = args.env
# Create an env similar to the training env
env = TimeFeatureWrapper(gym.make(env_id))
# Use SubprocVecEnv for rendering
if not args.no_render:
# Note: fork is not thread-safe but usually is faster
fork_available = 'fork' in multiprocessing.get_all_start_methods()
start_method = 'fork' if fork_available else 'spawn'
env = SubprocVecEnv([lambda: env], start_method=start_method)
algo = {
'sac': SAC,
'td3': TD3
}[args.algo]
# We assume that the saved model is in the same folder
save_path = '{}_{}.zip'.format(args.algo, env_id)
# Load the saved model
model = algo.load(save_path, env=env)
try:
# Use deterministic actions for evaluation
episode_rewards, episode_lengths = [], []
for _ in range(args.n_episodes):
obs = env.reset()
done = False
episode_reward = 0.0
episode_length = 0
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, _info = env.step(action)
episode_reward += reward
episode_length += 1
if not args.no_render:
env.render(mode='human')
episode_rewards.append(episode_reward)
episode_lengths.append(episode_length)
print("Episode {} reward={}, length={}".format(len(episode_rewards), episode_reward, episode_length))
mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)
mean_len, std_len = np.mean(episode_lengths), np.std(episode_lengths)
print("==== Results ====")
print("Episode_reward={:.2f} +/- {:.2f}".format(mean_reward, std_reward))
print("Episode_length={:.2f} +/- {:.2f}".format(mean_len, std_len))
except KeyboardInterrupt:
pass
# Close process
env.close()

View File

@ -1,6 +1,7 @@
# Code adapted from https://github.com/araffin/rl-baselines-zoo
# it requires stable-baselines to be installed
# Colab Notebook: https://colab.research.google.com/drive/1nZkHO4QTYfAksm9ZTaZ5vXyC7szZxC3F
# You can run it using: python -m pybullet_envs.stable_baselines.train --algo td3 --env HalfCheetahBulletEnv-v0
# Author: Antonin RAFFIN
# MIT License
import argparse
@ -12,11 +13,11 @@ import numpy as np
from stable_baselines import SAC, TD3
from stable_baselines.common.noise import NormalActionNoise
from utils import TimeFeatureWrapper, EvalCallback
from pybullet_envs.stable_baselines.utils import TimeFeatureWrapper, EvalCallback
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser("Train an RL agent using Stable Baselines")
parser.add_argument('--algo', help='RL Algorithm (Soft Actor-Critic by default)', default='sac',
type=str, required=False, choices=['sac', 'td3'])
parser.add_argument('--env', type=str, default='HalfCheetahBulletEnv-v0', help='environment ID')