mirror of
https://github.com/bulletphysics/bullet3
synced 2024-12-15 06:00:12 +00:00
21f9d1b816
work-in-progress (need to add missing data files, fix paths etc) example: pip install pybullet pip install gym python import gym import pybullet import pybullet_envs env = gym.make("HumanoidBulletEnv-v0")
29 lines
1.1 KiB
Python
29 lines
1.1 KiB
Python
|
|
import re
|
|
from gym import error
|
|
import glob
|
|
# checkpoints/KerasDDPG-InvertedPendulum-v0-20170701190920_actor.h5
|
|
weight_save_re = re.compile(r'^(?:\w+\/)+?(\w+-v\d+)-(\w+-v\d+)-(\d+)(?:_\w+)?\.(\w+)$')
|
|
|
|
def get_fields(weight_save_name):
|
|
match = weight_save_re.search(weight_save_name)
|
|
if not match:
|
|
raise error.Error('Attempted to read a malformed weight save: {}. (Currently all weight saves must be of the form {}.)'.format(id,weight_save_re.pattern))
|
|
return match.group(1), match.group(2), int(match.group(3))
|
|
|
|
def get_latest_save(file_folder, agent_name, env_name, version_number):
|
|
"""
|
|
Returns the properties of the latest weight save. The information can be used to generate the loading path
|
|
:return:
|
|
"""
|
|
path = "%s%s"% (file_folder, "*.h5")
|
|
file_list = glob.glob(path)
|
|
latest_file_properties = []
|
|
file_properties = []
|
|
for f in file_list:
|
|
file_properties = get_fields(f)
|
|
if file_properties[0] == agent_name and file_properties[1] == env_name and (latest_file_properties == [] or file_properties[2] > latest_file_properties[2]):
|
|
latest_file_properties = file_properties
|
|
|
|
return latest_file_properties
|