-
Notifications
You must be signed in to change notification settings - Fork 76
Description
I am working with the Qwen/Qwen2.5-7B-Instruct model and trying to enable flash_attention_2 for better performance. However, when I run my code, I encounter the following warnings:
Warning 1: Dtype Mismatch
(WorkerDict pid=95771) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the with torch.autocast(device_type='torch_device'): decorator, or load the model with the torch_dtype argument. Example: model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)
Warning
2: Device Initialization
(WorkerDict pid=95523) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda').
What is the practical impact of these warnings? My main concern is: do these warnings mean that Flash Attention 2.0 is being silently disabled, causing the model to fall back to the slower, standard attention mechanism? This would explain why the performance is not as fast as I expected.
Could you provide a definitive code snippet for loading the Qwen2-7B-Instruct model that ensures Flash Attention 2.0 is correctly and optimally utilized, avoiding these warnings?
Thank you for your time and assistance.