mirror of
https://github.com/bulletphysics/bullet3
synced 2024-12-15 06:00:12 +00:00
a0ded43a69
add a red sphere to make training a bit easier for now.
42 lines
912 B
Python
42 lines
912 B
Python
import gym
|
|
from envs.bullet.racecarZEDGymEnv import RacecarZEDGymEnv
|
|
|
|
from baselines import deepq
|
|
|
|
import datetime
|
|
|
|
|
|
|
|
def callback(lcl, glb):
|
|
# stop training if reward exceeds 199
|
|
total = sum(lcl['episode_rewards'][-101:-1]) / 100
|
|
totalt = lcl['t']
|
|
is_solved = totalt > 2000 and total >= -50
|
|
return is_solved
|
|
|
|
def main():
|
|
|
|
env = RacecarZEDGymEnv(renders=False)
|
|
model = deepq.models.cnn_to_mlp(
|
|
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
|
hiddens=[256],
|
|
dueling=False
|
|
)
|
|
act = deepq.learn(
|
|
env,
|
|
q_func=model,
|
|
lr=1e-3,
|
|
max_timesteps=10000,
|
|
buffer_size=50000,
|
|
exploration_fraction=0.1,
|
|
exploration_final_eps=0.02,
|
|
print_freq=10,
|
|
callback=callback
|
|
)
|
|
print("Saving model to racecar_zed_model.pkl")
|
|
act.save("racecar_zed_model.pkl")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|