bullet3/examples/pybullet/gym/pybullet_envs/stable_baselines/train.py
2020-02-15 21:06:10 +01:00

62 lines
2.4 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Code adapted from https://github.com/araffin/rl-baselines-zoo
# it requires stable-baselines to be installed
# Colab Notebook: https://colab.research.google.com/drive/1nZkHO4QTYfAksm9ZTaZ5vXyC7szZxC3F
# You can run it using: python -m pybullet_envs.stable_baselines.train --algo td3 --env HalfCheetahBulletEnv-v0
# Author: Antonin RAFFIN
# MIT License
import argparse
import pybullet_envs
import gym
import numpy as np
from stable_baselines import SAC, TD3
from stable_baselines.common.noise import NormalActionNoise
from pybullet_envs.stable_baselines.utils import TimeFeatureWrapper, EvalCallback
if __name__ == '__main__':
parser = argparse.ArgumentParser("Train an RL agent using Stable Baselines")
parser.add_argument('--algo', help='RL Algorithm (Soft Actor-Critic by default)', default='sac',
type=str, required=False, choices=['sac', 'td3'])
parser.add_argument('--env', type=str, default='HalfCheetahBulletEnv-v0', help='environment ID')
parser.add_argument('-n', '--n-timesteps', help='Number of training timesteps', default=int(1e6),
type=int)
args = parser.parse_args()
env_id = args.env
n_timesteps = args.n_timesteps
save_path = '{}_{}'.format(args.algo, env_id)
# Instantiate and wrap the environment
env = TimeFeatureWrapper(gym.make(env_id))
# Create the evaluation environment and callback
eval_env = TimeFeatureWrapper(gym.make(env_id))
callback = EvalCallback(eval_env, best_model_save_path=save_path + '_best')
algo = {
'sac': SAC,
'td3': TD3
}[args.algo]
n_actions = env.action_space.shape[0]
# Tuned hyperparameters from https://github.com/araffin/rl-baselines-zoo
hyperparams = {
'sac': dict(batch_size=256, gamma=0.98, policy_kwargs=dict(layers=[256, 256]),
learning_starts=10000, buffer_size=int(2e5), tau=0.01),
'td3': dict(batch_size=100, policy_kwargs=dict(layers=[400, 300]),
learning_rate=1e-3, learning_starts=10000, buffer_size=int(1e6),
train_freq=1000, gradient_steps=1000,
action_noise=NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)))
}[args.algo]
model = algo('MlpPolicy', env, verbose=1, **hyperparams)
model.learn(n_timesteps, callback=callback)
print("Saving to {}.zip".format(save_path))
model.save(save_path)