Skip to content

Batch Support for Relation Solver#512

Open
zhx06 wants to merge 12 commits intozxiao/bbox_batch_supportfrom
zxiao/solver_batch_support_from_batch_bbox
Open

Batch Support for Relation Solver#512
zhx06 wants to merge 12 commits intozxiao/bbox_batch_supportfrom
zxiao/solver_batch_support_from_batch_bbox

Conversation

@zhx06
Copy link
Copy Markdown
Collaborator

@zhx06 zhx06 commented Mar 30, 2026

Summary

Add batched multi-env placement to relation solver (1 -> num_envs)

Detailed description

Solver

  • RelationSolverState positions stored as (num_envs, num_optimizable, 3); get_position() returns (num_envs, 3)
  • RelationSolver._compute_total_loss() accumulates (num_envs,) loss per env, returns mean of losses; per-env loss exposed via last_loss_per_env
  • All loss strategies (On, NextTo, NoCollision, AtPosition) updated to accept (num_envs, 3) with (3,) backward compat

ObjectPlacer

  • place() accepts num_envs and result_per_env; returns MultiEnvPlacementResult (one PlacementResult per env) or a single PlacementResult
  • Pool-based placement: generates max_placement_attempts * num_results candidates in one batched solver call, sorts by (valid-first, lowest-loss), and selects the best num_results
  • result_per_env=False solves a single layout applied to all environments
  • _validate_placement skips On-related pairs in overlap check (fix for false-positive collisions)
  • mean_loss calculation filters out non-finite values

Unified Placement Path

  • New PosePerEnv dataclass holds a list of Pose objects, one per environment
  • ObjectBase._init_event_cfg handles PosePerEnv via set_object_pose_per_env, so each object registers its own reset event for both single-env and multi-env cases
  • Removed placement_events.py — its functionality is now handled by object-level event registration
  • ArenaEnvBuilder._solve_relations() simplified: no longer builds a separate placement event cfg or distinguishes single/multi-env code paths

Example Update

  • Update default objects in gr1_table_multi_object_no_collision_environment, add TODO for potential object collision issues after running the solver

@zhx06 zhx06 changed the title Zxiao/solver batch support from batch bbox Zxiao/batch support for relation solver Mar 30, 2026
@zhx06 zhx06 changed the title Zxiao/batch support for relation solver Batch Support for Relation Solver Mar 30, 2026
Copy link
Copy Markdown
Collaborator

@alexmillane alexmillane left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work.

I have some suggestions for how we might improve things.

Comment on lines +109 to +112
positions_all_envs_by_name = [
{obj.name: result.results[e].positions[obj] for obj in result.results[0].positions}
for e in range(len(result.results))
]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wdyt about moving this into MultiEnvPlacementResult? It's a bit difficult to decipher. Perhaps would be nicers to hide inside a method?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, currently it does not build layout dicts

Comment on lines +106 to +121
num_envs = self.args.num_envs
result = placer.place(objects_with_relations, num_envs=num_envs)
if isinstance(result, MultiEnvPlacementResult) and result.results:
positions_all_envs_by_name = [
{obj.name: result.results[e].positions[obj] for obj in result.results[0].positions}
for e in range(len(result.results))
]
object_names = [obj.name for obj in objects_with_relations]
anchor_names = [a.name for a in get_anchor_objects(objects_with_relations)]
self._placement_event_cfg = make_placement_event_cfg(
positions_all_envs_by_name,
object_names,
anchor_names,
)
else:
self._placement_event_cfg = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see what the motivation is here, but we have two quite different code paths for the single environment and multi-environment case.

I think we should unify these code paths.

Both code paths (single and multi placement) should run the solver to produce a result (of type PlacementResult or MultiEnvPlacementResult).

This result should be passed to ObjectPlacer._apply_positions.

The placement should be passed to the object through it's Object.set_initial_pose, which will have to be expanded to take something like a PosePerEnv or something like that.

Right now, depending on whether your optimizing for a single or multiple environments results in two very different paths. In the single env case, the responsibility for applying the placement is in the ObjectPlacer and Object registers the event. In the multi-env case the placement is effectively applied in the ArenaEnvBuilder, which also registers the event.

I notice you have a note somewhere that perhaps these options could be unified. I suggest we unify now.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We unified these code paths in the current code. Each object's _init_event_cfg generates the appropriate reset event.

task.get_events_cfg(),
)
]
placement_event = getattr(self, "_placement_event_cfg", None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to set this in the constructor (__init__) as None unconditionally, such that the object is garunteed to have the member, rather than checking for the attribute every time

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is removed in the new version of code, we no longer have placement_event

Comment on lines +188 to +191
placement_event = getattr(self, "_placement_event_cfg", None)
if placement_event is not None:
events_sources.append(placement_event)
events_cfg = combine_configclass_instances("EventsCfg", *events_sources)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above. I think that the placement result should be registered through the objects, as it is the single env case.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Both single-env and multi-env paths now register placement in object levels.

self,
objects: list[Object | ObjectReference],
) -> PlacementResult:
num_envs: int = 1,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about the case that we have num_envs > 1, but we want a placement that is the same across all environments. I feel that that might be a capability we'd like to preserve. Wdyt @cvolkcvolk?

If we do want to preserve that functionality we should introduce another parameter result_per_env: bool = True

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is see you've solved this! Thanks!

Comment on lines +139 to +147
# Per env: use best valid if any, else best-by-loss fallback
final_per_env: list[dict] = [
(
best_valid_positions_per_env[e]
if best_valid_positions_per_env[e] is not None
else best_any_positions_per_env[e]
)
for e in range(num_envs)
]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea here is that we return some result, even if some of the environments do not result in a valid placement?

Do you think that that's a safe approach.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one no longer exists when we have the pool-based selection.

final_loss=best_loss,
attempts=attempt + 1,
)
# TODO(@zhx06): Consider applying via event for consistency with multi_env.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do this in this MR. See comments above.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is now implemented in this MR.

Comment on lines +118 to +124
valid = self._validate_placement(positions_per_env[e])
if valid and loss_e < best_valid_loss_per_env[e]:
best_valid_loss_per_env[e] = loss_e
best_valid_positions_per_env[e] = positions_per_env[e]
if loss_e < best_any_loss_per_env[e]:
best_any_loss_per_env[e] = loss_e
best_any_positions_per_env[e] = positions_per_env[e]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that is inefficient about this loop is that it if a single environment (say "A)) fails N times, by chance, the whole thing fails. Even if environment "B" has 10 successful solutions. Because the environments are homogenous, "A" could just use the solution from "B".

I wonder if we shouldn't just calculate FACTOR * num_env solutions, and take the best num_envs. There's no need to associate a particular solution to a particular environment (because they're homogeneous).

Copy link
Copy Markdown
Collaborator Author

@zhx06 zhx06 Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. We now generate max_placement_attempts * num_envs candidates in a single batched solver call

Comment on lines +83 to +84
if not hasattr(asset, "write_root_pose_to_sim"):
continue
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

under which conditions would an object in out list not have this method? Is that something we want to handle?

Comment on lines +76 to +77
ordered_names = [n for n in object_names if n in anchor_set]
ordered_names += [n for n in object_names if n not in anchor_set]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do the names have to be ordered. Consider writing a comment.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is outdated and is replaced by the pool-based approach

Copy link
Copy Markdown
Collaborator

@alexmillane alexmillane left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing the last comments. It's look a lot better (and simpler) now. Great!

I just have some smaller stylistic comments now.

Comment on lines +169 to +174
events_sources = [
embodiment.get_events_cfg(),
self.arena_env.scene.get_events_cfg(),
task.get_events_cfg(),
)
]
events_cfg = combine_configclass_instances("EventsCfg", *events_sources)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion just to revert this change (it has no effect).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. This is reverted

self,
objects: list[Object | ObjectReference],
) -> PlacementResult:
num_envs: int = 1,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is see you've solved this! Thanks!

Comment on lines +103 to +109
rng_state = None
if self.params.placement_seed is not None:
rng_state = torch.get_rng_state()
torch.manual_seed(self.params.placement_seed + candidate_idx)
initial_positions.append(self._generate_initial_positions(objects, anchor_objects_set, init_bounds))
if rng_state is not None:
torch.set_rng_state(rng_state)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion to abstract this logic into a function. Maybe _generate_initial_positions could be expanded to handle the single and multiple placement cases.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the new _generate_initial_positions handles both cases.

Comment on lines +112 to +116
all_losses = (
self._solver.last_loss_per_env.cpu().tolist()
if self._solver.last_loss_per_env is not None
else [float("inf")] * num_candidates
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Seems like if we've run the solve, the last_loss_per_env should not be None or am I missing something?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is removed from the updated code.

torch.set_rng_state(rng_state)
all_candidates: list[tuple[float, dict, bool]] = []
for idx in range(num_candidates):
loss = all_losses[idx] if idx < len(all_losses) else float("inf")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this check: idx < len(all_losses). Shouldn't all_losses be defined to have the correct length?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, This extra check is removed.

Comment on lines +128 to +147
n_valid = sum(1 for c in selected if c[2])
if self.params.verbose:
total_valid = sum(1 for c in all_candidates if c[2])
finite_losses = [c[0] for c in all_candidates if math.isfinite(c[0])]
mean_loss = sum(finite_losses) / len(finite_losses) if finite_losses else float("inf")
print(
f"Solved {num_candidates} candidates in one batch: mean loss = {mean_loss:.6f},"
f" {total_valid} valid, selected best {num_results} ({n_valid} valid)"
)

return PlacementResult(
success=success,
positions=best_positions,
final_loss=best_loss,
attempts=attempt + 1,
)
final_per_env: list[dict] = [c[1] for c in selected]
results_per_env = [
PlacementResult(
success=c[2],
positions=c[1],
final_loss=c[0],
attempts=self.params.max_placement_attempts,
)
for c in selected
]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above. This accesses of c make things a bit difficult to read. Consider using a named composite.

Comment on lines +260 to +261
if (id(a), id(b)) in on_pairs:
continue
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add:

# Pairs related by an OnRelation are excluded from the check.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. This comment has been added.

"""
for obj, pos in positions.items():
num_envs = len(positions_per_env)
for obj in positions_per_env[0]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit weird because of the [0]. Could be clearer to split into two lines:

# Objects are the same for every environment. Extract them.
objects = [obj for obj in positions_per_env[0]]
# Apply pose for each object.
for obj in objects:
    ...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code has been updated following your suggestion.

Comment on lines +123 to +129
initial_positions: (
dict[Object | ObjectReference, tuple[float, float, float]]
| list[dict[Object | ObjectReference, tuple[float, float, float]]]
),
) -> (
dict[Object | ObjectReference, tuple[float, float, float]]
| list[dict[Object | ObjectReference, tuple[float, float, float]]]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These types are getting a bit awkward.

One way we can simply is: Object | ObjectReference -> ObjectBase (their parent class).

Do you see any opportunity to just use a list of length 1 for the single env case?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ll type annotations now use ObjectBase.

solve() and RelationSolverState now always use list[dict].

Comment on lines +45 to +46
if isinstance(initial_positions, dict):
initial_positions = [initial_positions]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like maybe we should just use a list no matter what? What do you think?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is removed and it takes list[dict] for all cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants