mirror of
https://github.com/bulletphysics/bullet3
synced 2025-01-18 21:10:05 +00:00
Merge pull request #2567 from erwincoumans/master
one more fix in previous commit related to bullet_client.py
This commit is contained in:
commit
b57557c6cf
@ -25,6 +25,14 @@ register(
|
||||
reward_threshold=190.0,
|
||||
)
|
||||
|
||||
register(
|
||||
id='CartPoleContinuousBulletEnv-v0',
|
||||
entry_point='pybullet_envs.bullet:CartPoleContinuousBulletEnv',
|
||||
max_episode_steps=200,
|
||||
reward_threshold=190.0,
|
||||
)
|
||||
|
||||
|
||||
register(
|
||||
id='MinitaurBulletEnv-v0',
|
||||
entry_point='pybullet_envs.bullet:MinitaurBulletEnv',
|
||||
|
@ -1,4 +1,5 @@
|
||||
from pybullet_envs.bullet.cartpole_bullet import CartPoleBulletEnv
|
||||
from pybullet_envs.bullet.cartpole_bullet import CartPoleContinuousBulletEnv
|
||||
from pybullet_envs.bullet.minitaur_gym_env import MinitaurBulletEnv
|
||||
from pybullet_envs.bullet.minitaur_duck_gym_env import MinitaurBulletDuckEnv
|
||||
from pybullet_envs.bullet.racecarGymEnv import RacecarGymEnv
|
||||
|
@ -15,8 +15,9 @@ from gym.utils import seeding
|
||||
import numpy as np
|
||||
import time
|
||||
import subprocess
|
||||
import pybullet as p
|
||||
import pybullet as p2
|
||||
import pybullet_data
|
||||
import pybullet_utils.bullet_client as bc
|
||||
from pkg_resources import parse_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -25,13 +26,13 @@ logger = logging.getLogger(__name__)
|
||||
class CartPoleBulletEnv(gym.Env):
|
||||
metadata = {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 50}
|
||||
|
||||
def __init__(self, renders=True):
|
||||
def __init__(self, renders=False, discrete_actions=True):
|
||||
# start the bullet physics server
|
||||
self._renders = renders
|
||||
if (renders):
|
||||
p.connect(p.GUI)
|
||||
else:
|
||||
p.connect(p.DIRECT)
|
||||
self._discrete_actions = discrete_actions
|
||||
self._render_height = 200
|
||||
self._render_width = 320
|
||||
self._physics_client_id = -1
|
||||
self.theta_threshold_radians = 12 * 2 * math.pi / 360
|
||||
self.x_threshold = 0.4 #2.4
|
||||
high = np.array([
|
||||
@ -42,7 +43,13 @@ class CartPoleBulletEnv(gym.Env):
|
||||
|
||||
self.force_mag = 10
|
||||
|
||||
self.action_space = spaces.Discrete(2)
|
||||
if self._discrete_actions:
|
||||
self.action_space = spaces.Discrete(2)
|
||||
else:
|
||||
action_dim = 1
|
||||
action_high = np.array([self.force_mag] * action_dim)
|
||||
self.action_space = spaces.Box(-action_high, action_high)
|
||||
|
||||
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
|
||||
|
||||
self.seed()
|
||||
@ -58,7 +65,11 @@ class CartPoleBulletEnv(gym.Env):
|
||||
return [seed]
|
||||
|
||||
def step(self, action):
|
||||
force = self.force_mag if action == 1 else -self.force_mag
|
||||
p = self._p
|
||||
if self._discrete_actions:
|
||||
force = self.force_mag if action == 1 else -self.force_mag
|
||||
else:
|
||||
force = action[0]
|
||||
|
||||
p.setJointMotorControl2(self.cartpole, 0, p.TORQUE_CONTROL, force=force)
|
||||
p.stepSimulation()
|
||||
@ -77,19 +88,27 @@ class CartPoleBulletEnv(gym.Env):
|
||||
|
||||
def reset(self):
|
||||
# print("-----------reset simulation---------------")
|
||||
p.resetSimulation()
|
||||
self.cartpole = p.loadURDF(os.path.join(pybullet_data.getDataPath(), "cartpole.urdf"),
|
||||
[0, 0, 0])
|
||||
p.changeDynamics(self.cartpole, -1, linearDamping=0, angularDamping=0)
|
||||
p.changeDynamics(self.cartpole, 0, linearDamping=0, angularDamping=0)
|
||||
p.changeDynamics(self.cartpole, 1, linearDamping=0, angularDamping=0)
|
||||
self.timeStep = 0.02
|
||||
p.setJointMotorControl2(self.cartpole, 1, p.VELOCITY_CONTROL, force=0)
|
||||
p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, force=0)
|
||||
p.setGravity(0, 0, -9.8)
|
||||
p.setTimeStep(self.timeStep)
|
||||
p.setRealTimeSimulation(0)
|
||||
|
||||
if self._physics_client_id < 0:
|
||||
if self._renders:
|
||||
self._p = bc.BulletClient(connection_mode=p2.GUI)
|
||||
else:
|
||||
self._p = bc.BulletClient()
|
||||
self._physics_client_id = self._p._client
|
||||
|
||||
p = self._p
|
||||
p.resetSimulation()
|
||||
self.cartpole = p.loadURDF(os.path.join(pybullet_data.getDataPath(), "cartpole.urdf"),
|
||||
[0, 0, 0])
|
||||
p.changeDynamics(self.cartpole, -1, linearDamping=0, angularDamping=0)
|
||||
p.changeDynamics(self.cartpole, 0, linearDamping=0, angularDamping=0)
|
||||
p.changeDynamics(self.cartpole, 1, linearDamping=0, angularDamping=0)
|
||||
self.timeStep = 0.02
|
||||
p.setJointMotorControl2(self.cartpole, 1, p.VELOCITY_CONTROL, force=0)
|
||||
p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, force=0)
|
||||
p.setGravity(0, 0, -9.8)
|
||||
p.setTimeStep(self.timeStep)
|
||||
p.setRealTimeSimulation(0)
|
||||
p = self._p
|
||||
randstate = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
|
||||
p.resetJointState(self.cartpole, 1, randstate[0], randstate[1])
|
||||
p.resetJointState(self.cartpole, 0, randstate[2], randstate[3])
|
||||
@ -99,4 +118,51 @@ class CartPoleBulletEnv(gym.Env):
|
||||
return np.array(self.state)
|
||||
|
||||
def render(self, mode='human', close=False):
|
||||
return
|
||||
if mode == "human":
|
||||
self._renders = True
|
||||
if mode != "rgb_array":
|
||||
return np.array([])
|
||||
base_pos=[0,0,0]
|
||||
self._cam_dist = 2
|
||||
self._cam_pitch = 0.3
|
||||
self._cam_yaw = 0
|
||||
if (self._physics_client_id>=0):
|
||||
view_matrix = self._p.computeViewMatrixFromYawPitchRoll(
|
||||
cameraTargetPosition=base_pos,
|
||||
distance=self._cam_dist,
|
||||
yaw=self._cam_yaw,
|
||||
pitch=self._cam_pitch,
|
||||
roll=0,
|
||||
upAxisIndex=2)
|
||||
proj_matrix = self._p.computeProjectionMatrixFOV(fov=60,
|
||||
aspect=float(self._render_width) /
|
||||
self._render_height,
|
||||
nearVal=0.1,
|
||||
farVal=100.0)
|
||||
(_, _, px, _, _) = self._p.getCameraImage(
|
||||
width=self._render_width,
|
||||
height=self._render_height,
|
||||
renderer=self._p.ER_BULLET_HARDWARE_OPENGL,
|
||||
viewMatrix=view_matrix,
|
||||
projectionMatrix=proj_matrix)
|
||||
else:
|
||||
px = np.array([[[255,255,255,255]]*self._render_width]*self._render_height, dtype=np.uint8)
|
||||
rgb_array = np.array(px, dtype=np.uint8)
|
||||
rgb_array = np.reshape(np.array(px), (self._render_height, self._render_width, -1))
|
||||
rgb_array = rgb_array[:, :, :3]
|
||||
return rgb_array
|
||||
|
||||
def configure(self, args):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
if self._physics_client_id >= 0:
|
||||
self._p.disconnect()
|
||||
self._physics_client_id = -1
|
||||
|
||||
class CartPoleContinuousBulletEnv(CartPoleBulletEnv):
|
||||
metadata = {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 50}
|
||||
|
||||
def __init__(self, renders=False):
|
||||
# start the bullet physics server
|
||||
CartPoleBulletEnv.__init__(self, renders, discrete_actions=False)
|
||||
|
@ -41,7 +41,8 @@ class BulletClient(object):
|
||||
def __getattr__(self, name):
|
||||
"""Inject the client id into Bullet functions."""
|
||||
attribute = getattr(pybullet, name)
|
||||
attribute = functools.partial(attribute, physicsClientId=self._client)
|
||||
if inspect.isbuiltin(attribute):
|
||||
attribute = functools.partial(attribute, physicsClientId=self._client)
|
||||
if name=="disconnect":
|
||||
self._client = -1
|
||||
return attribute
|
||||
|
Loading…
Reference in New Issue
Block a user