diff --git a/isaaclab_arena/assets/object_base.py b/isaaclab_arena/assets/object_base.py index 85585c6ca..b8b3c4e4b 100644 --- a/isaaclab_arena/assets/object_base.py +++ b/isaaclab_arena/assets/object_base.py @@ -15,8 +15,9 @@ from isaaclab_arena.assets.asset import Asset from isaaclab_arena.relations.relations import AtPosition, Relation, RelationBase -from isaaclab_arena.terms.events import set_object_pose -from isaaclab_arena.utils.pose import Pose, PoseRange +from isaaclab_arena.terms.events import set_object_pose, set_object_pose_per_env +from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox +from isaaclab_arena.utils.pose import Pose, PosePerEnv, PoseRange from isaaclab_arena.utils.velocity import Velocity @@ -42,13 +43,13 @@ def __init__( prim_path = "{ENV_REGEX_NS}/" + self.name self.prim_path = prim_path self.object_type = object_type - self.initial_pose: Pose | PoseRange | None = None + self.initial_pose: Pose | PoseRange | PosePerEnv | None = None self.initial_velocity: Velocity | None = None self.object_cfg = None self.event_cfg = None self.relations: list[RelationBase] = [] - def get_initial_pose(self) -> Pose | PoseRange | None: + def get_initial_pose(self) -> Pose | PoseRange | PosePerEnv | None: """Return the current initial pose of this object. Subclasses may override to derive the pose from other sources @@ -56,24 +57,38 @@ def get_initial_pose(self) -> Pose | PoseRange | None: """ return self.initial_pose + @abstractmethod + def get_bounding_box(self) -> AxisAlignedBoundingBox: + """Get local bounding box (relative to object origin).""" + ... + + @abstractmethod + def get_world_bounding_box(self) -> AxisAlignedBoundingBox: + """Get bounding box in world coordinates (local bbox rotated and translated).""" + ... + def _get_initial_pose_as_pose(self) -> Pose | None: """Return a single ``Pose`` suitable for *init_state* and bounding-box calculations. If the initial pose is a ``PoseRange``, its midpoint is returned. + If the initial pose is a ``PosePerEnv``, the first environment's pose is returned. If the initial pose is ``None``, ``None`` is returned. """ initial_pose = self.get_initial_pose() if initial_pose is None: return None + if isinstance(initial_pose, PosePerEnv): + return initial_pose.poses[0] if isinstance(initial_pose, PoseRange): return initial_pose.get_midpoint() return initial_pose - def set_initial_pose(self, pose: Pose | PoseRange) -> None: + def set_initial_pose(self, pose: Pose | PoseRange | PosePerEnv) -> None: """Set / override the initial pose and rebuild derived configs. Args: - pose: A fixed ``Pose`` or a ``PoseRange`` (randomised on reset). + pose: A fixed ``Pose``, a ``PoseRange`` (randomised on reset), + or a ``PosePerEnv`` (distinct pose per environment). """ self.initial_pose = pose initial_pose = self._get_initial_pose_as_pose() @@ -116,7 +131,16 @@ def _init_event_cfg(self) -> EventTermCfg | None: return None initial_pose = self.get_initial_pose() - if isinstance(initial_pose, PoseRange): + if isinstance(initial_pose, PosePerEnv): + return EventTermCfg( + func=set_object_pose_per_env, + mode="reset", + params={ + "asset_cfg": SceneEntityCfg(self.name), + "pose_list": initial_pose.poses, + }, + ) + elif isinstance(initial_pose, PoseRange): return EventTermCfg( func=randomize_object_pose, mode="reset", diff --git a/isaaclab_arena/environments/arena_env_builder.py b/isaaclab_arena/environments/arena_env_builder.py index 3d1543142..e5a01e4e5 100644 --- a/isaaclab_arena/environments/arena_env_builder.py +++ b/isaaclab_arena/environments/arena_env_builder.py @@ -27,6 +27,7 @@ from isaaclab_arena.metrics.recorder_manager_utils import metrics_to_recorder_manager_cfg from isaaclab_arena.relations.object_placer import ObjectPlacer from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams +from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult from isaaclab_arena.relations.relations import IsAnchor, NoCollision from isaaclab_arena.tasks.no_task import NoTask from isaaclab_arena.utils.configclass import combine_configclass_instances @@ -99,11 +100,24 @@ def _solve_relations(self) -> None: self._add_pairwise_no_collision(objects_with_relations) # Run the ObjectPlacer (default on_relation_z_tolerance_m accommodates solver residual). + # Positions are applied to objects via set_initial_pose (single-env: Pose/PoseRange, + # multi-env: PosePerEnv), so each object's event_cfg handles its own reset. placement_seed = getattr(self.args, "placement_seed", None) placer = ObjectPlacer(params=ObjectPlacerParams(placement_seed=placement_seed)) - result = placer.place(objects=objects_with_relations) - - if result.success: + num_envs = self.args.num_envs + result = placer.place(objects_with_relations, num_envs=num_envs) + + # Log outcome + if isinstance(result, MultiEnvPlacementResult): + n_succeeded = sum(1 for r in result.results if r.success) + if n_succeeded == num_envs: + print(f"Relation solving succeeded for all {num_envs} env(s) after {result.attempts} attempt(s)") + else: + print( + f"Relation solving: {n_succeeded}/{num_envs} env(s) passed validation after" + f" {result.attempts} attempt(s)." + ) + elif result.success: print(f"Relation solving succeeded after {result.attempts} attempt(s)") else: print(f"Relation solving not completed after {result.attempts} attempt(s)") diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 8f9d62928..bfba41096 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -5,19 +5,34 @@ from __future__ import annotations +import math import torch +from dataclasses import dataclass from typing import TYPE_CHECKING from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams -from isaaclab_arena.relations.placement_result import PlacementResult +from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult from isaaclab_arena.relations.relation_solver import RelationSolver from isaaclab_arena.relations.relations import On, RandomAroundSolution, RotateAroundSolution, get_anchor_objects from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, get_random_pose_within_bounding_box -from isaaclab_arena.utils.pose import Pose +from isaaclab_arena.utils.pose import Pose, PosePerEnv if TYPE_CHECKING: - from isaaclab_arena.assets.object import Object - from isaaclab_arena.assets.object_reference import ObjectReference + from isaaclab_arena.assets.object_base import ObjectBase + + +@dataclass +class PlacementCandidate: + """A single solver result, ranked and selected in ObjectPlacer.place().""" + + loss: float + """Loss value returned by the solver.""" + + positions: dict[ObjectBase, tuple[float, float, float]] + """Solved positions for each object.""" + + is_valid: bool + """Whether the placement passed validation checks.""" class ObjectPlacer: @@ -29,6 +44,8 @@ class ObjectPlacer: 3. Validating the result 4. Retrying if necessary 5. Applying solved positions to objects + + Supports single-env (num_envs=1) and batched (num_envs>1) placement. """ def __init__(self, params: ObjectPlacerParams | None = None): @@ -42,16 +59,24 @@ def __init__(self, params: ObjectPlacerParams | None = None): def place( self, - objects: list[Object | ObjectReference], - ) -> PlacementResult: + objects: list[ObjectBase], + num_envs: int = 1, + result_per_env: bool = True, + ) -> PlacementResult | MultiEnvPlacementResult: """Place objects according to their spatial relations. Args: objects: List of objects to place. Must include at least one object marked with IsAnchor() which serves as a fixed reference. + num_envs: Number of environments. 1 for single-env; > 1 for batched + placement (one layout per env). + result_per_env: When True (default), each environment gets a distinct + layout. When False, a single best layout is solved and applied + identically to all environments (useful for deterministic evaluation). Returns: - PlacementResult with success status, positions, loss, and attempt count. + PlacementResult when a single layout is produced (num_envs=1 or + result_per_env=False); MultiEnvPlacementResult otherwise. """ # Validate all objects have at least one relation for obj in objects: @@ -76,63 +101,62 @@ def place( anchor_objects_set = set(anchor_objects) - # Save RNG state and set seed if provided (for reproducibility without affecting Isaac Sim) - rng_state = None - if self.params.placement_seed is not None: - rng_state = torch.get_rng_state() - torch.manual_seed(self.params.placement_seed) - # Determine bounds for random position initialization from the first anchor object # TODO(cvolk): The user should not need to know about the bounds to set. # Implement an initialization strategy that infers from the Relations(s). init_bounds = self._get_init_bounds(anchor_objects[0]) - # Placement loop with retries - best_positions: dict[Object | ObjectReference, tuple[float, float, float]] = {} - best_loss = float("inf") - success = False - - for attempt in range(self.params.max_placement_attempts): - # Generate starting positions (anchors from their poses, others random) - initial_positions = self._generate_initial_positions(objects, anchor_objects_set, init_bounds) - - # Solve - positions = self._solver.solve(objects, initial_positions) - loss = self._solver.last_loss_history[-1] if self._solver.last_loss_history else float("inf") - - if self.params.verbose: - print(f"Attempt {attempt + 1}/{self.params.max_placement_attempts}: loss = {loss:.6f}") - - # Check if placement is valid - if self._validate_placement(positions): - best_loss = loss - best_positions = positions - success = True - if self.params.verbose: - print(f"Success on attempt {attempt + 1}") - break - - # Track best invalid result as fallback - if loss < best_loss: - best_loss = loss - best_positions = positions - - # Apply solved positions to objects - if self.params.apply_positions_to_objects: - self._apply_positions(best_positions, anchor_objects_set) + # Pool-based placement: generate all candidates in one batched call, + # then pick the best num_results (environments are homogeneous so any + # valid solution can serve any environment). + num_results = num_envs if result_per_env else 1 + num_candidates = self.params.max_placement_attempts * num_results + + initial_positions = self._generate_initial_positions(objects, anchor_objects_set, init_bounds, num_candidates) + + all_positions = self._solver.solve(objects, initial_positions) + assert self._solver.last_loss_per_env is not None + all_losses: list[float] = self._solver.last_loss_per_env.cpu().tolist() + + all_candidates: list[PlacementCandidate] = [] + for idx in range(num_candidates): + loss = all_losses[idx] + is_valid = self._validate_placement(all_positions[idx]) + all_candidates.append(PlacementCandidate(loss, all_positions[idx], is_valid)) + + # Sort: valid solutions first (by loss), then invalid (by loss) + all_candidates.sort(key=lambda candidate: (not candidate.is_valid, candidate.loss)) + selected = all_candidates[:num_results] + + n_valid = sum(1 for candidate in selected if candidate.is_valid) + if self.params.verbose: + total_valid = sum(1 for candidate in all_candidates if candidate.is_valid) + finite_losses = [candidate.loss for candidate in all_candidates if math.isfinite(candidate.loss)] + mean_loss = sum(finite_losses) / len(finite_losses) if finite_losses else float("inf") + print( + f"Solved {num_candidates} candidates in one batch: mean loss = {mean_loss:.6f}," + f" {total_valid} valid, selected best {num_results} ({n_valid} valid)" + ) - # Restore RNG state if we changed it - if rng_state is not None: - torch.set_rng_state(rng_state) + final_per_env: list[dict] = [candidate.positions for candidate in selected] + results_per_env = [ + PlacementResult( + success=candidate.is_valid, + positions=candidate.positions, + final_loss=candidate.loss, + attempts=self.params.max_placement_attempts, + ) + for candidate in selected + ] - return PlacementResult( - success=success, - positions=best_positions, - final_loss=best_loss, - attempts=attempt + 1, - ) + if self.params.apply_positions_to_objects: + self._apply_positions(final_per_env, anchor_objects_set) + + if num_results == 1: + return results_per_env[0] + return MultiEnvPlacementResult(results=results_per_env) - def _get_init_bounds(self, anchor_object: Object | ObjectReference) -> AxisAlignedBoundingBox: + def _get_init_bounds(self, anchor_object: ObjectBase) -> AxisAlignedBoundingBox: """Get bounds for random position initialization. If init_bounds is provided in params, use it. @@ -153,29 +177,41 @@ def _get_init_bounds(self, anchor_object: Object | ObjectReference) -> AxisAlign def _generate_initial_positions( self, - objects: list[Object | ObjectReference], - anchor_objects: Object | ObjectReference, + objects: list[ObjectBase], + anchor_objects: set[ObjectBase], init_bounds: AxisAlignedBoundingBox, - ) -> dict[Object | ObjectReference, tuple[float, float, float]]: - """Generate initial positions for all objects. - - Anchors keep their current initial_pose, others get random positions. - - Returns: - Dictionary mapping all objects to their starting positions. + num_candidates: int, + ) -> list[dict[ObjectBase, tuple[float, float, float]]]: + """Generate initial positions for ``num_candidates`` placement candidates. + + Each candidate maps every object to a starting position: anchors keep + their current ``initial_pose``; others receive a random position within + ``init_bounds``. When ``placement_seed`` is set, each candidate gets a + deterministic seed (``placement_seed + candidate_idx``). """ - positions: dict[Object | ObjectReference, tuple[float, float, float]] = {} - for obj in objects: - if obj in anchor_objects: - positions[obj] = obj.get_initial_pose().position_xyz - else: - random_pose = get_random_pose_within_bounding_box(init_bounds) - positions[obj] = random_pose.position_xyz - return positions + results: list[dict[ObjectBase, tuple[float, float, float]]] = [] + for candidate_idx in range(num_candidates): + rng_state = None + if self.params.placement_seed is not None: + rng_state = torch.get_rng_state() + torch.manual_seed(self.params.placement_seed + candidate_idx) + + positions: dict[ObjectBase, tuple[float, float, float]] = {} + for obj in objects: + if obj in anchor_objects: + positions[obj] = obj.get_initial_pose().position_xyz + else: + random_pose = get_random_pose_within_bounding_box(init_bounds) + positions[obj] = random_pose.position_xyz + results.append(positions) + + if rng_state is not None: + torch.set_rng_state(rng_state) + return results def _validate_on_relations( self, - positions: dict[Object | ObjectReference, tuple[float, float, float]], + positions: dict[ObjectBase, tuple[float, float, float]], ) -> bool: """Validate each On relation; logic matches OnLossStrategy (relation_loss_strategies.py). @@ -217,13 +253,28 @@ def _validate_on_relations( def _validate_no_overlap( self, - positions: dict[Object | ObjectReference, tuple[float, float, float]], + positions: dict[ObjectBase, tuple[float, float, float]], ) -> bool: - """Check that no two objects overlap in 3D (axis-aligned bbox with margin).""" + """Validate that no two objects overlap in 3D (axis-aligned bbox with margin). + + Pairs linked by an On relation are skipped (validated separately by + _validate_on_relations). + """ + # Build set of On-related pairs to skip (child, parent) and (parent, child). + on_pairs: set[tuple] = set() + for obj in positions: + for rel in obj.get_relations(): + if isinstance(rel, On) and rel.parent in positions: + on_pairs.add((id(obj), id(rel.parent))) + on_pairs.add((id(rel.parent), id(obj))) + objects = list(positions.keys()) for i in range(len(objects)): for j in range(i + 1, len(objects)): a, b = objects[i], objects[j] + # Pairs related by an On relation are excluded from the overlap check. + if (id(a), id(b)) in on_pairs: + continue a_world = a.get_bounding_box().translated(positions[a]) b_world = b.get_bounding_box().translated(positions[b]) @@ -236,7 +287,7 @@ def _validate_no_overlap( def _validate_placement( self, - positions: dict[Object | ObjectReference, tuple[float, float, float]], + positions: dict[ObjectBase, tuple[float, float, float]], ) -> bool: """Validate that no two objects overlap in 3D and On relations are satisfied. @@ -250,31 +301,43 @@ def _validate_placement( def _apply_positions( self, - positions: dict[Object | ObjectReference, tuple[float, float, float]], - anchor_objects: Object | ObjectReference, + positions_per_env: list[dict[ObjectBase, tuple[float, float, float]]], + anchor_objects: set[ObjectBase], ) -> None: """Apply solved positions to objects (skipping anchors). - If RandomAroundSolution marker is present, sets a PoseRange (for reset-time randomization). - Rotation is taken from RotateAroundSolution marker if present, otherwise keep the identity rotation. + Handles both single-env and multi-env placement: + - Single-env: sets a fixed Pose or PoseRange (with RandomAroundSolution). + - Multi-env: sets a PosePerEnv with one Pose per environment. + + Rotation is taken from RotateAroundSolution marker if present, otherwise identity. """ - for obj, pos in positions.items(): + num_envs = len(positions_per_env) + # Objects are the same for every environment. Extract them. + objects = list(positions_per_env[0]) + # Apply pose for each object. + for obj in objects: if obj in anchor_objects: continue - random_marker = self._get_random_around_solution(obj) rotate_marker = self._get_rotate_around_solution(obj) rotation_wxyz = rotate_marker.get_rotation_wxyz() if rotate_marker else (1.0, 0.0, 0.0, 0.0) - if random_marker is not None: - # We need to set a PoseRange for the randomization to be picked up on reset. - # Set a PoseRange with the explicit rotation from RotateAroundSolution if present - obj.set_initial_pose(random_marker.to_pose_range_centered_at(pos, rotation_wxyz=rotation_wxyz)) + if num_envs == 1: + pos = positions_per_env[0][obj] + random_marker = self._get_random_around_solution(obj) + if random_marker is not None: + obj.set_initial_pose(random_marker.to_pose_range_centered_at(pos, rotation_wxyz=rotation_wxyz)) + else: + obj.set_initial_pose(Pose(position_xyz=pos, rotation_wxyz=rotation_wxyz)) else: - # Without randomization, we can set a fixed Pose. - obj.set_initial_pose(Pose(position_xyz=pos, rotation_wxyz=rotation_wxyz)) + poses = [ + Pose(position_xyz=positions_per_env[env_idx][obj], rotation_wxyz=rotation_wxyz) + for env_idx in range(num_envs) + ] + obj.set_initial_pose(PosePerEnv(poses=poses)) - def _get_random_around_solution(self, obj: Object | ObjectReference) -> RandomAroundSolution | None: + def _get_random_around_solution(self, obj: ObjectBase) -> RandomAroundSolution | None: """Get RandomAroundSolution marker from object if present. Args: @@ -288,7 +351,7 @@ def _get_random_around_solution(self, obj: Object | ObjectReference) -> RandomAr return rel return None - def _get_rotate_around_solution(self, obj: Object | ObjectReference) -> RotateAroundSolution | None: + def _get_rotate_around_solution(self, obj: ObjectBase) -> RotateAroundSolution | None: """Get RotateAroundSolution marker from object if present. Args: diff --git a/isaaclab_arena/relations/placement_result.py b/isaaclab_arena/relations/placement_result.py index e74e266f7..22f76a14d 100644 --- a/isaaclab_arena/relations/placement_result.py +++ b/isaaclab_arena/relations/placement_result.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from isaaclab_arena.assets.object import Object + from isaaclab_arena.assets.object_base import ObjectBase @dataclass @@ -19,7 +19,7 @@ class PlacementResult: success: bool """Whether placement passed validation checks.""" - positions: dict[Object, tuple[float, float, float]] + positions: dict[ObjectBase, tuple[float, float, float]] """Final positions for each object.""" final_loss: float @@ -27,3 +27,21 @@ class PlacementResult: attempts: int """Number of attempts made.""" + + +@dataclass +class MultiEnvPlacementResult: + """Result of an ObjectPlacer.place() call for multiple environments.""" + + results: list[PlacementResult] + """One PlacementResult per environment (same length as num_envs).""" + + @property + def success(self) -> bool: + """True if every environment's placement succeeded.""" + return all(r.success for r in self.results) + + @property + def attempts(self) -> int: + """Number of attempts (same for all envs in the batched run).""" + return self.results[0].attempts if self.results else 0 diff --git a/isaaclab_arena/relations/relation_loss_strategies.py b/isaaclab_arena/relations/relation_loss_strategies.py index 079162c8f..1be500d39 100644 --- a/isaaclab_arena/relations/relation_loss_strategies.py +++ b/isaaclab_arena/relations/relation_loss_strategies.py @@ -79,11 +79,12 @@ def compute_loss( Args: relation: The relation object containing constraint metadata. - child_pos: Child object position tensor (x, y, z) in world coords. - child_bbox: Child object local bounding box (extents relative to origin). + child_pos: Child object position tensor. Accepts (3,) for single-env + backward compat or (N, 3) for batched. + child_bbox: Child object local bounding box (N=1). Returns: - Scalar loss tensor representing the constraint violation. + Scalar loss tensor when child_pos is (3,), or (N,) tensor when (N, 3). """ pass @@ -103,12 +104,13 @@ def compute_loss( Args: relation: The relation object containing relationship metadata. - child_pos: Child object position tensor (x, y, z) in world coords. - child_bbox: Child object local bounding box (extents relative to origin). + child_pos: Child object position tensor. Accepts (3,) for single-env + backward compat or (N, 3) for batched. + child_bbox: Child object local bounding box (N=1). parent_world_bbox: Parent bounding box in world coordinates. Returns: - Scalar loss tensor representing the constraint violation. + Scalar loss tensor when child_pos is (3,), or (N,) tensor when (N, 3). """ pass @@ -145,45 +147,49 @@ def compute_loss( Args: relation: NextTo relation with side and distance attributes. - child_pos: Child object position tensor (x, y, z) in world coords. - child_bbox: Child object local bounding box. + child_pos: Child object position (N, 3) in world coords. + child_bbox: Child object local bounding box (N=1). parent_world_bbox: Parent bounding box in world coordinates. Returns: - Weighted loss tensor. + Weighted loss tensor of shape (N,). """ + single_input = child_pos.dim() == 1 + if single_input: + child_pos = child_pos.unsqueeze(0) + cfg = SIDE_CONFIGS[relation.side] distance = relation.distance_m assert distance >= 0.0, f"NextTo distance must be non-negative, got {distance}" # Parent world extents from the world bounding box if cfg.direction == Direction.POSITIVE: - parent_edge = parent_world_bbox.max_point[0, cfg.primary_axis] - child_offset = child_bbox.min_point[0, cfg.primary_axis] + parent_edge = parent_world_bbox.max_point[:, cfg.primary_axis] + child_offset = child_bbox.min_point[:, cfg.primary_axis] penalty_side = "less" else: - parent_edge = parent_world_bbox.min_point[0, cfg.primary_axis] - child_offset = child_bbox.max_point[0, cfg.primary_axis] + parent_edge = parent_world_bbox.min_point[:, cfg.primary_axis] + child_offset = child_bbox.max_point[:, cfg.primary_axis] penalty_side = "greater" # 1. Half-plane loss: child must be on correct side of parent edge half_plane_loss = single_boundary_linear_loss( - child_pos[cfg.primary_axis], + child_pos[:, cfg.primary_axis], parent_edge, slope=self.slope, penalty_side=penalty_side, ) # 2. Band position loss: child placed at target position within parent's perpendicular extent - parent_band_min = parent_world_bbox.min_point[0, cfg.band_axis] - parent_band_max = parent_world_bbox.max_point[0, cfg.band_axis] - valid_band_min = parent_band_min - child_bbox.min_point[0, cfg.band_axis] - valid_band_max = parent_band_max - child_bbox.max_point[0, cfg.band_axis] + parent_band_min = parent_world_bbox.min_point[:, cfg.band_axis] + parent_band_max = parent_world_bbox.max_point[:, cfg.band_axis] + valid_band_min = parent_band_min - child_bbox.min_point[:, cfg.band_axis] + valid_band_max = parent_band_max - child_bbox.max_point[:, cfg.band_axis] # Convert cross_position_ratio [-1, 1] to interpolation factor [0, 1]: -1 = min, 0 = center, 1 = max t = (relation.cross_position_ratio + 1.0) / 2.0 target_band_pos = valid_band_min + t * (valid_band_max - valid_band_min) band_loss = single_point_linear_loss( - child_pos[cfg.band_axis], + child_pos[:, cfg.band_axis], target_band_pos, slope=self.slope, ) @@ -192,31 +198,32 @@ def compute_loss( # For direction +1: target = parent_max + distance - child_min # For direction -1: target = parent_min - distance - child_max target_pos = parent_edge + cfg.direction * distance - child_offset - distance_loss = single_point_linear_loss(child_pos[cfg.primary_axis], target_pos, slope=self.slope) + distance_loss = single_point_linear_loss(child_pos[:, cfg.primary_axis], target_pos, slope=self.slope) - if self.debug: + if self.debug and child_pos.shape[0] == 1: axis_name = cfg.primary_axis.name band_axis_name = cfg.band_axis.name print( f" [NextTo] {relation.side.value}: child_{axis_name.lower()}=" - f"{child_pos[cfg.primary_axis].item():.4f}, parent_edge={parent_edge.item():.4f}," - f" loss={half_plane_loss.item():.6f}" + f"{child_pos[0, cfg.primary_axis].item():.4f}, parent_edge={parent_edge[0].item():.4f}," + f" loss={half_plane_loss[0].item():.6f}" ) print( f" [NextTo] {band_axis_name} band: child_{band_axis_name.lower()}=" - f"{child_pos[cfg.band_axis].item():.4f}, target={target_band_pos.item():.4f}" + f"{child_pos[0, cfg.band_axis].item():.4f}, target={target_band_pos[0].item():.4f}" f" (cross_position_ratio={relation.cross_position_ratio:.2f}," - f" range=[{valid_band_min.item():.4f}, {valid_band_max.item():.4f}])," - f" loss={band_loss.item():.6f}" + f" range=[{valid_band_min[0].item():.4f}, {valid_band_max[0].item():.4f}])," + f" loss={band_loss[0].item():.6f}" ) print( f" [NextTo] Distance: child_{axis_name.lower()}=" - f"{child_pos[cfg.primary_axis].item():.4f}, target={target_pos.item():.4f}," - f" loss={distance_loss.item():.6f}" + f"{child_pos[0, cfg.primary_axis].item():.4f}, target={target_pos[0].item():.4f}," + f" loss={distance_loss[0].item():.6f}" ) total_loss = half_plane_loss + band_loss + distance_loss - return relation.relation_loss_weight * total_loss + result = relation.relation_loss_weight * total_loss + return result.squeeze(0) if single_input else result class OnLossStrategy(RelationLossStrategy): @@ -249,31 +256,33 @@ def compute_loss( Args: relation: On relation with clearance_m attribute. - child_pos: Child object position tensor (x, y, z) in world coords. - child_bbox: Child object local bounding box. + child_pos: Child object position (N, 3) in world coords. + child_bbox: Child object local bounding box (N=1). parent_world_bbox: Parent bounding box in world coordinates. Returns: - Weighted loss tensor. + Weighted loss tensor of shape (N,). """ + single_input = child_pos.dim() == 1 + if single_input: + child_pos = child_pos.unsqueeze(0) + # Parent world-space extents from the world bounding box - parent_x_min = parent_world_bbox.min_point[0, 0] - parent_x_max = parent_world_bbox.max_point[0, 0] - parent_y_min = parent_world_bbox.min_point[0, 1] - parent_y_max = parent_world_bbox.max_point[0, 1] - parent_z_max = parent_world_bbox.max_point[0, 2] # Top surface + parent_x_min = parent_world_bbox.min_point[:, 0] + parent_x_max = parent_world_bbox.max_point[:, 0] + parent_y_min = parent_world_bbox.min_point[:, 1] + parent_y_max = parent_world_bbox.max_point[:, 1] + parent_z_max = parent_world_bbox.max_point[:, 2] # Top surface # Compute valid position ranges such that child's entire footprint is within parent - # Child left edge = child_pos[0] + child_bbox.min_point[0, 0], must be >= parent_x_min - # Child right edge = child_pos[0] + child_bbox.max_point[0, 0], must be <= parent_x_max - valid_x_min = parent_x_min - child_bbox.min_point[0, 0] # child's left at parent's left - valid_x_max = parent_x_max - child_bbox.max_point[0, 0] # child's right at parent's right - valid_y_min = parent_y_min - child_bbox.min_point[0, 1] - valid_y_max = parent_y_max - child_bbox.max_point[0, 1] + valid_x_min = parent_x_min - child_bbox.min_point[:, 0] # child's left at parent's left + valid_x_max = parent_x_max - child_bbox.max_point[:, 0] # child's right at parent's right + valid_y_min = parent_y_min - child_bbox.min_point[:, 1] + valid_y_max = parent_y_max - child_bbox.max_point[:, 1] # 1. X band loss: child's footprint entirely within parent's X extent x_band_loss = linear_band_loss( - child_pos[0], + child_pos[:, 0], lower_bound=valid_x_min, upper_bound=valid_x_max, slope=self.slope, @@ -281,32 +290,33 @@ def compute_loss( # 2. Y band loss: child's footprint entirely within parent's Y extent y_band_loss = linear_band_loss( - child_pos[1], + child_pos[:, 1], lower_bound=valid_y_min, upper_bound=valid_y_max, slope=self.slope, ) # 3. Z point loss: child bottom = parent top + clearance - target_z = parent_z_max + relation.clearance_m - child_bbox.min_point[0, 2] - z_loss = single_point_linear_loss(child_pos[2], target_z, slope=self.slope) + target_z = parent_z_max + relation.clearance_m - child_bbox.min_point[:, 2] + z_loss = single_point_linear_loss(child_pos[:, 2], target_z, slope=self.slope) - if self.debug: + if self.debug and child_pos.shape[0] == 1: print( - f" [On] X: child_pos={child_pos[0].item():.4f}, valid_range=[{valid_x_min.item():.4f}," - f" {valid_x_max.item():.4f}], loss={x_band_loss.item():.6f}" + f" [On] X: child_pos={child_pos[0, 0].item():.4f}, valid_range=[{valid_x_min[0].item():.4f}," + f" {valid_x_max[0].item():.4f}], loss={x_band_loss[0].item():.6f}" ) print( - f" [On] Y: child_pos={child_pos[1].item():.4f}, valid_range=[{valid_y_min.item():.4f}," - f" {valid_y_max.item():.4f}], loss={y_band_loss.item():.6f}" + f" [On] Y: child_pos={child_pos[0, 1].item():.4f}, valid_range=[{valid_y_min[0].item():.4f}," + f" {valid_y_max[0].item():.4f}], loss={y_band_loss[0].item():.6f}" ) print( - f" [On] Z: child_pos={child_pos[2].item():.4f}, target={target_z.item():.4f}," - f" loss={z_loss.item():.6f}" + f" [On] Z: child_pos={child_pos[0, 2].item():.4f}, target={target_z[0].item():.4f}," + f" loss={z_loss[0].item():.6f}" ) total_loss = x_band_loss + y_band_loss + z_loss - return relation.relation_loss_weight * total_loss + result = relation.relation_loss_weight * total_loss + return result.squeeze(0) if single_input else result class NoCollisionLossStrategy(RelationLossStrategy): @@ -340,54 +350,59 @@ def compute_loss( Args: relation: NoCollision relation with relation_loss_weight. - child_pos: Child object position tensor (x, y, z) in world coords. - child_bbox: Child object local bounding box. + child_pos: Child object position (N, 3) in world coords. + child_bbox: Child object local bounding box (N=1). parent_world_bbox: Parent bounding box in world coordinates. Returns: - Weighted loss tensor. + Weighted loss tensor of shape (N,). """ + single_input = child_pos.dim() == 1 + if single_input: + child_pos = child_pos.unsqueeze(0) + # Parent world extents from the world bounding box, expanded by clearance_m c = relation.clearance_m - parent_x_min = parent_world_bbox.min_point[0, 0] - c - parent_x_max = parent_world_bbox.max_point[0, 0] + c - parent_y_min = parent_world_bbox.min_point[0, 1] - c - parent_y_max = parent_world_bbox.max_point[0, 1] + c - parent_z_min = parent_world_bbox.min_point[0, 2] - c - parent_z_max = parent_world_bbox.max_point[0, 2] + c + parent_x_min = parent_world_bbox.min_point[:, 0] - c + parent_x_max = parent_world_bbox.max_point[:, 0] + c + parent_y_min = parent_world_bbox.min_point[:, 1] - c + parent_y_max = parent_world_bbox.max_point[:, 1] + c + parent_z_min = parent_world_bbox.min_point[:, 2] - c + parent_z_max = parent_world_bbox.max_point[:, 2] + c # Child world extents - child_world_min = child_pos + child_bbox.min_point[0] - child_world_max = child_pos + child_bbox.max_point[0] + child_world_min = child_pos + child_bbox.min_point + child_world_max = child_pos + child_bbox.max_point # 1. Per-axis overlap: zero when separated; else overlap length (default slope 1.0 gives length in m) - overlap_x = interval_overlap_axis_loss(child_world_min[0], child_world_max[0], parent_x_min, parent_x_max) - overlap_y = interval_overlap_axis_loss(child_world_min[1], child_world_max[1], parent_y_min, parent_y_max) - overlap_z = interval_overlap_axis_loss(child_world_min[2], child_world_max[2], parent_z_min, parent_z_max) + overlap_x = interval_overlap_axis_loss(child_world_min[:, 0], child_world_max[:, 0], parent_x_min, parent_x_max) + overlap_y = interval_overlap_axis_loss(child_world_min[:, 1], child_world_max[:, 1], parent_y_min, parent_y_max) + overlap_z = interval_overlap_axis_loss(child_world_min[:, 2], child_world_max[:, 2], parent_z_min, parent_z_max) # 2. Volume loss: slope * product of per-axis overlap lengths (overlap volume when slope 1.0) overlap_volume = overlap_x * overlap_y * overlap_z total_loss = self.slope * overlap_volume - if self.debug: + if self.debug and child_pos.shape[0] == 1: print( - f" [NoCollision] X: overlap={overlap_x.item():.6f} (child_x=[{child_world_min[0].item():.4f}," - f" {child_world_max[0].item():.4f}], parent_x=[{parent_x_min.item():.4f}," - f" {parent_x_max.item():.4f}])" + f" [NoCollision] X: overlap={overlap_x[0].item():.6f} (child_x=[{child_world_min[0, 0].item():.4f}," + f" {child_world_max[0, 0].item():.4f}], parent_x=[{parent_x_min[0].item():.4f}," + f" {parent_x_max[0].item():.4f}])" ) print( - f" [NoCollision] Y: overlap={overlap_y.item():.6f} (child_y=[{child_world_min[1].item():.4f}," - f" {child_world_max[1].item():.4f}], parent_y=[{parent_y_min.item():.4f}," - f" {parent_y_max.item():.4f}])" + f" [NoCollision] Y: overlap={overlap_y[0].item():.6f} (child_y=[{child_world_min[0, 1].item():.4f}," + f" {child_world_max[0, 1].item():.4f}], parent_y=[{parent_y_min[0].item():.4f}," + f" {parent_y_max[0].item():.4f}])" ) print( - f" [NoCollision] Z: overlap={overlap_z.item():.6f} (child_z=[{child_world_min[2].item():.4f}," - f" {child_world_max[2].item():.4f}], parent_z=[{parent_z_min.item():.4f}," - f" {parent_z_max.item():.4f}])" + f" [NoCollision] Z: overlap={overlap_z[0].item():.6f} (child_z=[{child_world_min[0, 2].item():.4f}," + f" {child_world_max[0, 2].item():.4f}], parent_z=[{parent_z_min[0].item():.4f}," + f" {parent_z_max[0].item():.4f}])" ) - print(f" [NoCollision] volume={overlap_volume.item():.6f}, loss={total_loss.item():.6f}") + print(f" [NoCollision] volume={overlap_volume[0].item():.6f}, loss={total_loss[0].item():.6f}") - return relation.relation_loss_weight * total_loss + result = relation.relation_loss_weight * total_loss + return result.squeeze(0) if single_input else result class AtPositionLossStrategy(UnaryRelationLossStrategy): @@ -415,27 +430,32 @@ def compute_loss( Args: relation: AtPosition relation with x, y, z target coordinates. - child_pos: Child object position tensor (x, y, z) in world coords. + child_pos: Child object position (N, 3) in world coords. child_bbox: Child object local bounding box (unused, for signature consistency). Returns: - Weighted loss tensor. + Weighted loss tensor of shape (N,). """ - total_loss = torch.tensor(0.0, dtype=child_pos.dtype, device=child_pos.device) + single_input = child_pos.dim() == 1 + if single_input: + child_pos = child_pos.unsqueeze(0) + + total_loss = torch.zeros(child_pos.shape[0], dtype=child_pos.dtype, device=child_pos.device) # X position constraint if relation.x is not None: - x_loss = single_point_linear_loss(child_pos[0], relation.x, slope=self.slope) + x_loss = single_point_linear_loss(child_pos[:, 0], relation.x, slope=self.slope) total_loss = total_loss + x_loss # Y position constraint if relation.y is not None: - y_loss = single_point_linear_loss(child_pos[1], relation.y, slope=self.slope) + y_loss = single_point_linear_loss(child_pos[:, 1], relation.y, slope=self.slope) total_loss = total_loss + y_loss # Z position constraint if relation.z is not None: - z_loss = single_point_linear_loss(child_pos[2], relation.z, slope=self.slope) + z_loss = single_point_linear_loss(child_pos[:, 2], relation.z, slope=self.slope) total_loss = total_loss + z_loss - return relation.relation_loss_weight * total_loss + result = relation.relation_loss_weight * total_loss + return result.squeeze(0) if single_input else result diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index be96be7ac..9df5155b5 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -14,8 +14,7 @@ from isaaclab_arena.relations.relations import AtPosition, Relation, RelationBase if TYPE_CHECKING: - from isaaclab_arena.assets.object import Object - from isaaclab_arena.assets.object_reference import ObjectReference + from isaaclab_arena.assets.object_base import ObjectBase class RelationSolver: @@ -39,6 +38,7 @@ def __init__( self.params = params or RelationSolverParams() self._last_loss_history: list[float] = [] self._last_position_history: list = [] + self._last_loss_per_env: torch.Tensor | None = None def _get_strategy(self, relation: RelationBase) -> RelationLossStrategy | UnaryRelationLossStrategy: """Look up the appropriate strategy for a relation type. @@ -68,9 +68,11 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) - debug: If True, print detailed loss breakdown. Returns: - Total loss tensor. + Scalar loss tensor (mean over envs). Per-env loss stored in _last_loss_per_env. """ - total_loss = torch.tensor(0.0) + N = state.num_envs + device = state.optimizable_positions.device if state.optimizable_positions is not None else None + total_loss = torch.zeros(N, device=device, dtype=torch.float32) # Compute loss from all spatial relations using strategies for obj in state.optimizable_objects: @@ -86,7 +88,7 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) - child_bbox=obj.get_bounding_box(), ) if debug: - _print_unary_relation_debug(obj, relation, child_pos, loss) + _print_unary_relation_debug(obj, relation, child_pos[0], loss.mean()) # Handle binary relations (with parent) like On, NextTo elif isinstance(relation, Relation): # Build parent world bbox: anchors have a known fixed pose, @@ -96,9 +98,7 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) - parent_world_bbox = parent.get_world_bounding_box() else: parent_pos = state.get_position(parent) - parent_world_bbox = parent.get_bounding_box().translated( - (float(parent_pos[0]), float(parent_pos[1]), float(parent_pos[2])) - ) + parent_world_bbox = parent.get_bounding_box().translated(parent_pos) loss = strategy.compute_loss( relation=relation, child_pos=child_pos, @@ -107,28 +107,30 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) - ) if debug: parent_pos = state.get_position(parent) - _print_relation_debug(obj, relation, child_pos, parent_pos, loss) + _print_relation_debug(obj, relation, child_pos[0], parent_pos[0], loss.mean()) else: raise ValueError(f"Unknown relation type: {type(relation).__name__}") total_loss = total_loss + loss - return total_loss + self._last_loss_per_env = total_loss.detach().clone() + return total_loss.mean() def solve( self, - objects: list[Object | ObjectReference], - initial_positions: dict[Object | ObjectReference, tuple[float, float, float]], - ) -> dict[Object | ObjectReference, tuple[float, float, float]]: + objects: list[ObjectBase], + initial_positions: list[dict[ObjectBase, tuple[float, float, float]]], + ) -> list[dict[ObjectBase, tuple[float, float, float]]]: """Solve for optimal positions of all objects. Args: - objects: List of Object or ObjectReference instances. Must include at least one object + objects: List of ObjectBase instances. Must include at least one object marked with IsAnchor() which serves as a fixed reference. - initial_positions: Starting positions for all objects (including anchors). + initial_positions: List of dicts (one per env). Use a single-element list + for single-env placement. Returns: - Dictionary mapping object instances to final (x, y, z) positions. + List of dicts (one per env) mapping objects to their solved (x, y, z) positions. """ state = RelationSolverState(objects, initial_positions) @@ -145,7 +147,7 @@ def solve( print("No optimizable objects, skipping solver.") self._last_loss_history = [0.0] self._last_position_history = [state.get_all_positions_snapshot()] - return state.get_final_positions_dict() + return state.get_final_positions() # Setup optimizer (only for optimizable positions) optimizer = torch.optim.Adam([state.optimizable_positions], lr=self.params.lr) @@ -188,19 +190,24 @@ def solve( self._last_loss_history = loss_history self._last_position_history = position_history - return state.get_final_positions_dict() + return state.get_final_positions() @property def last_loss_history(self) -> list[float]: """Loss values from the most recent solve() call.""" return self._last_loss_history + @property + def last_loss_per_env(self) -> torch.Tensor | None: + """Per-env loss (N,) from the last solve() call.""" + return self._last_loss_per_env + @property def last_position_history(self) -> list: """Position snapshots from the most recent solve() call.""" return self._last_position_history - def debug_losses(self, objects: list[Object | ObjectReference]) -> None: + def debug_losses(self, objects: list[ObjectBase]) -> None: """Print detailed loss breakdown for all relations using final positions. Call this after solve() to inspect why objects may not be correctly positioned. @@ -220,13 +227,13 @@ def debug_losses(self, objects: list[Object | ObjectReference]) -> None: # Build positions dict from final position history final_positions = {obj: (pos[0], pos[1], pos[2]) for obj, pos in zip(objects, final_positions_list)} - state = RelationSolverState(objects, final_positions) + state = RelationSolverState(objects, [final_positions]) self._compute_total_loss(state, debug=True) print("\n" + "=" * 60) def _print_relation_debug( - obj: Object | ObjectReference, + obj: ObjectBase, relation: Relation, child_pos: torch.Tensor, parent_pos: torch.Tensor, @@ -272,7 +279,7 @@ def _print_relation_debug( def _print_unary_relation_debug( - obj: Object, + obj: ObjectBase, relation: AtPosition, child_pos: torch.Tensor, loss: torch.Tensor, diff --git a/isaaclab_arena/relations/relation_solver_state.py b/isaaclab_arena/relations/relation_solver_state.py index 75c00169d..d6d5d1aa1 100644 --- a/isaaclab_arena/relations/relation_solver_state.py +++ b/isaaclab_arena/relations/relation_solver_state.py @@ -11,8 +11,7 @@ from isaaclab_arena.relations.relations import get_anchor_objects if TYPE_CHECKING: - from isaaclab_arena.assets.object import Object - from isaaclab_arena.assets.object_reference import ObjectReference + from isaaclab_arena.assets.object_base import ObjectBase class RelationSolverState: @@ -21,74 +20,101 @@ class RelationSolverState: This class manages the mapping between objects and their positions, keeping anchor (fixed) and optimizable positions separate internally while providing an interface for position lookups. + + Positions are always stored as (N, num_objects, 3) where N = num_envs + (N=1 for single-env). """ def __init__( self, - objects: list[Object | ObjectReference], - initial_positions: dict[Object | ObjectReference, tuple[float, float, float]], + objects: list[ObjectBase], + initial_positions: list[dict[ObjectBase, tuple[float, float, float]]], ): """Initialize optimization state. Args: - objects: List of all Object or ObjectReference instances to track. Must include at least one + objects: List of all ObjectBase instances to track. Must include at least one object marked with IsAnchor() which serves as a fixed reference. - initial_positions: Starting positions for all objects (including anchors). + initial_positions: List of dicts (one per env). Length 1 = single-env, + length > 1 = batched. """ + assert len(initial_positions) >= 1, "initial_positions must contain at least one dict." anchor_objects = get_anchor_objects(objects) assert len(anchor_objects) > 0, "No anchor object found in objects list." self._all_objects = objects - self._anchor_objects: set[Object] = set(anchor_objects) + self._anchor_objects: set[ObjectBase] = set(anchor_objects) self._optimizable_objects = [obj for obj in objects if obj not in self._anchor_objects] # Build object-to-index mapping - self._obj_to_idx: dict[Object | ObjectReference, int] = {obj: i for i, obj in enumerate(objects)} - - # Extract positions from the provided dict - positions = [] - for obj in objects: - assert obj in initial_positions, f"Missing initial position for {obj.name}" - positions.append(torch.tensor(initial_positions[obj], dtype=torch.float32)) + self._obj_to_idx: dict[ObjectBase, int] = {obj: i for i, obj in enumerate(objects)} + + # Extract positions from each env's dict + self._num_envs = len(initial_positions) + positions_per_env = [] + for d in initial_positions: + positions = [] + for obj in objects: + assert obj in d, f"Missing initial position for {obj.name}" + positions.append(torch.tensor(d[obj], dtype=torch.float32)) + positions_per_env.append(positions) # Separate anchor positions from optimizable positions self._anchor_indices: set[int] = {self._obj_to_idx[obj] for obj in self._anchor_objects} - self._anchor_positions: dict[int, torch.Tensor] = {idx: positions[idx].clone() for idx in self._anchor_indices} + # Anchors must be identical across envs (they are fixed reference points). + for idx in self._anchor_indices: + for e in range(1, self._num_envs): + assert torch.allclose(positions_per_env[0][idx], positions_per_env[e][idx]), ( + f"Anchor '{objects[idx].name}' has different positions across envs " + f"(env 0: {positions_per_env[0][idx].tolist()}, env {e}: {positions_per_env[e][idx].tolist()})" + ) + self._anchor_positions: dict[int, torch.Tensor] = { + idx: positions_per_env[0][idx].clone() for idx in self._anchor_indices + } # Build optimizable positions tensor (excludes all anchors) + # Always stored as (N, num_opt, 3) where N = num_envs self._optimizable_indices = [i for i in range(len(objects)) if i not in self._anchor_indices] if self._optimizable_indices: - self._optimizable_positions = torch.stack([positions[i] for i in self._optimizable_indices]) + opt_tensors = [ + torch.stack([positions_per_env[e][i] for e in range(self._num_envs)]) for i in self._optimizable_indices + ] + self._optimizable_positions = torch.stack(opt_tensors, dim=1) # (N, num_opt, 3) self._optimizable_positions.requires_grad = True else: self._optimizable_positions = None + @property + def num_envs(self) -> int: + """Number of environments (leading dimension N).""" + return self._num_envs + @property def optimizable_positions(self) -> torch.Tensor | None: - """Tensor of optimizable positions (shape: [N-num_anchors, 3]), or None if all objects are anchors. + """Tensor of optimizable positions (N, num_opt, 3), or None if all objects are anchors. This is the tensor that should be passed to the optimizer. """ return self._optimizable_positions @property - def optimizable_objects(self) -> list[Object]: + def optimizable_objects(self) -> list[ObjectBase]: """List of optimizable objects (excludes anchors).""" return self._optimizable_objects @property - def anchor_objects(self) -> set[Object]: + def anchor_objects(self) -> set[ObjectBase]: """Set of anchor objects (fixed during optimization).""" return self._anchor_objects - def get_position(self, obj: Object | ObjectReference) -> torch.Tensor: + def get_position(self, obj: ObjectBase) -> torch.Tensor: """Get current position for an object. Args: obj: The object to get position for. Returns: - Position tensor (x, y, z). + Position tensor of shape (N, 3). Raises: KeyError: If object is not tracked by this state. @@ -96,28 +122,31 @@ def get_position(self, obj: Object | ObjectReference) -> torch.Tensor: """ idx = self._obj_to_idx[obj] if idx in self._anchor_indices: - return self._anchor_positions[idx] + return self._anchor_positions[idx].unsqueeze(0).expand(self._num_envs, 3) if self._optimizable_positions is None: raise RuntimeError(f"No optimizable positions available for object '{obj.name}'") opt_idx = self._optimizable_indices.index(idx) - return self._optimizable_positions[opt_idx] + return self._optimizable_positions[:, opt_idx, :] def get_all_positions_snapshot(self) -> list[tuple[float, float, float]]: """Get detached copy of all positions for history tracking. Returns: - List of (x, y, z) positions for each object (in original order). + List of (x, y, z) positions for each object (in original order). Uses env 0. """ - return [tuple(self.get_position(obj).detach().tolist()) for obj in self._all_objects] + return [tuple(self.get_position(obj)[0].detach().tolist()) for obj in self._all_objects] - def get_final_positions_dict(self) -> dict[Object | ObjectReference, tuple[float, float, float]]: - """Get final positions as a dictionary mapping objects to positions. + def get_final_positions(self) -> list[dict[ObjectBase, tuple[float, float, float]]]: + """Get final positions as a list of dicts, one per env. Returns: - Dictionary with object instances as keys and (x, y, z) tuples as values. + List of dictionaries with object instances as keys and (x, y, z) tuples as values. """ - result: dict[Object | ObjectReference, tuple[float, float, float]] = {} - for obj in self._all_objects: - pos = self.get_position(obj).detach().tolist() - result[obj] = (pos[0], pos[1], pos[2]) - return result + out = [] + for e in range(self._num_envs): + d: dict[ObjectBase, tuple[float, float, float]] = {} + for obj in self._all_objects: + pos = self.get_position(obj)[e].detach().tolist() + d[obj] = (pos[0], pos[1], pos[2]) + out.append(d) + return out diff --git a/isaaclab_arena/relations/relations.py b/isaaclab_arena/relations/relations.py index f64167881..d64559011 100644 --- a/isaaclab_arena/relations/relations.py +++ b/isaaclab_arena/relations/relations.py @@ -14,8 +14,7 @@ from isaaclab_arena.utils.pose import PoseRange if TYPE_CHECKING: - from isaaclab_arena.assets.object import Object - from isaaclab_arena.assets.object_reference import ObjectReference + from isaaclab_arena.assets.object_base import ObjectBase class Side(Enum): @@ -41,10 +40,10 @@ class RelationBase: class Relation(RelationBase): """Base class for spatial relationships between objects.""" - def __init__(self, parent: Object | ObjectReference, relation_loss_weight: float = 1.0): + def __init__(self, parent: ObjectBase, relation_loss_weight: float = 1.0): """ Args: - parent: The parent asset in the relationship (Object or ObjectReference). + parent: The parent asset in the relationship. relation_loss_weight: Weight for the relationship loss function. """ self.parent = parent @@ -62,7 +61,7 @@ class NextTo(Relation): def __init__( self, - parent: Object | ObjectReference, + parent: ObjectBase, relation_loss_weight: float = 1.0, distance_m: float = 0.05, side: Side = Side.POSITIVE_X, @@ -102,7 +101,7 @@ class On(Relation): def __init__( self, - parent: Object | ObjectReference, + parent: ObjectBase, relation_loss_weight: float = 1.0, clearance_m: float = 0.01, ): @@ -126,15 +125,16 @@ class NoCollision(Relation): Note: Loss computation is handled by NoCollisionLossStrategy in relation_loss_strategies.py. - NOTE: If both A.add_relation(NoCollision(B)) and B.add_relation(NoCollision(A)) are present, - the solver will compute the loss twice and the relation graph becomes cyclic, which can cause - issues during environment creation. Deduplication or cycle detection should be addressed at a - higher level. + NOTE: RelationSolver._compute_total_loss iterates every relation on every object with no + deduplication. If both A.add_relation(NoCollision(B)) and B.add_relation(NoCollision(A)) + are present, loss is computed twice. Bidirectional NoCollision can also make the relation + graph cyclic and cause issues when creating the environment. Deduplication and/or + higher-level handling of symmetric relations to be addressed in a future commit. """ def __init__( self, - parent: Object | ObjectReference, + parent: ObjectBase, relation_loss_weight: float = 1.0, clearance_m: float = 0.01, ): @@ -345,7 +345,7 @@ def __init__( self.relation_loss_weight = relation_loss_weight -def get_anchor_objects(objects: list[Object | ObjectReference]) -> list[Object | ObjectReference]: +def get_anchor_objects(objects: list[ObjectBase]) -> list[ObjectBase]: """Get all anchor objects from a list of objects. Anchor objects are marked with IsAnchor() relation and serve as diff --git a/isaaclab_arena/tests/test_no_collision_loss.py b/isaaclab_arena/tests/test_no_collision_loss.py index 77768d768..6c3e485fd 100644 --- a/isaaclab_arena/tests/test_no_collision_loss.py +++ b/isaaclab_arena/tests/test_no_collision_loss.py @@ -190,7 +190,7 @@ def test_relation_solver_no_collision_produces_separated_positions(): solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) solver = RelationSolver(params=solver_params) - result = solver.solve(objects=objects, initial_positions=initial_positions) + result = solver.solve(objects=objects, initial_positions=[initial_positions])[0] pos_a = result[box_a] pos_b = result[box_b] @@ -208,12 +208,52 @@ def test_relation_solver_no_collision_same_inputs_reproducible(): solver_params = RelationSolverParams(max_iters=50) solver1 = RelationSolver(params=solver_params) - result1 = solver1.solve(objects=[table1, box_a1, box_b1], initial_positions=initial_positions1) + result1 = solver1.solve(objects=[table1, box_a1, box_b1], initial_positions=[initial_positions1])[0] table2, box_a2, box_b2 = _create_no_collision_scene() initial_positions2 = {table2: initial[0], box_a2: initial[1], box_b2: initial[2]} solver2 = RelationSolver(params=solver_params) - result2 = solver2.solve(objects=[table2, box_a2, box_b2], initial_positions=initial_positions2) + result2 = solver2.solve(objects=[table2, box_a2, box_b2], initial_positions=[initial_positions2])[0] assert result1[box_a1] == result2[box_a2], "box_a positions should match" assert result1[box_b1] == result2[box_b2], "box_b positions should match" + + +def test_no_collision_loss_multi_env_shape_and_values(): + """Test that NoCollision with batched (N,3) input returns (N,) loss with correct per-env values.""" + box_a = _create_box("box_a") + box_b = _create_box("box_b") + relation = NoCollision(box_b, clearance_m=0.0) + strategy = NoCollisionLossStrategy(slope=10.0) + + child_pos = torch.tensor([[0.0, 0.0, 0.0], [0.1, 0.1, 0.0]]) + parent_world_bbox = AxisAlignedBoundingBox( + min_point=torch.tensor([[1.0, 0.0, 0.0], [0.05, 0.05, 0.0]]), + max_point=torch.tensor([[1.2, 0.2, 0.2], [0.25, 0.25, 0.2]]), + ) + + loss = strategy.compute_loss(relation, child_pos, box_a.bounding_box, parent_world_bbox) + assert loss.shape == (2,) + assert torch.isclose(loss[0], torch.tensor(0.0), atol=1e-5) + assert loss[1] > 0.0 + + +def test_relation_solver_multi_env_returns_list_of_dicts(): + """Test that solver returns list[dict] when given list[dict] input.""" + table, box_a, box_b = _create_no_collision_scene() + objects = [table, box_a, box_b] + initial_positions = [ + {table: (0.0, 0.0, 0.0), box_a: (0.2, 0.2, 0.11), box_b: (0.25, 0.25, 0.11)}, + {table: (0.0, 0.0, 0.0), box_a: (0.3, 0.3, 0.11), box_b: (0.6, 0.6, 0.11)}, + ] + + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + solver = RelationSolver(params=solver_params) + result = solver.solve(objects=objects, initial_positions=initial_positions) + + assert isinstance(result, list) + assert len(result) == 2 + for d in result: + assert isinstance(d, dict) + assert box_a in d + assert box_b in d diff --git a/isaaclab_arena/tests/test_object_placer_reproducibility.py b/isaaclab_arena/tests/test_object_placer_reproducibility.py index 81daa8bf6..892d96cbf 100644 --- a/isaaclab_arena/tests/test_object_placer_reproducibility.py +++ b/isaaclab_arena/tests/test_object_placer_reproducibility.py @@ -8,11 +8,12 @@ from isaaclab_arena.assets.dummy_object import DummyObject from isaaclab_arena.relations.object_placer import ObjectPlacer from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams +from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult from isaaclab_arena.relations.relation_solver import RelationSolver from isaaclab_arena.relations.relation_solver_params import RelationSolverParams from isaaclab_arena.relations.relations import IsAnchor, NextTo, On, Side from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, get_random_pose_within_bounding_box -from isaaclab_arena.utils.pose import Pose +from isaaclab_arena.utils.pose import Pose, PosePerEnv def _create_test_objects() -> tuple[DummyObject, DummyObject, DummyObject]: @@ -62,14 +63,14 @@ def test_relation_solver_same_inputs_produces_identical_result(): initial_positions1 = {desk1: desk_pos, box1_run1: fixed_box1_pos, box2_run1: fixed_box2_pos} solver1 = RelationSolver(params=solver_params) - result1 = solver1.solve(objects=[desk1, box1_run1, box2_run1], initial_positions=initial_positions1) + result1 = solver1.solve(objects=[desk1, box1_run1, box2_run1], initial_positions=[initial_positions1])[0] # Run 2 (fresh objects, same initial positions) desk2, box1_run2, box2_run2 = _create_test_objects() initial_positions2 = {desk2: desk_pos, box1_run2: fixed_box1_pos, box2_run2: fixed_box2_pos} solver2 = RelationSolver(params=solver_params) - result2 = solver2.solve(objects=[desk2, box1_run2, box2_run2], initial_positions=initial_positions2) + result2 = solver2.solve(objects=[desk2, box1_run2, box2_run2], initial_positions=[initial_positions2])[0] # Compare by name (different object instances) for obj1 in result1: @@ -128,3 +129,112 @@ def test_object_placer_different_seeds_produce_different_results(): break assert any_different, "Different seeds should produce different results" + + +def test_object_placer_multi_env_returns_multi_env_result(): + """Test that ObjectPlacer.place with num_envs>1 returns MultiEnvPlacementResult.""" + num_envs = 4 + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + placer = ObjectPlacer(params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params)) + result = placer.place(objects, num_envs=num_envs) + + assert isinstance(result, MultiEnvPlacementResult) + assert len(result.results) == num_envs + for r in result.results: + assert isinstance(r, PlacementResult) + assert box1 in r.positions + assert box2 in r.positions + assert len(r.positions[box1]) == 3 + assert len(r.positions[box2]) == 3 + + +def test_object_placer_multi_env_produces_different_positions(): + """Test that multi-env placement produces different positions across environments.""" + num_envs = 4 + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + placer = ObjectPlacer(params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params)) + result = placer.place(objects, num_envs=num_envs) + + assert isinstance(result, MultiEnvPlacementResult) + # At least one pair of envs should have different positions for a non-anchor object. + positions_box1 = [result.results[e].positions[box1] for e in range(num_envs)] + any_different = any(positions_box1[i] != positions_box1[j] for i in range(num_envs) for j in range(i + 1, num_envs)) + assert any_different, "Multi-env placement should produce different positions across environments" + + +def test_relation_solver_multi_env_batched_positions(): + """Test that solver with list[dict] input returns list[dict] output.""" + solver_params = RelationSolverParams(max_iters=50) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + + initial_positions = [ + {desk: (0.0, 0.0, 0.0), box1: (0.2, 0.2, 0.11), box2: (0.5, 0.5, 0.11)}, + {desk: (0.0, 0.0, 0.0), box1: (0.3, 0.3, 0.11), box2: (0.6, 0.6, 0.11)}, + ] + + solver = RelationSolver(params=solver_params) + result = solver.solve(objects=objects, initial_positions=initial_positions) + + assert isinstance(result, list) + assert len(result) == 2 + for d in result: + assert isinstance(d, dict) + for obj in objects: + assert obj in d + assert len(d[obj]) == 3 + + +def test_object_placer_result_per_env_false_returns_single_result(): + """Test that place(num_envs>1, result_per_env=False) returns PlacementResult.""" + num_envs = 4 + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + placer = ObjectPlacer(params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params)) + result = placer.place(objects, num_envs=num_envs, result_per_env=False) + + assert isinstance(result, PlacementResult), "result_per_env=False should return PlacementResult" + assert not isinstance(result, MultiEnvPlacementResult) + assert box1 in result.positions + assert box2 in result.positions + assert len(result.positions[box1]) == 3 + assert len(result.positions[box2]) == 3 + + +def test_object_placer_result_per_env_false_applies_pose_not_pose_per_env(): + """Test that result_per_env=False sets a single Pose (not PosePerEnv) on each object.""" + num_envs = 4 + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + placer = ObjectPlacer( + params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=True) + ) + placer.place(objects, num_envs=num_envs, result_per_env=False) + + for obj in [box1, box2]: + pose = obj.get_initial_pose() + assert isinstance(pose, Pose), f"{obj.name} should have a Pose, got {type(pose).__name__}" + assert not isinstance(pose, PosePerEnv) + + +def test_object_placer_result_per_env_true_applies_pose_per_env(): + """Test that result_per_env=True (default) sets PosePerEnv on each object when num_envs>1.""" + num_envs = 4 + solver_params = RelationSolverParams(max_iters=200, convergence_threshold=1e-3) + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + placer = ObjectPlacer( + params=ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=True) + ) + placer.place(objects, num_envs=num_envs, result_per_env=True) + + for obj in [box1, box2]: + pose = obj.get_initial_pose() + assert isinstance(pose, PosePerEnv), f"{obj.name} should have PosePerEnv, got {type(pose).__name__}" + assert len(pose.poses) == num_envs diff --git a/isaaclab_arena/tests/test_pose.py b/isaaclab_arena/tests/test_pose.py index 718fc8c76..8a3a5c3bc 100644 --- a/isaaclab_arena/tests/test_pose.py +++ b/isaaclab_arena/tests/test_pose.py @@ -3,7 +3,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from isaaclab_arena.utils.pose import Pose +from isaaclab_arena.utils.pose import Pose, PosePerEnv def test_pose_composition(): @@ -14,3 +14,18 @@ def test_pose_composition(): assert T_C_A.position_xyz == (3.0, 0.0, 0.0) assert T_C_A.rotation_wxyz == (1.0, 0.0, 0.0, 0.0) + + +def test_pose_per_env_stores_poses(): + """Test that PosePerEnv stores the list of Pose objects correctly.""" + poses = [ + Pose(position_xyz=(1.0, 2.0, 3.0)), + Pose(position_xyz=(4.0, 5.0, 6.0)), + Pose(position_xyz=(7.0, 8.0, 9.0)), + ] + pose_per_env = PosePerEnv(poses=poses) + + assert len(pose_per_env.poses) == 3 + assert pose_per_env.poses[0].position_xyz == (1.0, 2.0, 3.0) + assert pose_per_env.poses[1].position_xyz == (4.0, 5.0, 6.0) + assert pose_per_env.poses[2].position_xyz == (7.0, 8.0, 9.0) diff --git a/isaaclab_arena/tests/test_relation_loss_strategies.py b/isaaclab_arena/tests/test_relation_loss_strategies.py index d7221d0ea..f2463c167 100644 --- a/isaaclab_arena/tests/test_relation_loss_strategies.py +++ b/isaaclab_arena/tests/test_relation_loss_strategies.py @@ -229,3 +229,38 @@ def test_next_to_zero_distance_raises(): with pytest.raises(AssertionError, match="Distance must be positive"): NextTo(parent_obj, side=Side.POSITIVE_X, distance_m=0.0) + + +def test_on_loss_strategy_multi_env_shape_and_values(): + """Test that On with batched (N,3) input returns (N,) loss with correct per-env values.""" + table = _create_table() + box = _create_box() + relation = On(table, clearance_m=0.01) + strategy = OnLossStrategy(slope=10.0) + + child_pos = torch.tensor([[0.4, 0.4, 0.11], [0.4, 0.4, 0.5]]) + parent_world_bbox = AxisAlignedBoundingBox( + min_point=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), + max_point=torch.tensor([[1.0, 1.0, 0.1], [1.0, 1.0, 0.1]]), + ) + + loss = strategy.compute_loss(relation, child_pos, box.bounding_box, parent_world_bbox) + assert loss.shape == (2,) + assert torch.isclose(loss[0], torch.tensor(0.0), atol=1e-4) + assert loss[1] > 0.0 + + +def test_next_to_loss_strategy_multi_env_shape_and_values(): + """Test that NextTo with batched (N,3) input returns (N,) loss with correct per-env values.""" + parent_obj = _create_table() + child_obj = _create_box() + relation = NextTo(parent_obj, side=Side.POSITIVE_X, distance_m=0.05) + strategy = NextToLossStrategy(slope=10.0) + + # Env 0: perfectly placed. Env 1: wrong side. + child_pos = torch.tensor([[1.05, 0.4, 0.0], [-0.5, 0.5, 0.0]]) + + loss = strategy.compute_loss(relation, child_pos, child_obj.bounding_box, parent_obj.bounding_box) + assert loss.shape == (2,) + assert torch.isclose(loss[0], torch.tensor(0.0), atol=1e-4) + assert loss[1] > 0.0 diff --git a/isaaclab_arena/utils/pose.py b/isaaclab_arena/utils/pose.py index 57babdc58..088f1e311 100644 --- a/isaaclab_arena/utils/pose.py +++ b/isaaclab_arena/utils/pose.py @@ -74,6 +74,14 @@ def compose_poses(T_C_B: Pose, T_B_A: Pose) -> Pose: return Pose(position_xyz=tuple(t_C_A.tolist()), rotation_wxyz=tuple(q_C_A.tolist())) +@dataclass +class PosePerEnv: + """Per-environment poses (one Pose per env, used for batched placement).""" + + poses: list[Pose] + """One Pose per environment.""" + + @dataclass class PoseRange: """Range of poses. diff --git a/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py b/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py index 3e0428fc1..f3750fddc 100644 --- a/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py +++ b/isaaclab_arena_environments/gr1_table_multi_object_no_collision_environment.py @@ -10,7 +10,7 @@ Example: python isaaclab_arena/evaluation/policy_runner.py --policy_type zero_action --num_steps 500 \\ - --num_envs 1 --enable_cameras gr1_table_multi_object_no_collision --embodiment gr1_joint + --num_envs 16 --env_spacing 4.0 --enable_cameras gr1_table_multi_object_no_collision --embodiment gr1_joint """ import argparse @@ -19,13 +19,15 @@ DEFAULT_TABLE_OBJECTS = [ "cracker_box", - "mustard_bottle", "sugar_box", "tomato_soup_can", - "mug", - "brown_box", "dex_cube", -] # Default objects on table (On + pairwise NoCollision) + "power_drill", + "red_container", +] +# NOTE: The gradient-based solver does not guarantee collision-free placement for all +# objects. Better initialization strategies and constraining unchanged pose dimensions +# are needed in the near future. class GR1TableMultiObjectNoCollisionEnvironment(ExampleEnvironmentBase):