@@ -124,7 +124,8 @@ def __init__(
124124 tracking_method : str = "pf" , # method for tracking the landmarks positions (ls, pf)
125125 tracking_buffer_len : int = 32 , # maximum number of range observations kept for predicting the landmark positions
126126 range_noise_std : float = 10.0 , # standard deviation of the gaussian noise added to range measurements (meters)
127- traj_noise_std : float = 0.02 , # standard deviation of the gaussian noise added to the traj models (radians)
127+ traj_noise_std : float = 0.1 , # standard deviation of the gaussian noise added to the traj models (radians)
128+ velocity_noise_std : float = 0.1 , # standard deviation of the gaussian noise added to the velocity (meters/second)
128129 lost_comm_prob = 0.1 , # probability of loosing communications (range measurements and intra-agent communication)
129130 min_steps_ls : int = 2 , # minimum steps for collecting data and start predicting landmarks positions with least squares
130131 rew_dist_thr : float = 150.0 , # distance threshold for the follow reward
@@ -218,7 +219,7 @@ def __init__(
218219 self .tracking_buffer_len = tracking_buffer_len
219220 self .range_noise_std = range_noise_std
220221 self .traj_noise_std = traj_noise_std
221- self .traj_noise_std = traj_noise_std
222+ self .velocity_noise_std = velocity_noise_std
222223 self .lost_comm_prob = lost_comm_prob
223224 self .min_steps_ls = min_steps_ls
224225 self .rew_dist_thr = rew_dist_thr
@@ -487,27 +488,20 @@ def world_step(
487488 else :
488489 # update the angle
489490 angle_change = actions * traj_coeffs + traj_intercepts
490- # add noise
491- angle_change + = (
491+ # add noise to agents only (not landmarks)
492+ angle_noise = (
492493 jax .random .normal (rng , shape = angle_change .shape ) * self .traj_noise_std
493494 )
494- new_angles = (pos [:, - 1 ] + angle_change + jnp .pi ) % (2 * jnp .pi ) - jnp .pi
495- # update the x-y position (depth remains constant)
496- pos = pos .at [:, - 1 ].set (new_angles )
497-
498- if self .actions_as_angles :
499- pos = pos .at [:, - 1 ].set (actions )
500- else :
501- # update the angle
502- angle_change = actions * traj_coeffs + traj_intercepts
503- # add noise
504- angle_change += (
505- jax .random .normal (rng , shape = angle_change .shape ) * self .traj_noise_std
495+ angle_change = angle_change .at [: self .num_agents ].add (
496+ angle_noise [: self .num_agents ]
506497 )
507498 new_angles = (pos [:, - 1 ] + angle_change + jnp .pi ) % (2 * jnp .pi ) - jnp .pi
508499 # update the x-y position (depth remains constant)
509500 pos = pos .at [:, - 1 ].set (new_angles )
510501
502+ # add noise to velocity of agents only (not landmarks)
503+ vel_noise = jax .random .normal (rng , shape = vel .shape ) * self .velocity_noise_std
504+ vel = vel .at [: self .num_agents ].add (vel_noise [: self .num_agents ])
511505 pos = pos .at [:, 0 ].add (jnp .cos (pos [:, - 1 ]) * vel * self .dt )
512506 pos = pos .at [:, 1 ].add (jnp .sin (pos [:, - 1 ]) * vel * self .dt )
513507 return pos
@@ -666,16 +660,17 @@ def get_obs(self, delta_xyz, ranges, comm_drop, pos, old_pos, land_pred_pos):
666660 delta_self_pos = delta_self_pos .at [:, 3 ].set (
667661 (delta_self_pos [:, 3 ] + jnp .pi ) / (2 * jnp .pi )
668662 )
669- delta_xyz = self .normalize_distances (delta_xyz )
670663
664+ # other entities relative distance
665+ delta_xyz = self .normalize_distances (delta_xyz )
671666 other_agents_dist = jnp .where (
672667 comm_drop [:, :, None ], 0 , delta_xyz [:, : self .num_agents ]
673668 ) # 0 for communication drop
674669 self_mask = (
675670 jnp .arange (self .num_agents ) == np .arange (self .num_agents )[:, np .newaxis ]
676671 )
677672 self_pos_feats = delta_self_pos [: self .num_agents , [0 , 1 , 3 ]]
678- self_pos_feats = self_pos_feats .at [:, 2 ].set (0 )
673+ self_pos_feats = self_pos_feats .at [:, 2 ].set (0 ) # set angle to 0 for now
679674 agents_rel_pos = jnp .where (
680675 self_mask [:, :, None ],
681676 self_pos_feats ,
@@ -693,18 +688,20 @@ def get_obs(self, delta_xyz, ranges, comm_drop, pos, old_pos, land_pred_pos):
693688 is_self_feat = (
694689 jnp .arange (self .num_entities ) == jnp .arange (self .num_agents )[:, np .newaxis ]
695690 )
691+ ranges = self .normalize_distances (ranges )
696692 ranges *= 1.0 if self .ranges_in_obs else 0.0 # mask the ranges if not in obs
697693 # the distance based feats are rescaled to hundreds of meters (better for NNs)
698694
699695 feats = jnp .concatenate (
700696 (
701697 pos_feats ,
702- self . normalize_distances ( ranges [:, :, None ]) ,
698+ ranges [:, :, None ],
703699 is_agent_feat [:, :, None ],
704700 is_self_feat [:, :, None ],
705701 ),
706702 axis = 2 ,
707703 )
704+ feats = jnp .where (jnp .isnan (feats ), 0.0 , feats ) # replace nan with 0
708705
709706 # than it is assigned to each agent its obs
710707 return {
@@ -830,6 +827,8 @@ def get_global_state(
830827 else :
831828 state = self .get_vertex_state (pos , vel , ranges , land_pred_pos )
832829
830+ state = jnp .where (jnp .isnan (state ), 0.0 , state ) # replace nan with 0
831+
833832 if self .matrix_state :
834833 return state
835834 else :
@@ -912,6 +911,8 @@ def exponential_decay(x, x1=self.rew_pred_ideal, x2=self.rew_pred_thr):
912911 ).any ()
913912 rew = jnp .where (any_agent_lost , - 1.0 , rew )
914913
914+ rew = jnp .where (jnp .isnan (rew ), 0.0 , rew ) # replace nan with 0
915+
915916 # DONE
916917 done = t == self .max_steps
917918
@@ -1013,28 +1014,32 @@ def get_ranges(
10131014 jax .random .normal (key_noise , shape = ranges_real .shape ) * self .range_noise_std
10141015 )
10151016
1017+ # Add noise to 3D range measurement (physically correct)
10161018 ranges = ranges_real + noise
10171019 lost_range = (
10181020 jax .random .uniform (key_lost , shape = ranges .shape ) <= self .lost_comm_prob
10191021 ) | (
10201022 ranges_real > self .max_range_dist
10211023 ) # lost communication or landmark too far
10221024 ranges = jnp .where (lost_range , 0.0 , ranges )
1023- lost_range = (
1024- jax .random .uniform (key_lost , shape = ranges .shape ) <= self .lost_comm_prob
1025- ) | (
1026- ranges_real > self .max_range_dist
1027- ) # lost communication or landmark too far
1028- ranges = jnp .where (lost_range , 0.0 , ranges )
10291025 ranges = fill_diagonal_zeros (ranges ) # reset to 0s the self-ranges
10301026
1031- ranges_2d = ranges_real_2d + noise
1032- ranges_2d = jnp .where (lost_range , 0.0 , ranges_2d )
1033- ranges_2d = fill_diagonal_zeros (ranges_2d )
1034-
1035- ranges_2d = ranges_real_2d + noise
1036- ranges_2d = jnp .where (lost_range , 0.0 , ranges_2d )
1037- ranges_2d = fill_diagonal_zeros (ranges_2d )
1027+ # Convert noisy 3D measurement to 2D (if depth is known)
1028+ # This is the physically correct way: measure 3D with noise, then convert to 2D
1029+ if self .landmark_depth_known :
1030+ # Calculate depth differences for landmarks (agents to landmarks)
1031+ delta_z = pos [: self .num_agents , np .newaxis , 2 ] - pos [:, 2 ]
1032+ # Convert noisy 3D range to 2D: r_2d = sqrt(r_3d² - dz²)
1033+ # Use jnp.maximum to avoid negative values under sqrt
1034+ ranges_2d_squared = jnp .maximum (ranges ** 2 - delta_z ** 2 , 0.0 )
1035+ ranges_2d = jnp .sqrt (ranges_2d_squared )
1036+ # Set to 0 where communication was lost
1037+ ranges_2d = jnp .where (lost_range , 0.0 , ranges_2d )
1038+ ranges_2d = fill_diagonal_zeros (ranges_2d )
1039+ else :
1040+ # If depth is not known, can't convert - use 3D ranges directly
1041+ # (This is less realistic but keeps backward compatibility)
1042+ ranges_2d = ranges
10381043
10391044 return delta_xyz , ranges_real_2d , ranges_real , ranges_2d , ranges
10401045
0 commit comments