Skip to content

Commit 14ddb43

Browse files
LucienXianpre-commit-ci[bot]ksivaman
authored
Fix meta device check failure when passing torch.device objects (#2519)
* Fix meta device check failure when passing torch.device objects Signed-off-by: LucienXian <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: LucienXian <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 442513c commit 14ddb43

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,8 @@ def __init__(
702702
if self.primary_weights_in_fp8:
703703
self.init_fp8_metadata(num_gemms=self.num_gemms)
704704

705-
self.reset_parameters(defer_init=device == "meta")
705+
is_meta = torch.device(device).type == "meta"
706+
self.reset_parameters(defer_init=is_meta)
706707

707708
if self.wgrad_store.delay_wgrad_compute():
708709
for name, param in self.named_parameters():

0 commit comments

Comments
 (0)