The function train_step corresponds to a forward and backward pass through a 3 layered NequIP model implemented using e3nn-jax acting on a simple Tetris dataset. Thanks @ameya98 @mariogeiger for the code !
Here's a brief summary of the under the hood story:
-
XLA is unable to pattern match or generate a small subset of fused kernels for the compuatation (See arxiv:2301.13062 to understand how XLA works). Instead its left with around ~300 kernels (half of which are cuBLAS/CUTLASS calls) that it needs to execute at runtime (small chunks below
Thunk:#hlo_opin theTSLrow) -
This makes the compiler fall back to CUDAGraphs which batches the execution of these kernels. However, the execution graph needs to be updated with new inputs at runtime (~30% runtime overhead before
Graph 7is launched on the GPU). This overhead (notice theCUDA APIrow) increases with the size of the computation graph.
Ideally, the compiler/human should be giving us one forward and one backward fused kernel for our computation (See FlashAttention).
pip install requirements.txtTo reproduce the profile shown above install NVIDIA Nsight Systems and run the following command (borrowed from JAX-Toolbox)
nsys profile --capture-range=cudaProfilerApi --cuda-graph-trace=node --capture-range-end=stop -o nequip_profile_disable_cudagraph -f true python train.py- Add a MLP-equivalent to show what non-CUDAGraph fusion should look like
- More profiling:
- Add
TensorProduct,TensorProductLinearandTensorProductLinearGate - Allegro-JAX and MACE-JAX
- Add
