You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi there, thanks for your work and looking at this. I am trying to port my environment definition to JAX with the hopes that I can run massively parallel environment rollouts for my RL algorithm.
The primary goal is to navigate in a 3D maze. Ignoring the details for now, I based my implementation upon Gym and Gymnax conventions here: https://github.com/m-krastev/gymnax
To make use of massive parallelism, I rely on my whole code being JIT-able in reasonable margins. In this case, the code can be compiled. However, it runs around 1-2x slower compared to a similar code written in numpy + Torch (with most work done on CPU). Upon profiling, I see that a huge amount of time is spent by memcpyD2D, and I struggle to find what could cause the compiler to not in-place operations.
I provide my main code (gymnax/environments/medical/small_bowel.py) with some parts omitted just to fit within reasonable margins. SmallBowelParams are meant to be immutable once instantiated and only the state being passed around with new values.
Environment code:
@struct.dataclass(kw_only=True)classSmallBowelParams(environment.EnvParams):
"""Parameters for the Small Bowel environment."""image: jnp.ndarray# (D, H, W) floatseg: jnp.ndarray# (D, H, W) boolwall_map: jnp.ndarray# (D, H, W) floatgt_path_vol: jnp.ndarray# (D, H, W) boolgdt_start: jnp.ndarray# (D, H, W) floatgdt_end: jnp.ndarray# (D, H, W) floatlocal_peaks: jnp.ndarray# (N, 3) intstart_coord: jnp.ndarray# (3,) intend_coord: jnp.ndarray# (3,) int# seg_volume should be kept in the state, but since at each reset we have to# calculate it, it's better to keep it in the static parameters.seg_volume: float# I'd like to avoid passing the image shape as a parameter,# but the compiler will complain otherwise.image_shape: jnp.ndarrayr_zero_mov: float=10.0r_val1: float=4.0r_val2: float=6.0r_final: float=100.0r_peaks: float=4.0max_step_vox: float=10.0# Max movement in voxelsgdt_max_increase_theta: float=sqrt(3) *10.0patch_size_vox: jnp.ndarray=struct.field(
pytree_node=False,
default=(32, 32, 32), # Default patch size in voxels
)
cumulative_path_radius_vox: int=3@struct.dataclassclassSmallBowelState(environment.EnvState):
"""State of the Small Bowel environment."""current_pos_vox: jnp.ndarray# (3,) intgoal: jnp.ndarray# (3,) intgdt: jnp.ndarray# (D, H, W) float, GDT map for current episodecumulative_path_mask: jnp.ndarray# (D, H, W) boolmax_gdt_achieved: floatcum_reward: floatreward_map: jnp.ndarray# (D, H, W) float, for peakswall_gradient: floatlength: floatclassSmallBowel(environment.Environment[SmallBowelState, SmallBowelParams]):
""" Gymnax-compatible environment for RL-based small bowel path tracking. Implemented in JAX. """def__init__(self):
super().__init__()
@propertydefdefault_params(self) ->SmallBowelParams:
"""Default environment parameters."""# These are placeholders. Actual data will be loaded externally.returnSmallBowelParams(
image=jnp.zeros((1, 1, 1), dtype=jnp.float32),
seg=jnp.zeros((1, 1, 1), dtype=jnp.bool_),
wall_map=jnp.zeros((1, 1, 1), dtype=jnp.float32),
gdt_start=jnp.zeros((1, 1, 1), dtype=jnp.float32),
gdt_end=jnp.zeros((1, 1, 1), dtype=jnp.float32),
gt_path_vol=jnp.zeros((1, 1, 1), dtype=jnp.bool_),
local_peaks=jnp.zeros((1, 3), dtype=jnp.int32),
start_coord=jnp.zeros((3,), dtype=jnp.int32),
end_coord=jnp.zeros((3,), dtype=jnp.int32),
image_shape=jnp.ones((3,), dtype=jnp.int32), # Placeholderseg_volume=0.0,
)
@partial(jax.jit, static_argnames=("self",))defreset_env(
self, key: jax.Array, params: SmallBowelParams
) ->tuple[jnp.ndarray, SmallBowelState]:
"""Resets the environment."""key, key_choice, key_shuffle=jax.random.split(key, 3)
# Random start logic:# 0-3: start_coord -> end_coord (gdt_start)# 4-5: random local peak -> (end_coord or start_coord) (gdt_start or gdt_end)# 6-9: end_coord -> start_coord (gdt_end)rand_choice=jax.random.randint(key_choice, (), 0, 10)
# Case 1: Start at beginning, go to endcurrent_pos_vox_case1=params.start_coordgoal_case1=params.end_coordgdt_map_case1=params.gdt_start# Case 2: Start at random local peakpeak_idx=jax.random.randint(key_shuffle, (), 0, params.local_peaks.shape[0])
current_pos_vox_case2=params.local_peaks[peak_idx]
# Randomly go in either direction for local peak startgoal_case2_opt1=params.end_coordgdt_map_case2_opt1=params.gdt_startgoal_case2_opt2=params.start_coordgdt_map_case2_opt2=params.gdt_end# Use key_shuffle for this binary choicepeak_direction_choice=jax.random.bernoulli(key_shuffle)
goal_case2=jax.lax.select(
peak_direction_choice, goal_case2_opt1, goal_case2_opt2
)
gdt_map_case2=jax.lax.select(
peak_direction_choice, gdt_map_case2_opt1, gdt_map_case2_opt2
)
# Case 3: Start at end, go to startcurrent_pos_vox_case3=params.end_coordgoal_case3=params.start_coordgdt_map_case3=params.gdt_end# Select based on rand_choicecurrent_pos_vox=jax.lax.switch(
rand_choice,
[
lambda: current_pos_vox_case1, # 0lambda: current_pos_vox_case1, # 1lambda: current_pos_vox_case1, # 2lambda: current_pos_vox_case1, # 3lambda: current_pos_vox_case2, # 4lambda: current_pos_vox_case2, # 5lambda: current_pos_vox_case3, # 6lambda: current_pos_vox_case3, # 7lambda: current_pos_vox_case3, # 8lambda: current_pos_vox_case3, # 9
],
)
goal=jax.lax.switch(
rand_choice,
[
lambda: goal_case1,
lambda: goal_case1,
lambda: goal_case1,
lambda: goal_case1,
lambda: goal_case2,
lambda: goal_case2,
lambda: goal_case3,
lambda: goal_case3,
lambda: goal_case3,
lambda: goal_case3,
],
)
gdt_map=jax.lax.switch(
rand_choice,
[
lambda: gdt_map_case1,
lambda: gdt_map_case1,
lambda: gdt_map_case1,
lambda: gdt_map_case1,
lambda: gdt_map_case2,
lambda: gdt_map_case2,
lambda: gdt_map_case3,
lambda: gdt_map_case3,
lambda: gdt_map_case3,
lambda: gdt_map_case3,
],
)
# Validate start position (simplified check for now)is_valid_start= (
_is_valid_pos(current_pos_vox, params.image_shape)
¶ms.seg[tuple(current_pos_vox.T)]
)
# If start is invalid, default to start_coord and gdt_startcurrent_pos_vox=jax.lax.select(
is_valid_start, current_pos_vox, params.start_coord
)
goal=jax.lax.select(is_valid_start, goal, params.end_coord)
gdt_map=jax.lax.select(is_valid_start, gdt_map, params.gdt_start)
# Initialize path trackingcumulative_path_mask=jnp.zeros_like(params.image, dtype=jnp.bool_)
# Mark initial position on cumulative path maskline=line_nd_jax(current_pos_vox, current_pos_vox, 256)
cumulative_path_mask, _=draw_path_sphere(
cumulative_path_mask, line, params.cumulative_path_radius_vox, True
)
# Initialize various tracking variablescum_reward=0.0max_gdt_achieved=gdt_map[tuple(current_pos_vox.T)]
# Initialize reward_map for peaksreward_map=jnp.zeros_like(params.image, dtype=jnp.float32)
# Scatter 1s at local peaksreward_map=reward_map.at[tuple(params.local_peaks.T)].set(1.0)
wall_gradient=0.0state=SmallBowelState(
time=0, # EnvState's timecurrent_pos_vox=current_pos_vox,
goal=goal,
gdt=gdt_map,
cumulative_path_mask=cumulative_path_mask,
max_gdt_achieved=max_gdt_achieved,
cum_reward=cum_reward,
reward_map=reward_map,
wall_gradient=wall_gradient,
length=0.0,
)
obs=self.get_obs(state, params)
returnobs, state@partial(jax.jit, static_argnames=("self",))defstep_env(
self,
key: jax.Array,
state: SmallBowelState,
action: jnp.ndarray,
params: SmallBowelParams,
) ->tuple[jnp.ndarray, SmallBowelState, jnp.ndarray, jnp.ndarray, dict]:
"""Performs a step transition in the environment."""key, key_reward=jax.random.split(key)
# Extract Action: action is normalized [0, 1]action_mapped= (2*action-1) *params.max_step_voxaction_vox_delta=jnp.round(action_mapped).astype(jnp.int32)
# Calculate next positionnext_pos_vox=state.current_pos_vox+action_vox_delta# Calculate reward and update masksreward, new_cumulative_path_mask, new_reward_map, wall_stuff= (
self._calculate_reward(
state.current_pos_vox,
next_pos_vox,
action_vox_delta,
state.cumulative_path_mask,
state.reward_map,
# Use gdt_start as the base GDT map for reward calculationstate.gdt,
params.seg,
params.wall_map,
params.gt_path_vol,
params.cumulative_path_radius_vox,
params.r_zero_mov,
params.r_val1,
params.r_val2,
params.r_peaks,
params.gdt_max_increase_theta,
state.max_gdt_achieved,
params.image_shape,
params.seg_volume,
)
)
# Update state variablesnew_cum_reward=state.cum_reward+rewardnew_wall_gradient=state.wall_gradient+wall_stuff# Update max_gdt_achieved based on the GDT map used for reward calculationnew_max_gdt_achieved=jnp.maximum(
state.max_gdt_achieved, state.gdt[tuple(next_pos_vox.T)]
)
# Check Termination Conditionsdone=self.is_terminal(
state, params, next_pos_vox, action_vox_delta, wall_stuff
)
# Final Reward Adjustment if donefinal_coverage=jax.lax.select(
done,
self._get_final_coverage(
new_cumulative_path_mask, params.seg, params.seg_volume
),
0.0,
)
# Determine termination reason for final reward adjustmentreached_goal= (
jnp.linalg.norm(next_pos_vox-state.goal)
<params.cumulative_path_radius_vox
)
reward=jax.lax.select(
done,
reward+jax.lax.select(
reached_goal,
final_coverage*params.r_final,
(final_coverage-1) *params.r_final,
),
reward,
)
# Update statestate=SmallBowelState(
time=state.time+1,
current_pos_vox=next_pos_vox,
goal=state.goal,
gdt=state.gdt, # Use gdt_start for the next stepcumulative_path_mask=new_cumulative_path_mask,
max_gdt_achieved=new_max_gdt_achieved,
cum_reward=new_cum_reward,
reward_map=new_reward_map,
wall_gradient=new_wall_gradient,
length=state.length+jnp.linalg.norm(next_pos_vox-state.current_pos_vox),
)
obs=self.get_obs(state, params)
# Info dictionary (we already store a lot in the state, so no real need for it)info= {
"final_coverage": final_coverage,
}
returnobs, state, reward, done, info@partial(jax.jit, static_argnames=("self",))def_calculate_reward(
self,
current_pos_vox: jnp.ndarray,
next_pos_vox: jnp.ndarray,
action_vox_delta: jnp.ndarray,
cumulative_path_mask: jnp.ndarray,
reward_map: jnp.ndarray,
gdt_map: jnp.ndarray,
seg: jnp.ndarray,
wall_map: jnp.ndarray,
gt_path_vol: jnp.ndarray,
cumulative_path_radius_vox: int,
r_zero_mov: float,
r_val1: float,
r_val2: float,
r_peaks: float,
gdt_max_increase_theta: float,
max_gdt_achieved: float,
image_shape: tuple[int, int, int],
seg_volume: float,
) ->tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Calculate the reward for the current step (JAX-compatible)."""rt=jnp.array(0.0, dtype=jnp.float32)
wall_stuff=jnp.array(0.0, dtype=jnp.float32)
is_next_pos_valid=_is_valid_pos(next_pos_vox, image_shape)
# is_in_seg = jnp.where(is_next_pos_valid, seg[tuple(next_pos_vox)], False)# --- 1. Zero movement or goes out of the segmentation penalty ---cond_invalid_move= (
(jnp.all(action_vox_delta==0)) | (~is_next_pos_valid) # | (~is_in_seg)
)
rt=jax.lax.select(cond_invalid_move, rt-r_zero_mov, rt)
# Set of voxels S on the line segmentline=line_nd_jax(current_pos_vox, next_pos_vox, 256)
# Only proceed with rewards if move is validrt, new_cumulative_path_mask, new_reward_map, wall_stuff=jax.lax.cond(
cond_invalid_move,
lambda: (
rt,
cumulative_path_mask,
reward_map,
wall_stuff,
), # If invalid, no further reward/mask updateslambda: self._calculate_valid_move_rewards(
rt,
line,
cumulative_path_mask,
reward_map,
gdt_map,
seg,
wall_map,
gt_path_vol,
cumulative_path_radius_vox,
r_val1,
r_val2,
r_peaks,
gdt_max_increase_theta,
max_gdt_achieved,
next_pos_vox,
seg_volume,
),
)
returnrt, new_cumulative_path_mask, new_reward_map, wall_stuff@partial(jax.jit, static_argnames=("self",))def_calculate_valid_move_rewards(
self,
rt: jnp.ndarray,
line: tuple[jnp.ndarray],
cumulative_path_mask: jnp.ndarray,
reward_map: jnp.ndarray,
gdt_map: jnp.ndarray,
seg: jnp.ndarray,
wall_map: jnp.ndarray,
gt_path_vol: jnp.ndarray,
cumulative_path_radius_vox: int,
r_val1: float,
r_val2: float,
r_peaks: float,
gdt_max_increase_theta: float,
max_gdt_achieved: float,
next_pos_vox: jnp.ndarray,
seg_volume: float,
) ->tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Helper for reward calculation when the move is valid."""# --- 2. GDT-based reward ---next_gdt_val=gdt_map[tuple(next_pos_vox.T)]
delta=next_gdt_val-max_gdt_achievedrt=rt+jnp.where(
delta>0,
jnp.where(
delta>gdt_max_increase_theta,
-r_val2,
r_val2* (delta/gdt_max_increase_theta),
),
0.0,
)
# --- 2.5 Peaks-based reward ---peaks_reward_sum=reward_map.at[line[0], line[1], line[2]].get(
indices_are_sorted=True
)
# Since the (sorted) indices include many repeated voxels# (due to the way we index with a fixed number of samples),# we can find the number of gradient points to get the proper number of peaks.peaks_reward_sum=jnp.sum(
(peaks_reward_sum[1:] -peaks_reward_sum[:-1]) >0
) # Count non-zero peaksrt=rt+peaks_reward_sum*r_peaks# Discard the reward for visited nodes (set to 0 in reward_map)new_reward_map=reward_map.at[line[0], line[1], line[2]].set(0.0)
# Update cumulative path masknew_cumulative_path_mask, coverage=draw_path_sphere(
cumulative_path_mask, line, cumulative_path_radius_vox, True
)
# Reward for coverage (based on Dice within the segmentation on the path)patch=cumulative_path_mask.at[line[0], line[1], line[2]].get(
indices_are_sorted=True
)
coverage=coverage.at[line[0], line[1], line[2]].get(indices_are_sorted=True)
intersection=jnp.sum(patch*coverage)
union=jnp.sum(patch) +jnp.sum(coverage)
coverage=2*intersection/unionrt=rt+coverage*r_val2# --- 3. Wall-based penalty ---wall_map_max=wall_map.at[line[0], line[1], line[2]].get().max()
wall_stuff=wall_map_maxrt=rt-r_val2*wall_map_max*30# 0.03 is basically equivalent to a wall# --- 4. Revisiting penalty ---revisit_penalty= (
new_cumulative_path_mask.at[line[0], line[1], line[2]].get().any()
) # Check against the *new* maskrt=rt-r_val1*jax.lax.select(
revisit_penalty>0, 1.0, 0.0
) # Apply penalty if any overlap# Penalty if next position is outside segmentation# This was `if not self.seg_np[next_pos_vox]: rt -= self.config.r_val1`# This is now covered by `cond_invalid_move` at the top level.returnrt, new_cumulative_path_mask, new_reward_map, wall_stuff@partial(jax.jit, static_argnames=("self",))defget_obs(
self, state: SmallBowelState, params: SmallBowelParams
) ->dict[str, jnp.ndarray]:
"""Get state patches centered at current position."""img_patch=get_patch_jax(
params.image, state.current_pos_vox, params.patch_size_vox
)
wall_patch=get_patch_jax(
params.wall_map, state.current_pos_vox, params.patch_size_vox
)
cum_path_patch=get_patch_jax(
state.cumulative_path_mask,
state.current_pos_vox,
params.patch_size_vox,
)
# Stack patches for actor and critic (can allow for a 4th)state=jnp.stack([img_patch, wall_patch, cum_path_patch], axis=0)
# For Gymnax, we return the raw observation. The agent will handle batching.returnstate@partial(jax.jit, static_argnames=("self",))defis_terminal(
self,
state: SmallBowelState,
params: SmallBowelParams,
next_pos_vox: jnp.ndarray,
action_vox_delta: jnp.ndarray,
wall_stuff: jnp.ndarray,
) ->jnp.ndarray:
"""Check whether state transition is terminal."""# Max steps reachedmax_steps_reached=state.time>=params.max_steps_in_episode# Out of bounds or invalid move (already checked in _calculate_reward,# but re-check for termination)is_next_pos_valid=_is_valid_pos(next_pos_vox, params.image_shape)
# is_in_seg = jnp.where(# is_next_pos_valid, params.seg[tuple(next_pos_vox)] > 0, False# )invalid_move= (
(jnp.all(action_vox_delta==0))
| (~is_next_pos_valid)
# | (wall_stuff > 0.03)# | (~is_in_seg)
)
# Reached goalreached_goal= (
jnp.linalg.norm(next_pos_vox-state.goal)
<params.cumulative_path_radius_vox
)
returnmax_steps_reached|invalid_move|reached_goal
One particularly expensive function both in Torch and in JAX is a line drawing algorithm, which generates the coordinates of a line in 3D space and then expands it to fit within a certain radius. I tried two approaches, one using convolutions to approximate dilation and another by drawing spheres. Here is the second approach (which is slightly faster), but the first one is also available in my code.
@partial(jax.jit, inline=True)defdraw_sphere_point(
array_3d: jnp.ndarray, center_point: jnp.ndarray, radius: int, fill_value=1
):
""" Fills a 3D NumPy array with a specified value inside a sphere. Args: array_3d (np.ndarray): The 3D NumPy array (e.g., of zeros) to modify. Its shape defines the coordinate space. center_point (tuple or list or np.ndarray): The (x, y, z) coordinates of the sphere's center. radius (float or int): The radius of the sphere. fill_value (int or float, optional): The value to fill inside the sphere. Defaults to 1. Returns: np.ndarray: The modified 3D array with the sphere filled. """# Generate 1D coordinate arrays for each axis using np.ogrid# These will broadcast to the full 3D shape for the distance calculation# We take the actual shape of the input array for coordinatesx, y, z=jnp.ogrid[
0 : array_3d.shape[0], 0 : array_3d.shape[1], 0 : array_3d.shape[2]
]
# Calculate the squared distance from the sphere's center for every pointdistance_squared= (
(x-center_point[0]) **2+ (y-center_point[1]) **2+ (z-center_point[2]) **2
)
# Fill the array at the appropriate places using the boolean maskreturnjnp.where(distance_squared<=radius**2, fill_value, array_3d)
@partial(jax.jit)defdraw_path_sphere(array_3d: jnp.ndarray, pts: tuple, radius: int, fill_value=True):
""" Draws a sphere around each point in the path defined by `pts` in the 3D array. Args: array_3d: The 3D array to update. pts: A tuple containing the coordinates of the path points (e.g., from `line_nd_jax`). radius: The radius for the spheres to be drawn around each point. fill_value: The value to fill inside the spheres. """# Draw spheres around each point in the pathnew_array_3d=jax.lax.scan(
lambdaa, p: (draw_sphere_point(a, p, radius, fill_value), p),
jnp.zeros_like(array_3d),
jnp.asarray(pts).T,
)[0]
# Update the original array with the new spheresupdated_array_3d=jnp.maximum(array_3d, new_array_3d)
returnupdated_array_3d, new_array_3d
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi there, thanks for your work and looking at this. I am trying to port my environment definition to JAX with the hopes that I can run massively parallel environment rollouts for my RL algorithm.
The primary goal is to navigate in a 3D maze. Ignoring the details for now, I based my implementation upon Gym and Gymnax conventions here: https://github.com/m-krastev/gymnax
To make use of massive parallelism, I rely on my whole code being JIT-able in reasonable margins. In this case, the code can be compiled. However, it runs around 1-2x slower compared to a similar code written in numpy + Torch (with most work done on CPU). Upon profiling, I see that a huge amount of time is spent by memcpyD2D, and I struggle to find what could cause the compiler to not in-place operations.
I provide my main code (gymnax/environments/medical/small_bowel.py) with some parts omitted just to fit within reasonable margins. SmallBowelParams are meant to be immutable once instantiated and only the state being passed around with new values.
Environment code:
One particularly expensive function both in Torch and in JAX is a line drawing algorithm, which generates the coordinates of a line in 3D space and then expands it to fit within a certain radius. I tried two approaches, one using convolutions to approximate dilation and another by drawing spheres. Here is the second approach (which is slightly faster), but the first one is also available in my code.
Beta Was this translation helpful? Give feedback.
All reactions