-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Set dtype default to float32 #4778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Set dtype default to float32 #4778
Conversation
This reverts commit 9158319.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
For the record, QLoRA with DPO will still force the model dtype to be in fp32: from peft import LoraConfig
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import DPOTrainer
from transformers import BitsAndBytesConfig
model = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
dtype="float32",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
trainer = DPOTrainer(
model=model,
train_dataset=dataset,
peft_config=LoraConfig(),
)
trainer.train()but in my opinion it's fine, it will be fixed by #3906 |
|
Thanks for your review, @qgallouedec. Although the CI was green, I see you set float32 dtype in some tests, but not in others. I am wondering what criteria you used. |
Set dtype default to float32.
Follow-up to: