Attention

Supported attention modules in Axolotl

SDP Attention

This is the default built-in attention in PyTorch.

sdp_attention: true

For more details: PyTorch docs

Flash Attention

Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically based on your installed packages and GPU.

flash_attention: true

For more details: Flash Attention

Flash Attention 2

Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)

pip install flash-attn --no-build-isolation
Tip

If you get undefined symbol while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.

Flash Attention 3

Requirements: Hopper only and CUDA 12.8 (recommended)

git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper

python setup.py install

Flash Attention 4

Requirements: Hopper or Blackwell GPUs

pip install flash-attn-4

Or from source:

git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute

pip install -e .

# FA2's flash_attn package includes a cute/ stub that shadows FA4.
# Remove it so Python can find the real FA4 module:
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
Note

Hopper (SM90) users: The backward kernel is not yet included in the pip package. To use FA4 for training on Hopper, install from source using the instructions above.

Warning

FA4 only supports head dimensions up to 128 (d ≤ 128). The DeepSeek shape (192, 128) is also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions and falls back to FA2/3.

For more details: flash-attention/flash_attn/cute

AMD

Requirements: ROCm 6.0 and above.

See Flash Attention AMD docs.

Flex Attention

A flexible PyTorch API for attention used in combination with torch.compile.

flex_attention: true

# recommended
torch_compile: true
Note

We recommend using latest stable version of PyTorch for best performance.

For more details: PyTorch docs

SageAttention

Attention kernels with QK Int8 and PV FP16 accumulator.

sage_attention: true

Requirements: Ampere, Ada, or Hopper GPUs

pip install sageattention==2.2.0 --no-build-isolation
Warning

Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See GitHub Issue.

For more details: Sage Attention

Note

We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.

xFormers

xformers_attention: true
Tip

We recommend using with Turing GPUs or below (such as on Colab).

For more details: xFormers

Shifted Sparse Attention

Warning

We plan to deprecate this! If you use this feature, we recommend switching to methods above.

Requirements: LLaMA model architecture

flash_attention: true
s2_attention: true
Tip

No sample packing support!