Skip to content

cathalobrien/get-flash-attn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

get-flash-attn

A script which parses your python environment and installs the correct flash-attention wheel.

tri dao et al created flash-attention, which is released under a BSD-3 license. This script retrieves and installs pre-built flash-attention wheels.

How to

This script works for systems with nvidia GPUs. AMD GPUs and CPUs are not supported as the pre-built wheels are not available on the flash-attn github page.

# setup
git clone git@github.com:cathalobrien/get-flash-attn.git
cd get-flash-attn

# install flash attn
source /path/to/venv/bin/activate #activate your venv/conda/uv env
./get-flash-attn

Command line args

The following command line args are supported

-s|--source # which provider to download the wheels from. possible values are "all;tridao;naco". "all" will dynamically select provider based on versions
-v|--flash-attn-version $version # Which version of flash attention to install. defaults to '2.7.4.post1'
-l|--list # lists the wheels available
--offline # Prints instructions to install on an airgapped system e.g. MN5
--dryrun # Dryrun. prints commands instead of running them
--get-wheel # Downloads the wheel to your cwd and quits
--uv # Install into a UV env
--force-reinstall # Forces pip to reinstall flash-attn
--verbose # Verbose mode. all commands will be printed before execution, and wget and pip are not silenced

For legacy reasons, the env vars 'UV', 'OFFLINE' and 'DRYRUN' can be set to '1' to set their relevant flags.

UV

To install into a uv env, run the script as shown

./get-flash-attn --uv
#UV=1 ./get-flash-attn #legacy

Offline installs

The script works for systems without internet access. It will automatically detect when internet is not available and then print the URL to the correct wheel for your syetm. The you can follow the example below to install it on your system

#offline demo

# On the system without internet, run:
./get-flash-attn
# prints 'wget https://github.com/Dao-AILab/flash-attention/releases/download/...whl'

# on a system with internet, run:
wget https://github.com/Dao-AILab/flash-attention/releases/download/...whl
scp ...whl system_without_internet:

# On the system without internet, run:
pip install ...whl

Building wheels

The repo includes slurm scripts to build x86 and aarch64 wheels.

Displaying availible wheels

To plot the wheels available from a given source you can use

./get-flash-attn -s naco -l | etc/show_wheel_availability.awk
Python 3.11 - linux_aarch64:
Flash\Torch        2.6   2.7   2.8
-----------------------------------
2.7.4               0     0     0
2.7.4.post1         0     0     0
2.8.3               0     0     0

Python 3.11 - linux_x86_64:
Flash\Torch        2.6   2.7   2.8
-----------------------------------
2.7.4               1     0     0
2.7.4.post1         0     1     1
2.8.3               1     1     1

Python 3.12 - linux_aarch64:
Flash\Torch        2.6   2.7   2.8
-----------------------------------
2.7.4               1     1     1
2.7.4.post1         1     1     0
2.8.3               1     1     1

Python 3.12 - linux\_x86\_64:
Flash\Torch        2.6   2.7   2.8
-----------------------------------
2.7.4               0     0     0
2.7.4.post1         0     0     0
2.8.3               0     0     0

TODO

  • modify built wheels names to include cuda, torch and ABI information in build
  • add an option to build with ABI true/false
  • add wheel coverage matrix
  • build ABI true
  • store cuda minor version in naco wheels

About

A script which parses your python environment and grabs the correct flash-attention wheel from github

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors