Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enlarge #4

Merged
merged 28 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1a0ae78
Add: apo, trpo
YIFANSUN98 Oct 27, 2023
f5c9243
Add: ppo
YIFANSUN98 Oct 27, 2023
c4b35c8
feihan push test
LiangZhisama Oct 27, 2023
1194919
Add: A2C
LiangZhisama Oct 27, 2023
7ebed01
Add: CPO
LiangZhisama Oct 28, 2023
1be3493
Add: PCPO
LiangZhisama Oct 28, 2023
940512b
Add: PDO
LiangZhisama Oct 28, 2023
c74845d
Ad: LPG
LiangZhisama Oct 28, 2023
ab922c0
Add: SafeLayer
LiangZhisama Oct 28, 2023
ac891ac
Add: TRPOFAC
LiangZhisama Oct 28, 2023
6142c96
Add: TRPOIPO
LiangZhisama Oct 28, 2023
12cc78e
Add: TRPOLAG
LiangZhisama Oct 28, 2023
28d5cea
Add: USL
LiangZhisama Oct 28, 2023
b10dcae
Update: Buffer sign to BufferX
LiangZhisama Oct 28, 2023
012050f
Add SCPO and Bugfix about adc_buf usage
LiangZhisama Oct 28, 2023
ddbde3a
Update: update safe_rl_lib to the latest version
LiangZhisama Oct 28, 2023
1795abb
Update utils in safe_rl_lib
LiangZhisama Oct 28, 2023
d3a456f
Update: change the framework
Oct 29, 2023
8570993
Clean cache
LiangZhisama Oct 29, 2023
471fc39
Merge branch 'enlarge' of github.com:intelligent-control-lab/guardX i…
LiangZhisama Oct 29, 2023
f103bee
Update: safe_rl_envs
Oct 29, 2023
53f633d
Merge branch 'enlarge' of github.com:intelligent-control-lab/guardX i…
Oct 29, 2023
c9071a0
Merge pull request #2 from intelligent-control-lab/main
LiangZhisama Oct 29, 2023
c44a48f
Bugfix: about jaxlab GPU problem
LiangZhisama Oct 29, 2023
633790d
Bugfix: make every algorithm executable
LiangZhisama Oct 30, 2023
b95f525
Update: dafault settings of trpoipolar
LiangZhisama Oct 30, 2023
661b6fa
Update: engine to support obs getting
LiangZhisama Oct 30, 2023
ef40620
Remove: useless files
Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 0 additions & 56 deletions myTest.py

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
14 changes: 8 additions & 6 deletions safe_rl_envs/safe_rl_envs/envs/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def batched_step(data: mjx.Data, last_data: mjx.Data, last_last_data: mjx.Data,
action = jax.tree_map(np.array, ctrl_range)
self.action_space = gym.spaces.Box(action[:, 0], action[:, 1], dtype=np.float32)
# self.action_space = utils.batch_space(action_space, self.num_envs)
self.done = None
self.build_observation_space()
self.body_name2id = {}
self.body_name2id['floor'] = 0
Expand Down Expand Up @@ -329,7 +330,6 @@ def mjx_reset(self, rng):
xpos = self.mjx_data.xpos
xpos = xpos.at[3:].set(layout["hazards_pos"])
xpos = xpos.at[2].set(layout["goal_pos"][0])
# data = data.replace(xpos=jp.zeros(self.mjx_model.nbody), qpos=jp.zeros(self.mjx_model.nq), qvel=jp.zeros(self.mjx_model.nv), ctrl=jp.zeros(self.mjx_model.nu))
data = mjx.forward(self.mjx_model, data)
last_data = data
last_last_data = data
Expand All @@ -349,7 +349,6 @@ def mjx_step(self, data: mjx.Data,
last_last_data: mjx.Data,
action: jp.ndarray,
rng):
#! TODO: be able to send the last data and last last data and last dist goal
"""Runs one timestep of the environment's dynamics."""
def f(data, _):
data = data.replace(ctrl=action)
Expand All @@ -371,11 +370,14 @@ def f(data, _):
ctrl = jp.where(done > 0.0, x = ctrl_reset, y = data.ctrl)
qvel = jp.where(done > 0.0, x = qvel_reset, y = data.qvel)
data = data.replace(qpos=qpos, qvel=qvel, ctrl=ctrl)
data_reset, _ = jax.lax.scan(f, data, (), self.physics_steps_per_control_step)
obs_reset, obs_dict_reset = self.obs(data_reset, None, None)
info['obs_reset'] = jp.where(done > 0.0, x = obs_reset, y = obs)
return obs, reward, done, info, data

def reset(self):
''' Reset the physics simulation and return observation '''
obs, reward, done, info, self.data = self._reset(self.key)
obs, reward, self.done, info, self.data = self._reset(self.key)
self.info = info
return jax_to_torch(obs)

Expand All @@ -386,9 +388,9 @@ def step(self, action):

self.key,_ = jax.random.split(self.key, 2)
self.update_data()
obs, reward, done, info, self.data= self._step(self.data, self.last_data, self.last_last_data, action, self.key)
obs, reward, self.done, info, self.data= self._step(self.data, self.last_data, self.last_last_data, action, self.key)
self.info = info
return jax_to_torch(obs), jax_to_torch(reward), jax_to_torch(done), jax_to_torch(info)
return jax_to_torch(obs), jax_to_torch(reward), jax_to_torch(self.done), jax_to_torch(info)

def build_layout(self, rng):
''' Rejection sample a placement of objects to find a layout. '''
Expand Down Expand Up @@ -569,7 +571,7 @@ def cost(self, data: mjx.Data) -> jp.ndarray:
robot_pos = data.xpos[self.body_name2id['robot'],:].reshape(-1,3)
hazards_pos = data.xpos[self.body_name2id['hazards'],:].reshape(-1,3)
dist_robot2hazard = jp.linalg.norm(hazards_pos[:,:2] - robot_pos[:,:2], axis=1)
dist_robot2hazard_below_threshold = jp.maximum(dist_robot2hazard, self.hazards_size)
dist_robot2hazard_below_threshold = jp.minimum(dist_robot2hazard, self.hazards_size)
cost = jp.sum(self.hazards_size*jp.ones(dist_robot2hazard_below_threshold.shape) - dist_robot2hazard_below_threshold)
return cost

Expand Down
Binary file not shown.
Loading