Skip to content

Commit 51a4e61

Browse files
committed
velocity noise
1 parent f7c32fa commit 51a4e61

4 files changed

Lines changed: 57 additions & 47 deletions

File tree

baselines/MAPPO/config/mappo_homogenous_transf_utracking.yaml

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
LR: 0.0005
2-
NUM_ENVS: 1024
1+
LR: 0.0001
2+
NUM_ENVS: 512
33
NUM_STEPS: 256
44
TOTAL_TIMESTEPS: 2e9
55
HIDDEN_DIM: 64
@@ -16,7 +16,7 @@ ENT_COEF: 0.01
1616
VF_COEF: 0.5
1717
MAX_GRAD_NORM: 0.5
1818
ACTIVATION: relu
19-
ANNEAL_LR: true
19+
ANNEAL_LR: false
2020

2121
# Environment configuration
2222
"ENV_NAME": "utracking"
@@ -25,29 +25,35 @@ ANNEAL_LR: true
2525
"num_landmarks": 1,
2626
"max_steps": 1024,
2727
"dt": 30,
28-
"difficulty": "hard",
28+
"difficulty": "manual",
29+
"landmark_rel_speed": [0.5, 0.8],
30+
"dirchange_time_range_landmark": [100, 200],
31+
"rudder_range_landmark": [0.1, 0.2],
2932
"rew_follow_coeff": 1.0,
3033
"rew_tracking_coeff": 1.0,
3134
"steps_for_new_range": 4,
35+
"traj_noise_std": 0.2,
36+
"velocity_noise_std": 0.2,
3237
"max_range_dist": 1500.0,
3338
"min_init_distance": 30.0,
34-
"max_init_distance": 1000.0,
39+
"max_init_distance": 800.0,
3540
"matrix_obs": true,
3641
"matrix_state": true,
42+
"pre_init_pos_len": 10000000,
3743
}
3844

3945

4046
# Experiment settings
41-
SEED: 0
47+
SEED: 3
4248
NUM_SEEDS: 1
4349
TUNE: false
44-
SAVE_PATH: models/utracking_mbari
45-
LOAD_PATH: #models/utracking_mbari/utracking_2_vs_1/mappo_transformer_tracking_rew/mappo_transformer_tracking_rew_utracking_2_vs_1_seed0_vmap0.safetensors # models/utracking_post/utracking_5_vs_5/mappo_transformer_hard_128steps/mappo_transformer_hard_128steps_utracking_5_vs_5_seed2.safetensors
46-
LOAD_CRITIC: # True
47-
ALG_NAME: mappo_transformer_newobs_hard
48-
CHECKPOINT_INTERVAL: 0.05 # perecentage of total update steps
50+
SAVE_PATH: #models/17november
51+
LOAD_PATH: #models/17november/utracking_2_vs_1/mappo_transformer_noisy_more_linear_4steps_3rdrun/mappo_transformer_noisy_more_linear_4steps_3rdrun_utracking_2_vs_1_step760_rng1946498123.safetensors # "models/17november/utracking_1_vs_1/mappo_transformer_noisy/mappo_transformer_noisy_utracking_1_vs_1_step3230_rng1948878966.safetensors" #models/utracking_mbari/utracking_1_vs_1/mappo_transformer_newobs_medium_correct_norm/mappo_transformer_newobs_medium_correct_norm_utracking_1_vs_1_step2660_rng1948878966.safetensors #models/utracking_mbari/utracking_2_vs_1/mappo_transformer_newobs_medium_2ndrun/mappo_transformer_newobs_medium_2ndrun_utracking_2_vs_1_step3810_rng928981903.safetensors #models/utracking_mbari/utracking_1_vs_1_newobs/mappo_transformer_newobs_medium_utracking_1_vs_1_step4191_rng928981903.safetensors # models/utracking_post/utracking_5_vs_5/mappo_transformer_hard_128steps/mappo_transformer_hard_128steps_utracking_5_vs_5_seed2.safetensors
52+
LOAD_CRITIC: #True #False
53+
ALG_NAME: mappo_transformer_noisy_more_linear_4steps
54+
CHECKPOINT_INTERVAL: 0.02 # perecentage of total update steps
4955
ANIMATION_LOG_INTERVAL: 0.1 # percentage of total update steps
50-
ANIMATION_MAX_STEPS: 128 # should match the env
56+
ANIMATION_MAX_STEPS: 1024 # should match the env
5157

5258
# Weights & Biases logging
5359
WANDB_MODE: online

baselines/MAPPO/mappo_transformer_utracking.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,8 +767,6 @@ def callback(
767767
env_name = f'utracking_{config["ENV_KWARGS"]["num_agents"]}_vs_{config["ENV_KWARGS"]["num_landmarks"]}'
768768
alg_name = config.get("ALG_NAME", "mappo_rnn_utracking")
769769

770-
print("Saving Checkpoint")
771-
772770
model_state = {
773771
"actor": model_state[0].params,
774772
"critic": model_state[1].params,
@@ -781,6 +779,7 @@ def callback(
781779
f"{alg_name}_{env_name}_step{int(metrics['update_steps'])}_rng{int(original_seed)}.safetensors",
782780
)
783781
save_params(model_state, save_path)
782+
print("Checkpoint saved at", save_path)
784783

785784
if config.get("ANIMATION_LOG_INTERVAL", None) is not None:
786785

jaxmarl/environments/utracking/particle_filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
self,
2828
num_particles=5000,
2929
std_range=10, # m (standard deviation error of the range measurements)
30-
mu_init_vel=1.0, # m/s
30+
mu_init_vel=2.0, # m/s
3131
std_init_vel=0.6, # m/s
3232
turn_noise=0.5, # rad
3333
vel_noise=0.10, # m/s

jaxmarl/environments/utracking/utracking.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ def __init__(
124124
tracking_method: str = "pf", # method for tracking the landmarks positions (ls, pf)
125125
tracking_buffer_len: int = 32, # maximum number of range observations kept for predicting the landmark positions
126126
range_noise_std: float = 10.0, # standard deviation of the gaussian noise added to range measurements (meters)
127-
traj_noise_std: float = 0.02, # standard deviation of the gaussian noise added to the traj models (radians)
127+
traj_noise_std: float = 0.1, # standard deviation of the gaussian noise added to the traj models (radians)
128+
velocity_noise_std: float = 0.1, # standard deviation of the gaussian noise added to the velocity (meters/second)
128129
lost_comm_prob=0.1, # probability of loosing communications (range measurements and intra-agent communication)
129130
min_steps_ls: int = 2, # minimum steps for collecting data and start predicting landmarks positions with least squares
130131
rew_dist_thr: float = 150.0, # distance threshold for the follow reward
@@ -218,7 +219,7 @@ def __init__(
218219
self.tracking_buffer_len = tracking_buffer_len
219220
self.range_noise_std = range_noise_std
220221
self.traj_noise_std = traj_noise_std
221-
self.traj_noise_std = traj_noise_std
222+
self.velocity_noise_std = velocity_noise_std
222223
self.lost_comm_prob = lost_comm_prob
223224
self.min_steps_ls = min_steps_ls
224225
self.rew_dist_thr = rew_dist_thr
@@ -487,27 +488,20 @@ def world_step(
487488
else:
488489
# update the angle
489490
angle_change = actions * traj_coeffs + traj_intercepts
490-
# add noise
491-
angle_change += (
491+
# add noise to agents only (not landmarks)
492+
angle_noise = (
492493
jax.random.normal(rng, shape=angle_change.shape) * self.traj_noise_std
493494
)
494-
new_angles = (pos[:, -1] + angle_change + jnp.pi) % (2 * jnp.pi) - jnp.pi
495-
# update the x-y position (depth remains constant)
496-
pos = pos.at[:, -1].set(new_angles)
497-
498-
if self.actions_as_angles:
499-
pos = pos.at[:, -1].set(actions)
500-
else:
501-
# update the angle
502-
angle_change = actions * traj_coeffs + traj_intercepts
503-
# add noise
504-
angle_change += (
505-
jax.random.normal(rng, shape=angle_change.shape) * self.traj_noise_std
495+
angle_change = angle_change.at[: self.num_agents].add(
496+
angle_noise[: self.num_agents]
506497
)
507498
new_angles = (pos[:, -1] + angle_change + jnp.pi) % (2 * jnp.pi) - jnp.pi
508499
# update the x-y position (depth remains constant)
509500
pos = pos.at[:, -1].set(new_angles)
510501

502+
# add noise to velocity of agents only (not landmarks)
503+
vel_noise = jax.random.normal(rng, shape=vel.shape) * self.velocity_noise_std
504+
vel = vel.at[: self.num_agents].add(vel_noise[: self.num_agents])
511505
pos = pos.at[:, 0].add(jnp.cos(pos[:, -1]) * vel * self.dt)
512506
pos = pos.at[:, 1].add(jnp.sin(pos[:, -1]) * vel * self.dt)
513507
return pos
@@ -666,16 +660,17 @@ def get_obs(self, delta_xyz, ranges, comm_drop, pos, old_pos, land_pred_pos):
666660
delta_self_pos = delta_self_pos.at[:, 3].set(
667661
(delta_self_pos[:, 3] + jnp.pi) / (2 * jnp.pi)
668662
)
669-
delta_xyz = self.normalize_distances(delta_xyz)
670663

664+
# other entities relative distance
665+
delta_xyz = self.normalize_distances(delta_xyz)
671666
other_agents_dist = jnp.where(
672667
comm_drop[:, :, None], 0, delta_xyz[:, : self.num_agents]
673668
) # 0 for communication drop
674669
self_mask = (
675670
jnp.arange(self.num_agents) == np.arange(self.num_agents)[:, np.newaxis]
676671
)
677672
self_pos_feats = delta_self_pos[: self.num_agents, [0, 1, 3]]
678-
self_pos_feats = self_pos_feats.at[:, 2].set(0)
673+
self_pos_feats = self_pos_feats.at[:, 2].set(0) # set angle to 0 for now
679674
agents_rel_pos = jnp.where(
680675
self_mask[:, :, None],
681676
self_pos_feats,
@@ -693,18 +688,20 @@ def get_obs(self, delta_xyz, ranges, comm_drop, pos, old_pos, land_pred_pos):
693688
is_self_feat = (
694689
jnp.arange(self.num_entities) == jnp.arange(self.num_agents)[:, np.newaxis]
695690
)
691+
ranges = self.normalize_distances(ranges)
696692
ranges *= 1.0 if self.ranges_in_obs else 0.0 # mask the ranges if not in obs
697693
# the distance based feats are rescaled to hundreds of meters (better for NNs)
698694

699695
feats = jnp.concatenate(
700696
(
701697
pos_feats,
702-
self.normalize_distances(ranges[:, :, None]),
698+
ranges[:, :, None],
703699
is_agent_feat[:, :, None],
704700
is_self_feat[:, :, None],
705701
),
706702
axis=2,
707703
)
704+
feats = jnp.where(jnp.isnan(feats), 0.0, feats) # replace nan with 0
708705

709706
# than it is assigned to each agent its obs
710707
return {
@@ -830,6 +827,8 @@ def get_global_state(
830827
else:
831828
state = self.get_vertex_state(pos, vel, ranges, land_pred_pos)
832829

830+
state = jnp.where(jnp.isnan(state), 0.0, state) # replace nan with 0
831+
833832
if self.matrix_state:
834833
return state
835834
else:
@@ -912,6 +911,8 @@ def exponential_decay(x, x1=self.rew_pred_ideal, x2=self.rew_pred_thr):
912911
).any()
913912
rew = jnp.where(any_agent_lost, -1.0, rew)
914913

914+
rew = jnp.where(jnp.isnan(rew), 0.0, rew) # replace nan with 0
915+
915916
# DONE
916917
done = t == self.max_steps
917918

@@ -1013,28 +1014,32 @@ def get_ranges(
10131014
jax.random.normal(key_noise, shape=ranges_real.shape) * self.range_noise_std
10141015
)
10151016

1017+
# Add noise to 3D range measurement (physically correct)
10161018
ranges = ranges_real + noise
10171019
lost_range = (
10181020
jax.random.uniform(key_lost, shape=ranges.shape) <= self.lost_comm_prob
10191021
) | (
10201022
ranges_real > self.max_range_dist
10211023
) # lost communication or landmark too far
10221024
ranges = jnp.where(lost_range, 0.0, ranges)
1023-
lost_range = (
1024-
jax.random.uniform(key_lost, shape=ranges.shape) <= self.lost_comm_prob
1025-
) | (
1026-
ranges_real > self.max_range_dist
1027-
) # lost communication or landmark too far
1028-
ranges = jnp.where(lost_range, 0.0, ranges)
10291025
ranges = fill_diagonal_zeros(ranges) # reset to 0s the self-ranges
10301026

1031-
ranges_2d = ranges_real_2d + noise
1032-
ranges_2d = jnp.where(lost_range, 0.0, ranges_2d)
1033-
ranges_2d = fill_diagonal_zeros(ranges_2d)
1034-
1035-
ranges_2d = ranges_real_2d + noise
1036-
ranges_2d = jnp.where(lost_range, 0.0, ranges_2d)
1037-
ranges_2d = fill_diagonal_zeros(ranges_2d)
1027+
# Convert noisy 3D measurement to 2D (if depth is known)
1028+
# This is the physically correct way: measure 3D with noise, then convert to 2D
1029+
if self.landmark_depth_known:
1030+
# Calculate depth differences for landmarks (agents to landmarks)
1031+
delta_z = pos[: self.num_agents, np.newaxis, 2] - pos[:, 2]
1032+
# Convert noisy 3D range to 2D: r_2d = sqrt(r_3d² - dz²)
1033+
# Use jnp.maximum to avoid negative values under sqrt
1034+
ranges_2d_squared = jnp.maximum(ranges**2 - delta_z**2, 0.0)
1035+
ranges_2d = jnp.sqrt(ranges_2d_squared)
1036+
# Set to 0 where communication was lost
1037+
ranges_2d = jnp.where(lost_range, 0.0, ranges_2d)
1038+
ranges_2d = fill_diagonal_zeros(ranges_2d)
1039+
else:
1040+
# If depth is not known, can't convert - use 3D ranges directly
1041+
# (This is less realistic but keeps backward compatibility)
1042+
ranges_2d = ranges
10381043

10391044
return delta_xyz, ranges_real_2d, ranges_real, ranges_2d, ranges
10401045

0 commit comments

Comments
 (0)