Skip to content

Commit 7918b15

Browse files
Deprecate dtype per sub config (#42990)
* deprecate dtypes per sub config * fix test * Update tests/utils/test_modeling_utils.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> * address comment --------- Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
1 parent acb0f22 commit 7918b15

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

src/transformers/modeling_utils.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,6 @@ def _get_dtype(
759759
2. Else, use the dtype provided as a dict or str
760760
"""
761761
is_sharded = sharded_metadata is not None
762-
asked_dtype = dtype
763762

764763
if dtype is not None:
765764
if isinstance(dtype, str):
@@ -806,24 +805,21 @@ def _get_dtype(
806805
if isinstance(dtype, dict):
807806
main_dtype = dtype.get("", torch.get_default_dtype())
808807
main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
808+
809+
logger.warning_once(
810+
"Using different dtypes per module is deprecated and will be removed in future versions "
811+
"Setting different dtypes per backbone model might cause device errors downstream, therefore "
812+
f"setting the dtype={main_dtype} for all modules."
813+
)
814+
809815
else:
810816
main_dtype = dtype
811817

812818
# Set it on the config and subconfigs
813819
config.dtype = main_dtype
814820
for sub_config_key in config.sub_configs:
815821
if (sub_config := getattr(config, sub_config_key)) is not None:
816-
# The dtype was "auto" -> try to read the subconfig dtype value if any
817-
if asked_dtype == "auto":
818-
sub_dtype = getattr(sub_config, "dtype", main_dtype)
819-
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
820-
# The dtype was provided as a dict, try to see if we match the subconfig name
821-
elif isinstance(dtype, dict):
822-
sub_dtype = dtype.get(sub_config_key, main_dtype)
823-
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
824-
else:
825-
sub_dtype = main_dtype
826-
sub_config.dtype = sub_dtype
822+
sub_config.dtype = main_dtype
827823

828824
return config, main_dtype
829825

tests/utils/test_modeling_utils.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ def test_model_from_config_dtype_composite(self):
568568
"""
569569
Test that from_pretrained works with dtype being as a dict per each sub-config in composite config
570570
Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
571+
Note, this is a deprecated feature and we fallback to main dtype in all cases below. This test checks
572+
if the dtype fallback works correctly.
571573
"""
572574
# Load without dtype specified
573575
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA)
@@ -586,42 +588,42 @@ def test_model_from_config_dtype_composite(self):
586588
self.assertEqual(model.model.vision_tower.dtype, torch.float16)
587589
self.assertIsInstance(model.config.dtype, torch.dtype)
588590

589-
# should be able to set dtype as a dict for each sub-config
591+
# should be able to accept dtype as a dict for each sub-config
590592
model = LlavaForConditionalGeneration.from_pretrained(
591593
TINY_LLAVA, dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"}
592594
)
593-
self.assertEqual(model.model.language_model.dtype, torch.float32)
594-
self.assertEqual(model.model.vision_tower.dtype, torch.float16)
595+
self.assertEqual(model.model.language_model.dtype, torch.bfloat16)
596+
self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16)
595597
self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
596598
self.assertIsInstance(model.config.dtype, torch.dtype)
597599

598-
# should be able to set the values as torch.dtype (not str)
600+
# should be able to accept the values as torch.dtype (not str)
599601
model = LlavaForConditionalGeneration.from_pretrained(
600602
TINY_LLAVA, dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16}
601603
)
602-
self.assertEqual(model.model.language_model.dtype, torch.float32)
603-
self.assertEqual(model.model.vision_tower.dtype, torch.float16)
604+
self.assertEqual(model.model.language_model.dtype, torch.bfloat16)
605+
self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16)
604606
self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
605607
self.assertIsInstance(model.config.dtype, torch.dtype)
606608

607-
# should be able to set the values in configs directly and pass it to `from_pretrained`
609+
# should be able to accept the values in configs directly and pass it to `from_pretrained`
608610
config = copy.deepcopy(model.config)
609611
config.text_config.dtype = torch.float32
610612
config.vision_config.dtype = torch.bfloat16
611613
config.dtype = torch.float16
612614
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, dtype="auto")
613-
self.assertEqual(model.model.language_model.dtype, torch.float32)
614-
self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16)
615+
self.assertEqual(model.model.language_model.dtype, torch.float16)
616+
self.assertEqual(model.model.vision_tower.dtype, torch.float16)
615617
self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
616618
self.assertIsInstance(model.config.dtype, torch.dtype)
617619

618620
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
619621
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
620622
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, dtype="auto")
621623
self.assertEqual(
622-
model.model.language_model.dtype, torch.float32
624+
model.model.language_model.dtype, torch.float16
623625
) # remember config says float32 for text_config
624-
self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16)
626+
self.assertEqual(model.model.vision_tower.dtype, torch.float16)
625627
self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
626628
self.assertIsInstance(model.config.dtype, torch.dtype)
627629

0 commit comments

Comments
 (0)