diff --git a/examples/pybullet/gym/agents/simpleAgent.py b/examples/pybullet/gym/agents/simpleAgent.py index 78a5dfeb1..08a4cf1fa 100644 --- a/examples/pybullet/gym/agents/simpleAgent.py +++ b/examples/pybullet/gym/agents/simpleAgent.py @@ -30,19 +30,17 @@ class SimpleAgent(): def _build(self): self._agent_net = actor_net.ActorNetwork(self._actor_layer_size, self._action_size) - self._o_t = tf.placeholder(tf.float32, (31,)) + self._obs = tf.placeholder(tf.float32, (31,)) with tf.name_scope('Act'): - batch_o_t = snt.nest.pack_iterable_as( - self._o_t, - snt.nest.map( - lambda x: tf.expand_dims(x, 0), - snt.nest.flatten_iterable(self._o_t))) - self._action = self._agent_net(batch_o_t) + batch_obs = snt.nest.pack_iterable_as(self._obs, + snt.nest.map(lambda x: tf.expand_dims(x, 0), + snt.nest.flatten_iterable(self._obs))) + self._action = self._agent_net(batch_obs) saver = tf.train.Saver() saver.restore( sess=self._session, save_path=self._ckpt_path) def __call__(self, observation): - out_action = self._session.run(self._action, feed_dict={self._o_t: observation}) + out_action = self._session.run(self._action, feed_dict={self._obs: observation}) return out_action[0]