-
Notifications
You must be signed in to change notification settings - Fork 117
Description
Version
2.3.0
On which installation method(s) does this occur?
Source
Describe the issue
Hi PhysicsNeMo team! I found an issue while running the Turbulent Channel example with CUDA graphs enabled - in brief, the custom aggregator's loss weights get stuck with whatever weights were used when the gradient step was recorded to CUDA graph.
The issue is caused by Trainer._cuda_graph_training_step passing step as an int to self.compute_gradients function. When the graph is captured after graph warmup (e.g. step 20), the graph will always be replayed with whichever step was captured, so the entire gradient calculation will always be done as if it were step 20. This affects both Aggregators and loss functions. The solution I have is to pass step as a Tensor instead and update it in-place each step.
CUDA graphs also can't handle dynamic control flow, which is used in a few aggregators (GradNorm, LRAnnealing etc). These fail when run with CUDA graphs, replacing the ifs with torch.wheres should work.
I've made some edits to my local version already, and if the proposals above sound good I'd like to contribute those changes π
Thanks!
Minimum reproducible example
Relevant log output
Environment details
Other/Misc.
No response