Skip to content

πŸ›[BUG]: Loss Functions and Aggregators break with CUDA graphsΒ #279

@jasooney23

Description

@jasooney23

Version

2.3.0

On which installation method(s) does this occur?

Source

Describe the issue

Hi PhysicsNeMo team! I found an issue while running the Turbulent Channel example with CUDA graphs enabled - in brief, the custom aggregator's loss weights get stuck with whatever weights were used when the gradient step was recorded to CUDA graph.

The issue is caused by Trainer._cuda_graph_training_step passing step as an int to self.compute_gradients function. When the graph is captured after graph warmup (e.g. step 20), the graph will always be replayed with whichever step was captured, so the entire gradient calculation will always be done as if it were step 20. This affects both Aggregators and loss functions. The solution I have is to pass step as a Tensor instead and update it in-place each step.

CUDA graphs also can't handle dynamic control flow, which is used in a few aggregators (GradNorm, LRAnnealing etc). These fail when run with CUDA graphs, replacing the ifs with torch.wheres should work.

I've made some edits to my local version already, and if the proposals above sound good I'd like to contribute those changes πŸ˜ƒ

Thanks!

Minimum reproducible example

Relevant log output

Environment details

Other/Misc.

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    ? - Needs TriageNeed team to review and classifybugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions