Skip to content

Commit 0906e63

Browse files
authored
Merge branch 'main' into pingtian/add_linear_wgrad_compute_param_api
2 parents fbae781 + 59f6f38 commit 0906e63

File tree

86 files changed

+8108
-100
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+8108
-100
lines changed

README.rst

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,22 @@ For example to use the NGC PyTorch container interactively,
175175

176176
.. code-block:: bash
177177
178-
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.08-py3
178+
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:26.01-py3
179179
180180
For example to use the NGC JAX container interactively,
181181

182182
.. code-block:: bash
183183
184-
docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.08-py3
184+
docker run --gpus all -it --rm nvcr.io/nvidia/jax:26.01-py3
185185
186-
Where 25.08 (corresponding to August 2025 release) is the container version.
186+
Where 26.01 (corresponding to January 2026 release) is the container version.
187+
188+
We recommend updating to the latest NGC container available here:
189+
190+
* https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch
191+
* https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax
192+
193+
If you run any examples, please ensure you are using a matching version of TransformerEngine. TransformerEngine is pre-built and packaged inside the containers with examples available at ``/opt/transformerengine`` or ``/opt/transformer-engine``. If you would like to use examples from TE main branch and are running into import errors, please try the latest pip package or building from source, although NGC containers are recommended for ease-of-use for most users.
187194

188195
**Benefits of using NGC containers:**
189196

build_tools/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,10 @@ def nvcc_path() -> Tuple[str, str]:
228228
def get_cuda_include_dirs() -> Tuple[str, str]:
229229
"""Returns the CUDA header directory."""
230230

231+
force_wheels = bool(int(os.getenv("NVTE_BUILD_USE_NVIDIA_WHEELS", "0")))
231232
# If cuda is installed via toolkit, all necessary headers
232233
# are bundled inside the top level cuda directory.
233-
if cuda_toolkit_include_path() is not None:
234+
if not force_wheels and cuda_toolkit_include_path() is not None:
234235
return [cuda_toolkit_include_path()]
235236

236237
# Use pip wheels to include all headers.
@@ -239,7 +240,10 @@ def get_cuda_include_dirs() -> Tuple[str, str]:
239240
except ModuleNotFoundError as e:
240241
raise RuntimeError("CUDA not found.")
241242

242-
cuda_root = Path(nvidia.__file__).parent
243+
if nvidia.__file__ is not None:
244+
cuda_root = Path(nvidia.__file__).parent
245+
else:
246+
cuda_root = Path(nvidia.__path__[0]) # namespace
243247
return [
244248
subdir / "include"
245249
for subdir in cuda_root.iterdir()
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/* Diagram color definitions for Transformer Engine documentation */
2+
3+
/* High precision (BF16/FP16) elements */
4+
.hp {
5+
fill: #ede7f6;
6+
stroke: #673ab7;
7+
stroke-width: 2;
8+
}
9+
10+
/* FP8 precision elements */
11+
.fp8 {
12+
fill: #fff8e1;
13+
stroke: #ffa726;
14+
stroke-width: 2;
15+
}
16+
17+
/* GEMM/computation operations */
18+
.gemm {
19+
fill: #ffe0b2;
20+
stroke: #fb8c00;
21+
stroke-width: 2.5;
22+
}
23+
24+
/* Quantization operations */
25+
.quantize {
26+
fill: #e8f5e9;
27+
stroke: #66bb6a;
28+
stroke-width: 2;
29+
}
30+
31+
/* Amax computation operations */
32+
.amax {
33+
fill: #e1f5fe;
34+
stroke: #039be5;
35+
stroke-width: 2;
36+
}
37+
38+
/* Text styles */
39+
.text {
40+
font-family: 'Segoe UI', Arial, sans-serif;
41+
font-size: 14px;
42+
text-anchor: middle;
43+
fill: #212121;
44+
}
45+
46+
.small-text {
47+
font-family: 'Segoe UI', Arial, sans-serif;
48+
font-size: 14px;
49+
text-anchor: middle;
50+
fill: #757575;
51+
}
52+
53+
.label {
54+
font-family: 'Segoe UI', Arial, sans-serif;
55+
font-size: 14px;
56+
text-anchor: middle;
57+
fill: #424242;
58+
}
59+
60+
.title {
61+
font-family: 'Segoe UI', Arial, sans-serif;
62+
font-size: 18px;
63+
font-weight: 600;
64+
text-anchor: middle;
65+
fill: #212121;
66+
}
67+
68+
.section-title {
69+
font-family: 'Segoe UI', Arial, sans-serif;
70+
font-size: 15px;
71+
font-weight: 600;
72+
text-anchor: middle;
73+
}
74+
75+
/* Arrows */
76+
/* Note: marker-end references #arrowhead marker which must be defined in each SVG's <defs> section */
77+
.arrow {
78+
stroke: #616161;
79+
stroke-width: 2;
80+
fill: none;
81+
marker-end: url(#arrowhead);
82+
}
83+
84+
/* Additional box and element styles */
85+
.box-blue {
86+
fill: #e3f2fd;
87+
stroke: #1976d2;
88+
stroke-width: 2;
89+
}
90+
91+
.box-orange {
92+
fill: #fff3e0;
93+
stroke: #f57c00;
94+
stroke-width: 2;
95+
}
96+
97+
.box-green {
98+
fill: #c8e6c9;
99+
stroke: #388e3c;
100+
stroke-width: 2;
101+
}
102+
103+
.box-dashed {
104+
stroke-dasharray: 5,5;
105+
}
106+
107+
/* LayerNorm specific */
108+
.layernorm {
109+
fill: #b3e5fc;
110+
stroke: #0277bd;
111+
stroke-width: 2.5;
112+
}
113+
114+
/* Fused layers */
115+
.fused {
116+
fill: #b2dfdb;
117+
stroke: #00695c;
118+
stroke-width: 3;
119+
}
120+
121+
/* Generic computation blocks */
122+
.computation {
123+
fill: #f5f5f5;
124+
stroke: #757575;
125+
stroke-width: 2;
126+
}
127+
128+
/* FP32 precision (alternative red) */
129+
.fp32 {
130+
fill: #ffcdd2;
131+
stroke: #d32f2f;
132+
stroke-width: 2.5;
133+
}
134+

docs/_static/css/sphinx_tabs.css

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/* Custom styling for sphinx-tabs */
2+
3+
.sphinx-tabs {
4+
margin-bottom: 1rem;
5+
}
6+
7+
.sphinx-tabs-tab {
8+
background-color: #f4f4f4;
9+
border: 1px solid #ccc;
10+
border-bottom: none;
11+
padding: 0.5rem 1rem;
12+
margin-right: 0.5rem;
13+
cursor: pointer;
14+
font-weight: 500;
15+
transition: background-color 0.2s;
16+
}
17+
18+
.sphinx-tabs-tab:hover {
19+
background-color: #e0e0e0;
20+
}
21+
22+
.sphinx-tabs-tab[aria-selected="true"] {
23+
background-color: #76b900; /* NVIDIA green */
24+
color: white;
25+
border-color: #76b900;
26+
margin-right: 0.5rem;
27+
}
28+
29+
.sphinx-tabs-panel {
30+
border: 1px solid #ccc;
31+
padding: 1rem;
32+
background-color: #f9f9f9;
33+
}
34+
35+
/* Dark mode support for RTD theme */
36+
.rst-content .sphinx-tabs-tab {
37+
color: #333;
38+
}
39+
40+
.rst-content .sphinx-tabs-tab[aria-selected="true"] {
41+
color: white;
42+
}
43+
44+
45+
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/* Responsive styling for SVG images */
2+
3+
/* Make all SVG images responsive */
4+
.document svg,
5+
.document object[type="image/svg+xml"],
6+
.rst-content svg {
7+
max-width: 100%;
8+
height: auto;
9+
display: block;
10+
margin: 1em auto;
11+
}
12+
13+
/* For raw HTML embedded SVGs */
14+
.document .raw-html svg {
15+
max-width: 100%;
16+
height: auto;
17+
width: 100%;
18+
}
19+
20+
/* Ensure container doesn't overflow */
21+
.document .raw-html {
22+
max-width: 100%;
23+
overflow-x: auto;
24+
}
25+
26+
/* Figure containers with captions */
27+
.svg-figure {
28+
text-align: center;
29+
margin: 20px auto;
30+
}
31+
32+
.svg-figure img {
33+
display: block;
34+
margin: 0 auto;
35+
height: auto;
36+
}
37+
38+
/* Different width classes for figures */
39+
.svg-figure.width-70 img {
40+
width: 70%;
41+
max-width: 100%;
42+
}
43+
44+
.svg-figure.width-80 img {
45+
width: 80%;
46+
max-width: 100%;
47+
}
48+
49+
.svg-figure.width-90 img {
50+
width: 90%;
51+
max-width: 100%;
52+
}
53+
54+
.svg-figure.width-100 img {
55+
width: 100%;
56+
}
57+
58+
/* Figure captions */
59+
.svg-caption {
60+
font-style: italic;
61+
margin-top: 10px;
62+
color: #555;
63+
font-size: 0.95em;
64+
line-height: 1.4;
65+
}
66+
67+
68+
69+
70+
71+
72+

docs/_templates/layout.html

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@
6767
overflow: visible !important;
6868
}
6969

70+
.quant {
71+
background-color: yellow !important;
72+
}
73+
7074
</style>
7175
<style>
7276
a:link, a:visited {

docs/conf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,11 @@
8484
html_css_files = [
8585
"css/nvidia_font.css",
8686
"css/nvidia_footer.css",
87-
"css/rtabs.css",
8887
"css/output-style.css",
88+
"css/diagram-colors.css",
89+
"css/sphinx_tabs.css",
90+
"css/svg-responsive.css",
91+
"css/rtabs.css",
8992
]
9093

9194
html_theme_options = {

docs/debug/1_getting_started.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
1515
- log the statistics for each of the tensors in every matrix multiply (GEMM) operation,
1616
- run selected GEMMs in higher precision,
1717
- run current scaling - with one scaling factor per tensor - for particular GEMMs,
18-
- test new precisions and integrate them with FP8 training,
18+
- test new precisions and integrate them with quantized training (FP8, NVFP4, etc.),
1919
- ... and many more.
2020

2121
There are 4 things one needs to do to use Transformer Engine debug features:

docs/debug/3_api_features.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ Debug features
88

99
.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats
1010
.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats
11-
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
12-
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
11+
.. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats
12+
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM
13+
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer
1314
.. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling
1415
.. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant
16+
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
17+
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer

0 commit comments

Comments
 (0)