remove sonnet dependency

This commit is contained in:
Jie Tan 2017-05-18 16:12:38 -07:00
parent d40787f3fd
commit 8a6a46d180
11 changed files with 41 additions and 75 deletions

View File

@ -1,21 +0,0 @@
"""An actor network."""
import tensorflow as tf
import sonnet as snt
class ActorNetwork(snt.AbstractModule):
"""An actor network as a sonnet Module."""
def __init__(self, layer_sizes, action_size, name='target_actor'):
super(ActorNetwork, self).__init__(name=name)
self._layer_sizes = layer_sizes
self._action_size = action_size
def _build(self, inputs):
state = inputs
for output_size in self._layer_sizes:
state = snt.Linear(output_size)(state)
state = tf.nn.relu(state)
action = tf.tanh(
snt.Linear(self._action_size, name='action')(state))
return action

View File

@ -1,46 +0,0 @@
"""Loads a DDPG agent without too much external dependencies
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import collections
import numpy as np
import tensorflow as tf
import sonnet as snt
from agents import actor_net
class SimpleAgent():
def __init__(
self,
session,
ckpt_path,
actor_layer_size,
observation_size=(31,),
action_size=8,
):
self._ckpt_path = ckpt_path
self._actor_layer_size = actor_layer_size
self._observation_size = observation_size
self._action_size = action_size
self._session = session
self._build()
def _build(self):
self._agent_net = actor_net.ActorNetwork(self._actor_layer_size, self._action_size)
self._obs = tf.placeholder(tf.float32, (31,))
with tf.name_scope('Act'):
batch_obs = snt.nest.pack_iterable_as(self._obs,
snt.nest.map(lambda x: tf.expand_dims(x, 0),
snt.nest.flatten_iterable(self._obs)))
self._action = self._agent_net(batch_obs)
saver = tf.train.Saver()
saver.restore(
sess=self._session,
save_path=self._ckpt_path)
def __call__(self, observation):
out_action = self._session.run(self._action, feed_dict={self._obs: observation})
return out_action[0]

View File

@ -0,0 +1,36 @@
"""Loads a DDPG agent without too much external dependencies
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import collections
import numpy as np
import tensorflow as tf
import pdb
class SimplerAgent():
def __init__(
self,
session,
ckpt_path,
observation_dim=31
):
self._ckpt_path = ckpt_path
self._session = session
self._observation_dim = observation_dim
self._build()
def _build(self):
saver = tf.train.import_meta_graph(self._ckpt_path + '.meta')
saver.restore(
sess=self._session,
save_path=self._ckpt_path)
self._action = tf.get_collection('action_op')[0]
self._obs = tf.get_collection('observation_placeholder')[0]
def __call__(self, observation):
feed_dict={self._obs: observation}
out_action = self._session.run(self._action, feed_dict=feed_dict)
return out_action[0]

View File

@ -1,2 +1,2 @@
model_checkpoint_path: "/cns/ij-d/home/jietan/persistent/minitaur/minitaur_vizier_3_153645653/Bullet/MinitaurSimEnv/28158/0003600000/agent/tf_graph_data/tf_graph_data.ckpt"
all_model_checkpoint_paths: "/cns/ij-d/home/jietan/persistent/minitaur/minitaur_vizier_3_153645653/Bullet/MinitaurSimEnv/28158/0003600000/agent/tf_graph_data/tf_graph_data.ckpt"
model_checkpoint_path: "tf_graph_data_converted.ckpt-0"
all_model_checkpoint_paths: "tf_graph_data_converted.ckpt-0"

View File

@ -10,7 +10,7 @@ import numpy as np
import tensorflow as tf
from envs.bullet.minitaurGymEnv import MinitaurGymEnv
from agents import simpleAgent
from agents import simplerAgent
def testSinePolicy():
"""Tests sine policy
@ -53,17 +53,14 @@ def testDDPGPolicy():
environment = MinitaurGymEnv(render=True)
sum_reward = 0
steps = 1000
ckpt_path = 'data/agent/tf_graph_data/tf_graph_data.ckpt'
ckpt_path = 'data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0'
observation_shape = (31,)
action_size = 8
actor_layer_sizes = (100, 181)
n_steps = 0
tf.reset_default_graph()
with tf.Session() as session:
agent = simpleAgent.SimpleAgent(session, ckpt_path,
actor_layer_sizes,
observation_size=observation_shape,
action_size=action_size)
agent = simplerAgent.SimplerAgent(session, ckpt_path)
state = environment.reset()
action = agent(state)
for _ in range(steps):