mirror of
https://github.com/bulletphysics/bullet3
synced 2024-12-14 13:50:04 +00:00
8e8955571f
try: import tensorflow.compat.v1 as tf except Exception: import tensorflow as tf
186 lines
6.2 KiB
Python
186 lines
6.2 KiB
Python
# Copyright 2017 The TensorFlow Agents Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Utilities for using reinforcement learning algorithms."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
|
|
import ruamel.yaml as yaml
|
|
try:
|
|
import tensorflow.compat.v1 as tf
|
|
except Exception:
|
|
import tensorflow as tf
|
|
|
|
from . import tools
|
|
|
|
|
|
def define_simulation_graph(batch_env, algo_cls, config):
|
|
"""Define the algortihm and environment interaction.
|
|
|
|
Args:
|
|
batch_env: In-graph environments object.
|
|
algo_cls: Constructor of a batch algorithm.
|
|
config: Configuration object for the algorithm.
|
|
|
|
Returns:
|
|
Object providing graph elements via attributes.
|
|
"""
|
|
# pylint: disable=unused-variable
|
|
step = tf.Variable(0, False, dtype=tf.int32, name='global_step')
|
|
is_training = tf.placeholder(tf.bool, name='is_training')
|
|
should_log = tf.placeholder(tf.bool, name='should_log')
|
|
do_report = tf.placeholder(tf.bool, name='do_report')
|
|
force_reset = tf.placeholder(tf.bool, name='force_reset')
|
|
algo = algo_cls(batch_env, step, is_training, should_log, config)
|
|
done, score, summary = tools.simulate(batch_env, algo, should_log, force_reset)
|
|
message = 'Graph contains {} trainable variables.'
|
|
tf.logging.info(message.format(tools.count_weights()))
|
|
# pylint: enable=unused-variable
|
|
return tools.AttrDict(locals())
|
|
|
|
|
|
def define_batch_env(constructor, num_agents, env_processes):
|
|
"""Create environments and apply all desired wrappers.
|
|
|
|
Args:
|
|
constructor: Constructor of an OpenAI gym environment.
|
|
num_agents: Number of environments to combine in the batch.
|
|
env_processes: Whether to step environment in external processes.
|
|
|
|
Returns:
|
|
In-graph environments object.
|
|
"""
|
|
with tf.variable_scope('environments'):
|
|
if env_processes:
|
|
envs = [tools.wrappers.ExternalProcess(constructor) for _ in range(num_agents)]
|
|
else:
|
|
envs = [constructor() for _ in range(num_agents)]
|
|
batch_env = tools.BatchEnv(envs, blocking=not env_processes)
|
|
batch_env = tools.InGraphBatchEnv(batch_env)
|
|
return batch_env
|
|
|
|
|
|
def define_saver(exclude=None):
|
|
"""Create a saver for the variables we want to checkpoint.
|
|
|
|
Args:
|
|
exclude: List of regexes to match variable names to exclude.
|
|
|
|
Returns:
|
|
Saver object.
|
|
"""
|
|
variables = []
|
|
exclude = exclude or []
|
|
exclude = [re.compile(regex) for regex in exclude]
|
|
for variable in tf.global_variables():
|
|
if any(regex.match(variable.name) for regex in exclude):
|
|
continue
|
|
variables.append(variable)
|
|
saver = tf.train.Saver(variables, keep_checkpoint_every_n_hours=5)
|
|
return saver
|
|
|
|
|
|
def initialize_variables(sess, saver, logdir, checkpoint=None, resume=None):
|
|
"""Initialize or restore variables from a checkpoint if available.
|
|
|
|
Args:
|
|
sess: Session to initialize variables in.
|
|
saver: Saver to restore variables.
|
|
logdir: Directory to search for checkpoints.
|
|
checkpoint: Specify what checkpoint name to use; defaults to most recent.
|
|
resume: Whether to expect recovering a checkpoint or starting a new run.
|
|
|
|
Raises:
|
|
ValueError: If resume expected but no log directory specified.
|
|
RuntimeError: If no resume expected but a checkpoint was found.
|
|
"""
|
|
sess.run(tf.group(tf.local_variables_initializer(), tf.global_variables_initializer()))
|
|
if resume and not (logdir or checkpoint):
|
|
raise ValueError('Need to specify logdir to resume a checkpoint.')
|
|
if logdir:
|
|
state = tf.train.get_checkpoint_state(logdir)
|
|
if checkpoint:
|
|
checkpoint = os.path.join(logdir, checkpoint)
|
|
if not checkpoint and state and state.model_checkpoint_path:
|
|
checkpoint = state.model_checkpoint_path
|
|
if checkpoint and resume is False:
|
|
message = 'Found unexpected checkpoint when starting a new run.'
|
|
raise RuntimeError(message)
|
|
if checkpoint:
|
|
saver.restore(sess, checkpoint)
|
|
|
|
|
|
def save_config(config, logdir=None):
|
|
"""Save a new configuration by name.
|
|
|
|
If a logging directory is specified, is will be created and the configuration
|
|
will be stored there. Otherwise, a log message will be printed.
|
|
|
|
Args:
|
|
config: Configuration object.
|
|
logdir: Location for writing summaries and checkpoints if specified.
|
|
|
|
Returns:
|
|
Configuration object.
|
|
"""
|
|
if logdir:
|
|
with config.unlocked:
|
|
config.logdir = logdir
|
|
message = 'Start a new run and write summaries and checkpoints to {}.'
|
|
tf.logging.info(message.format(config.logdir))
|
|
tf.gfile.MakeDirs(config.logdir)
|
|
config_path = os.path.join(config.logdir, 'config.yaml')
|
|
with tf.gfile.GFile(config_path, 'w') as file_:
|
|
yaml.dump(config, file_, default_flow_style=False)
|
|
else:
|
|
message = ('Start a new run without storing summaries and checkpoints since no '
|
|
'logging directory was specified.')
|
|
tf.logging.info(message)
|
|
return config
|
|
|
|
|
|
def load_config(logdir):
|
|
"""Load a configuration from the log directory.
|
|
|
|
Args:
|
|
logdir: The logging directory containing the configuration file.
|
|
|
|
Raises:
|
|
IOError: The logging directory does not contain a configuration file.
|
|
|
|
Returns:
|
|
Configuration object.
|
|
"""
|
|
config_path = logdir and os.path.join(logdir, 'config.yaml')
|
|
if not config_path or not tf.gfile.Exists(config_path):
|
|
message = ('Cannot resume an existing run since the logging directory does not '
|
|
'contain a configuration file.')
|
|
raise IOError(message)
|
|
with tf.gfile.FastGFile(config_path, 'r') as file_:
|
|
config = yaml.load(file_, Loader=yaml.Loader)
|
|
message = 'Resume run and write summaries and checkpoints to {}.'
|
|
tf.logging.info(message.format(config.logdir))
|
|
return config
|
|
|
|
|
|
def set_up_logging():
|
|
"""Configure the TensorFlow logger."""
|
|
tf.logging.set_verbosity(tf.logging.INFO)
|
|
logging.getLogger('tensorflow').propagate = False
|