Batch Support for Relation Solver#512
Batch Support for Relation Solver#512zhx06 wants to merge 12 commits intozxiao/bbox_batch_supportfrom
Conversation
alexmillane
left a comment
There was a problem hiding this comment.
Good work.
I have some suggestions for how we might improve things.
| positions_all_envs_by_name = [ | ||
| {obj.name: result.results[e].positions[obj] for obj in result.results[0].positions} | ||
| for e in range(len(result.results)) | ||
| ] |
There was a problem hiding this comment.
Wdyt about moving this into MultiEnvPlacementResult? It's a bit difficult to decipher. Perhaps would be nicers to hide inside a method?
There was a problem hiding this comment.
Done, currently it does not build layout dicts
| num_envs = self.args.num_envs | ||
| result = placer.place(objects_with_relations, num_envs=num_envs) | ||
| if isinstance(result, MultiEnvPlacementResult) and result.results: | ||
| positions_all_envs_by_name = [ | ||
| {obj.name: result.results[e].positions[obj] for obj in result.results[0].positions} | ||
| for e in range(len(result.results)) | ||
| ] | ||
| object_names = [obj.name for obj in objects_with_relations] | ||
| anchor_names = [a.name for a in get_anchor_objects(objects_with_relations)] | ||
| self._placement_event_cfg = make_placement_event_cfg( | ||
| positions_all_envs_by_name, | ||
| object_names, | ||
| anchor_names, | ||
| ) | ||
| else: | ||
| self._placement_event_cfg = None |
There was a problem hiding this comment.
I can see what the motivation is here, but we have two quite different code paths for the single environment and multi-environment case.
I think we should unify these code paths.
Both code paths (single and multi placement) should run the solver to produce a result (of type PlacementResult or MultiEnvPlacementResult).
This result should be passed to ObjectPlacer._apply_positions.
The placement should be passed to the object through it's Object.set_initial_pose, which will have to be expanded to take something like a PosePerEnv or something like that.
Right now, depending on whether your optimizing for a single or multiple environments results in two very different paths. In the single env case, the responsibility for applying the placement is in the ObjectPlacer and Object registers the event. In the multi-env case the placement is effectively applied in the ArenaEnvBuilder, which also registers the event.
I notice you have a note somewhere that perhaps these options could be unified. I suggest we unify now.
There was a problem hiding this comment.
We unified these code paths in the current code. Each object's _init_event_cfg generates the appropriate reset event.
| task.get_events_cfg(), | ||
| ) | ||
| ] | ||
| placement_event = getattr(self, "_placement_event_cfg", None) |
There was a problem hiding this comment.
Prefer to set this in the constructor (__init__) as None unconditionally, such that the object is garunteed to have the member, rather than checking for the attribute every time
There was a problem hiding this comment.
This is removed in the new version of code, we no longer have placement_event
| placement_event = getattr(self, "_placement_event_cfg", None) | ||
| if placement_event is not None: | ||
| events_sources.append(placement_event) | ||
| events_cfg = combine_configclass_instances("EventsCfg", *events_sources) |
There was a problem hiding this comment.
See comment above. I think that the placement result should be registered through the objects, as it is the single env case.
There was a problem hiding this comment.
Done. Both single-env and multi-env paths now register placement in object levels.
| self, | ||
| objects: list[Object | ObjectReference], | ||
| ) -> PlacementResult: | ||
| num_envs: int = 1, |
There was a problem hiding this comment.
What do you think about the case that we have num_envs > 1, but we want a placement that is the same across all environments. I feel that that might be a capability we'd like to preserve. Wdyt @cvolkcvolk?
If we do want to preserve that functionality we should introduce another parameter result_per_env: bool = True
There was a problem hiding this comment.
Is see you've solved this! Thanks!
| # Per env: use best valid if any, else best-by-loss fallback | ||
| final_per_env: list[dict] = [ | ||
| ( | ||
| best_valid_positions_per_env[e] | ||
| if best_valid_positions_per_env[e] is not None | ||
| else best_any_positions_per_env[e] | ||
| ) | ||
| for e in range(num_envs) | ||
| ] |
There was a problem hiding this comment.
The idea here is that we return some result, even if some of the environments do not result in a valid placement?
Do you think that that's a safe approach.
There was a problem hiding this comment.
This one no longer exists when we have the pool-based selection.
| final_loss=best_loss, | ||
| attempts=attempt + 1, | ||
| ) | ||
| # TODO(@zhx06): Consider applying via event for consistency with multi_env. |
There was a problem hiding this comment.
I think we should do this in this MR. See comments above.
There was a problem hiding this comment.
It is now implemented in this MR.
| valid = self._validate_placement(positions_per_env[e]) | ||
| if valid and loss_e < best_valid_loss_per_env[e]: | ||
| best_valid_loss_per_env[e] = loss_e | ||
| best_valid_positions_per_env[e] = positions_per_env[e] | ||
| if loss_e < best_any_loss_per_env[e]: | ||
| best_any_loss_per_env[e] = loss_e | ||
| best_any_positions_per_env[e] = positions_per_env[e] |
There was a problem hiding this comment.
One thing that is inefficient about this loop is that it if a single environment (say "A)) fails N times, by chance, the whole thing fails. Even if environment "B" has 10 successful solutions. Because the environments are homogenous, "A" could just use the solution from "B".
I wonder if we shouldn't just calculate FACTOR * num_env solutions, and take the best num_envs. There's no need to associate a particular solution to a particular environment (because they're homogeneous).
There was a problem hiding this comment.
Done. We now generate max_placement_attempts * num_envs candidates in a single batched solver call
| if not hasattr(asset, "write_root_pose_to_sim"): | ||
| continue |
There was a problem hiding this comment.
under which conditions would an object in out list not have this method? Is that something we want to handle?
| ordered_names = [n for n in object_names if n in anchor_set] | ||
| ordered_names += [n for n in object_names if n not in anchor_set] |
There was a problem hiding this comment.
Why do the names have to be ordered. Consider writing a comment.
There was a problem hiding this comment.
This part is outdated and is replaced by the pool-based approach
alexmillane
left a comment
There was a problem hiding this comment.
Thank you for addressing the last comments. It's look a lot better (and simpler) now. Great!
I just have some smaller stylistic comments now.
| events_sources = [ | ||
| embodiment.get_events_cfg(), | ||
| self.arena_env.scene.get_events_cfg(), | ||
| task.get_events_cfg(), | ||
| ) | ||
| ] | ||
| events_cfg = combine_configclass_instances("EventsCfg", *events_sources) |
There was a problem hiding this comment.
Suggestion just to revert this change (it has no effect).
There was a problem hiding this comment.
Done. This is reverted
| self, | ||
| objects: list[Object | ObjectReference], | ||
| ) -> PlacementResult: | ||
| num_envs: int = 1, |
There was a problem hiding this comment.
Is see you've solved this! Thanks!
| rng_state = None | ||
| if self.params.placement_seed is not None: | ||
| rng_state = torch.get_rng_state() | ||
| torch.manual_seed(self.params.placement_seed + candidate_idx) | ||
| initial_positions.append(self._generate_initial_positions(objects, anchor_objects_set, init_bounds)) | ||
| if rng_state is not None: | ||
| torch.set_rng_state(rng_state) |
There was a problem hiding this comment.
Suggestion to abstract this logic into a function. Maybe _generate_initial_positions could be expanded to handle the single and multiple placement cases.
There was a problem hiding this comment.
Now the new _generate_initial_positions handles both cases.
| all_losses = ( | ||
| self._solver.last_loss_per_env.cpu().tolist() | ||
| if self._solver.last_loss_per_env is not None | ||
| else [float("inf")] * num_candidates | ||
| ) |
There was a problem hiding this comment.
Why is this needed? Seems like if we've run the solve, the last_loss_per_env should not be None or am I missing something?
There was a problem hiding this comment.
This one is removed from the updated code.
| torch.set_rng_state(rng_state) | ||
| all_candidates: list[tuple[float, dict, bool]] = [] | ||
| for idx in range(num_candidates): | ||
| loss = all_losses[idx] if idx < len(all_losses) else float("inf") |
There was a problem hiding this comment.
Do we need this check: idx < len(all_losses). Shouldn't all_losses be defined to have the correct length?
There was a problem hiding this comment.
Correct, This extra check is removed.
| n_valid = sum(1 for c in selected if c[2]) | ||
| if self.params.verbose: | ||
| total_valid = sum(1 for c in all_candidates if c[2]) | ||
| finite_losses = [c[0] for c in all_candidates if math.isfinite(c[0])] | ||
| mean_loss = sum(finite_losses) / len(finite_losses) if finite_losses else float("inf") | ||
| print( | ||
| f"Solved {num_candidates} candidates in one batch: mean loss = {mean_loss:.6f}," | ||
| f" {total_valid} valid, selected best {num_results} ({n_valid} valid)" | ||
| ) | ||
|
|
||
| return PlacementResult( | ||
| success=success, | ||
| positions=best_positions, | ||
| final_loss=best_loss, | ||
| attempts=attempt + 1, | ||
| ) | ||
| final_per_env: list[dict] = [c[1] for c in selected] | ||
| results_per_env = [ | ||
| PlacementResult( | ||
| success=c[2], | ||
| positions=c[1], | ||
| final_loss=c[0], | ||
| attempts=self.params.max_placement_attempts, | ||
| ) | ||
| for c in selected | ||
| ] |
There was a problem hiding this comment.
See comment above. This accesses of c make things a bit difficult to read. Consider using a named composite.
| if (id(a), id(b)) in on_pairs: | ||
| continue |
There was a problem hiding this comment.
Add:
# Pairs related by an OnRelation are excluded from the check.
There was a problem hiding this comment.
Done. This comment has been added.
| """ | ||
| for obj, pos in positions.items(): | ||
| num_envs = len(positions_per_env) | ||
| for obj in positions_per_env[0]: |
There was a problem hiding this comment.
This looks a bit weird because of the [0]. Could be clearer to split into two lines:
# Objects are the same for every environment. Extract them.
objects = [obj for obj in positions_per_env[0]]
# Apply pose for each object.
for obj in objects:
...
There was a problem hiding this comment.
The code has been updated following your suggestion.
| initial_positions: ( | ||
| dict[Object | ObjectReference, tuple[float, float, float]] | ||
| | list[dict[Object | ObjectReference, tuple[float, float, float]]] | ||
| ), | ||
| ) -> ( | ||
| dict[Object | ObjectReference, tuple[float, float, float]] | ||
| | list[dict[Object | ObjectReference, tuple[float, float, float]]] |
There was a problem hiding this comment.
These types are getting a bit awkward.
One way we can simply is: Object | ObjectReference -> ObjectBase (their parent class).
Do you see any opportunity to just use a list of length 1 for the single env case?
There was a problem hiding this comment.
ll type annotations now use ObjectBase.
solve() and RelationSolverState now always use list[dict].
| if isinstance(initial_positions, dict): | ||
| initial_positions = [initial_positions] |
There was a problem hiding this comment.
Seems like maybe we should just use a list no matter what? What do you think?
There was a problem hiding this comment.
This is removed and it takes list[dict] for all cases.
Summary
Add batched multi-env placement to relation solver (
1 -> num_envs)Detailed description
Solver
RelationSolverStatepositions stored as(num_envs, num_optimizable, 3);get_position()returns(num_envs, 3)RelationSolver._compute_total_loss()accumulates(num_envs,)loss per env, returns mean of losses; per-env loss exposed vialast_loss_per_envOn,NextTo,NoCollision,AtPosition) updated to accept(num_envs, 3)with(3,)backward compatObjectPlacer
place()acceptsnum_envsandresult_per_env; returnsMultiEnvPlacementResult(onePlacementResultper env) or a singlePlacementResultmax_placement_attempts * num_resultscandidates in one batched solver call, sorts by (valid-first, lowest-loss), and selects the bestnum_resultsresult_per_env=Falsesolves a single layout applied to all environments_validate_placementskipsOn-related pairs in overlap check (fix for false-positive collisions)mean_losscalculation filters out non-finite valuesUnified Placement Path
PosePerEnvdataclass holds a list ofPoseobjects, one per environmentObjectBase._init_event_cfghandlesPosePerEnvviaset_object_pose_per_env, so each object registers its own reset event for both single-env and multi-env casesplacement_events.py— its functionality is now handled by object-level event registrationArenaEnvBuilder._solve_relations()simplified: no longer builds a separate placement event cfg or distinguishes single/multi-env code pathsExample Update
gr1_table_multi_object_no_collision_environment, add TODO for potential object collision issues after running the solver