@@ -568,6 +568,8 @@ def test_model_from_config_dtype_composite(self):
568568 """
569569 Test that from_pretrained works with dtype being as a dict per each sub-config in composite config
570570 Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
571+ Note, this is a deprecated feature and we fallback to main dtype in all cases below. This test checks
572+ if the dtype fallback works correctly.
571573 """
572574 # Load without dtype specified
573575 model = LlavaForConditionalGeneration .from_pretrained (TINY_LLAVA )
@@ -586,42 +588,42 @@ def test_model_from_config_dtype_composite(self):
586588 self .assertEqual (model .model .vision_tower .dtype , torch .float16 )
587589 self .assertIsInstance (model .config .dtype , torch .dtype )
588590
589- # should be able to set dtype as a dict for each sub-config
591+ # should be able to accept dtype as a dict for each sub-config
590592 model = LlavaForConditionalGeneration .from_pretrained (
591593 TINY_LLAVA , dtype = {"text_config" : "float32" , "vision_config" : "float16" , "" : "bfloat16" }
592594 )
593- self .assertEqual (model .model .language_model .dtype , torch .float32 )
594- self .assertEqual (model .model .vision_tower .dtype , torch .float16 )
595+ self .assertEqual (model .model .language_model .dtype , torch .bfloat16 )
596+ self .assertEqual (model .model .vision_tower .dtype , torch .bfloat16 )
595597 self .assertEqual (model .model .multi_modal_projector .linear_1 .weight .dtype , torch .bfloat16 )
596598 self .assertIsInstance (model .config .dtype , torch .dtype )
597599
598- # should be able to set the values as torch.dtype (not str)
600+ # should be able to accept the values as torch.dtype (not str)
599601 model = LlavaForConditionalGeneration .from_pretrained (
600602 TINY_LLAVA , dtype = {"text_config" : torch .float32 , "vision_config" : torch .float16 , "" : torch .bfloat16 }
601603 )
602- self .assertEqual (model .model .language_model .dtype , torch .float32 )
603- self .assertEqual (model .model .vision_tower .dtype , torch .float16 )
604+ self .assertEqual (model .model .language_model .dtype , torch .bfloat16 )
605+ self .assertEqual (model .model .vision_tower .dtype , torch .bfloat16 )
604606 self .assertEqual (model .model .multi_modal_projector .linear_1 .weight .dtype , torch .bfloat16 )
605607 self .assertIsInstance (model .config .dtype , torch .dtype )
606608
607- # should be able to set the values in configs directly and pass it to `from_pretrained`
609+ # should be able to accept the values in configs directly and pass it to `from_pretrained`
608610 config = copy .deepcopy (model .config )
609611 config .text_config .dtype = torch .float32
610612 config .vision_config .dtype = torch .bfloat16
611613 config .dtype = torch .float16
612614 model = LlavaForConditionalGeneration .from_pretrained (TINY_LLAVA , config = config , dtype = "auto" )
613- self .assertEqual (model .model .language_model .dtype , torch .float32 )
614- self .assertEqual (model .model .vision_tower .dtype , torch .bfloat16 )
615+ self .assertEqual (model .model .language_model .dtype , torch .float16 )
616+ self .assertEqual (model .model .vision_tower .dtype , torch .float16 )
615617 self .assertEqual (model .model .multi_modal_projector .linear_1 .weight .dtype , torch .float16 )
616618 self .assertIsInstance (model .config .dtype , torch .dtype )
617619
618620 # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
619621 LlavaForConditionalGeneration ._keep_in_fp32_modules = ["multi_modal_projector" ]
620622 model = LlavaForConditionalGeneration .from_pretrained (TINY_LLAVA , config = config , dtype = "auto" )
621623 self .assertEqual (
622- model .model .language_model .dtype , torch .float32
624+ model .model .language_model .dtype , torch .float16
623625 ) # remember config says float32 for text_config
624- self .assertEqual (model .model .vision_tower .dtype , torch .bfloat16 )
626+ self .assertEqual (model .model .vision_tower .dtype , torch .float16 )
625627 self .assertEqual (model .model .multi_modal_projector .linear_1 .weight .dtype , torch .float32 )
626628 self .assertIsInstance (model .config .dtype , torch .dtype )
627629
0 commit comments