Skip to content

Commit 219e925

Browse files
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
2 parents ede33b4 + d9152b0 commit 219e925

File tree

75 files changed

+7799
-1204
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+7799
-1204
lines changed

examples/jax/encoder/run_test_multiprocessing_encoder.sh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ TEST_CASES=(
1111
"test_te_current_scaling_fp8"
1212
"test_te_mxfp8"
1313
"test_te_nvfp4"
14-
"test_te_bf16_shardy"
15-
"test_te_delayed_scaling_fp8_shardy"
16-
"test_te_current_scaling_fp8_shardy"
17-
"test_te_nvfp4_shardy"
1814
)
1915

2016
: ${TE_PATH:=/opt/transformerengine}

examples/jax/encoder/test_model_parallel_encoder.py

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def check_fp8(state, var_collect, inputs, masks, labels):
239239
def train_and_evaluate(args):
240240
"""Execute model training and evaluation loop."""
241241
print(args)
242-
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
243242

244243
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
245244

@@ -474,9 +473,6 @@ def encoder_parser(args):
474473
parser.add_argument(
475474
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
476475
)
477-
parser.add_argument(
478-
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
479-
)
480476

481477
return parser.parse_args(args)
482478

@@ -559,70 +555,6 @@ def test_te_nvfp4_with_sp(self):
559555
actual = train_and_evaluate(self.args)
560556
assert actual[0] < 0.40 and actual[1] > 0.82
561557

562-
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
563-
def test_te_bf16_shardy(self):
564-
"""Test Transformer Engine with BF16"""
565-
self.args.enable_shardy = True
566-
actual = train_and_evaluate(self.args)
567-
assert actual[0] < 0.36 and actual[1] > 0.84
568-
569-
@unittest.skipIf(not is_fp8_supported, fp8_reason)
570-
def test_te_delayed_scaling_fp8_shardy(self):
571-
"""Test Transformer Engine with DelayedScaling FP8"""
572-
self.args.enable_shardy = True
573-
self.args.use_fp8 = True
574-
self.args.fp8_recipe = "DelayedScaling"
575-
actual = train_and_evaluate(self.args)
576-
assert actual[0] < 0.362 and actual[1] > 0.84
577-
578-
@unittest.skipIf(not is_fp8_supported, fp8_reason)
579-
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
580-
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
581-
self.args.enable_shardy = True
582-
self.args.enable_sp = True
583-
self.args.use_fp8 = True
584-
self.args.fp8_recipe = "DelayedScaling"
585-
actual = train_and_evaluate(self.args)
586-
assert actual[0] < 0.362 and actual[1] > 0.84
587-
588-
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
589-
def test_te_mxfp8_shardy(self):
590-
"""Test Transformer Engine with MXFP8"""
591-
self.args.enable_shardy = True
592-
self.args.use_fp8 = True
593-
self.args.fp8_recipe = "MXFP8BlockScaling"
594-
actual = train_and_evaluate(self.args)
595-
assert actual[0] < 0.36 and actual[1] > 0.84
596-
597-
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
598-
def test_te_nvfp4_shardy(self):
599-
"""Test Transformer Engine with NVFP4"""
600-
self.args.enable_shardy = True
601-
self.args.use_fp8 = True
602-
self.args.fp8_recipe = "NVFP4BlockScaling"
603-
actual = train_and_evaluate(self.args)
604-
assert actual[0] < 0.40 and actual[1] > 0.82
605-
606-
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
607-
def test_te_mxfp8_with_sp_shardy(self):
608-
"""Test Transformer Engine with MXFP8 + SP"""
609-
self.args.enable_shardy = True
610-
self.args.enable_sp = True
611-
self.args.use_fp8 = True
612-
self.args.fp8_recipe = "MXFP8BlockScaling"
613-
actual = train_and_evaluate(self.args)
614-
assert actual[0] < 0.36 and actual[1] > 0.84
615-
616-
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
617-
def test_te_nvfp4_with_sp_shardy(self):
618-
"""Test Transformer Engine with NVFP4"""
619-
self.args.enable_shardy = True
620-
self.args.enable_sp = True
621-
self.args.use_fp8 = True
622-
self.args.fp8_recipe = "NVFP4BlockScaling"
623-
actual = train_and_evaluate(self.args)
624-
assert actual[0] < 0.40 and actual[1] > 0.82
625-
626558

627559
if __name__ == "__main__":
628560
train_and_evaluate(encoder_parser(None))

examples/jax/encoder/test_multigpu_encoder.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ def replace_params(x):
249249
def train_and_evaluate(args):
250250
"""Execute model training and evaluation loop."""
251251
print(args)
252-
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
253252
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
254253

255254
num_gpu = jax.local_device_count()
@@ -438,9 +437,6 @@ def encoder_parser(args):
438437
default="DelayedScaling",
439438
help="Use FP8 recipe (default: DelayedScaling)",
440439
)
441-
parser.add_argument(
442-
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
443-
)
444440

445441
return parser.parse_args(args)
446442

@@ -494,49 +490,6 @@ def test_te_nvfp4(self):
494490
actual = train_and_evaluate(self.args)
495491
assert actual[0] < 0.52 and actual[1] > 0.74
496492

497-
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
498-
def test_te_bf16_shardy(self):
499-
"""Test Transformer Engine with BF16"""
500-
self.args.enable_shardy = True
501-
actual = train_and_evaluate(self.args)
502-
assert actual[0] < 0.51 and actual[1] > 0.75
503-
504-
@unittest.skipIf(not is_fp8_supported, fp8_reason)
505-
def test_te_delayed_scaling_fp8_shardy(self):
506-
"""Test Transformer Engine with DelayedScaling FP8"""
507-
self.args.enable_shardy = True
508-
self.args.use_fp8 = True
509-
self.args.fp8_recipe = "DelayedScaling"
510-
actual = train_and_evaluate(self.args)
511-
assert actual[0] < 0.51 and actual[1] > 0.75
512-
513-
@unittest.skipIf(not is_fp8_supported, fp8_reason)
514-
def test_te_current_scaling_fp8_shardy(self):
515-
"""Test Transformer Engine with CurrentScaling FP8"""
516-
self.args.enable_shardy = True
517-
self.args.use_fp8 = True
518-
self.args.fp8_recipe = "Float8CurrentScaling"
519-
actual = train_and_evaluate(self.args)
520-
assert actual[0] < 0.51 and actual[1] > 0.749
521-
522-
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
523-
def test_te_mxfp8_shardy(self):
524-
"""Test Transformer Engine with MXFP8"""
525-
self.args.enable_shardy = True
526-
self.args.use_fp8 = True
527-
self.args.fp8_recipe = "MXFP8BlockScaling"
528-
actual = train_and_evaluate(self.args)
529-
assert actual[0] < 0.51 and actual[1] > 0.75
530-
531-
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
532-
def test_te_nvfp4_shardy(self):
533-
"""Test Transformer Engine with NVFP4"""
534-
self.args.enable_shardy = True
535-
self.args.use_fp8 = True
536-
self.args.fp8_recipe = "NVFP4BlockScaling"
537-
actual = train_and_evaluate(self.args)
538-
assert actual[0] < 0.52 and actual[1] > 0.74
539-
540493

541494
if __name__ == "__main__":
542495
train_and_evaluate(encoder_parser(None))

examples/jax/encoder/test_multiprocessing_encoder.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,6 @@ def replace_params(x):
359359
def train_and_evaluate(args):
360360
"""Execute model training and evaluation loop."""
361361
print(args)
362-
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
363362
if args.process_id == 0:
364363
nltk.download("punkt_tab")
365364

@@ -605,9 +604,6 @@ def encoder_parser(args):
605604
default=0,
606605
help="the ID number of the current process (default: 0)",
607606
)
608-
parser.add_argument(
609-
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
610-
)
611607

612608
return parser.parse_args(args)
613609

@@ -616,7 +612,7 @@ def encoder_parser(args):
616612
class TestEncoder(unittest.TestCase):
617613
"""Encoder unittests"""
618614

619-
def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
615+
def exec(self, use_fp8, fp8_recipe):
620616
"""Run 5 epochs for testing"""
621617
args = encoder_parser(["--epochs", "5"])
622618

@@ -632,7 +628,6 @@ def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
632628
args.num_process = num_gpu
633629
args.process_id = self.process_id
634630
args.fp8_recipe = fp8_recipe
635-
args.enable_shardy = enable_shardy
636631

637632
return train_and_evaluate(args)
638633

@@ -674,44 +669,6 @@ def test_te_nvfp4(self):
674669
result = self.exec(True, "NVFP4BlockScaling")
675670
assert result[0] < 0.451 and result[1] > 0.787
676671

677-
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
678-
def test_te_bf16_shardy(self):
679-
"""Test Transformer Engine with BF16"""
680-
result = self.exec(False, None, enable_shardy=True)
681-
assert result[0] < 0.43 and result[1] > 0.80
682-
683-
@unittest.skipIf(
684-
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
685-
)
686-
def test_te_delayed_scaling_fp8_shardy(self):
687-
"""Test Transformer Engine with DelayedScaling FP8"""
688-
result = self.exec(True, "DelayedScaling", enable_shardy=True)
689-
assert result[0] < 0.43 and result[1] > 0.80
690-
691-
@unittest.skipIf(
692-
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
693-
)
694-
def test_te_current_scaling_fp8_shardy(self):
695-
"""Test Transformer Engine with CurrentScaling FP8"""
696-
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
697-
assert result[0] < 0.432 and result[1] > 0.80
698-
699-
@unittest.skipIf(
700-
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
701-
)
702-
def test_te_mxfp8_shardy(self):
703-
"""Test Transformer Engine with MXFP8"""
704-
result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True)
705-
assert result[0] < 0.43 and result[1] > 0.80
706-
707-
@unittest.skipIf(
708-
not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4"
709-
)
710-
def test_te_nvfp4_shardy(self):
711-
"""Test Transformer Engine with NVFP4"""
712-
result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True)
713-
assert result[0] < 0.451 and result[1] > 0.787
714-
715672

716673
if __name__ == "__main__":
717674
train_and_evaluate(encoder_parser(None))

0 commit comments

Comments
 (0)