# FlashAttention
This repository provides the official implementation of FlashAttention and
FlashAttention-2 from the
following papers.
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
Paper: https://arxiv.org/abs/2205.14135
IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.
![FlashAttention](assets/flashattn_banner.jpg)
**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**
Tri Dao
Paper: https://tridao.me/publications/flash2/flash2.pdf
![FlashAttention-2](assets/flashattention_logo.png)
## Usage
We've been very happy to see FlashAttention being widely adopted in such a short
time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md)
contains a partial list of places where FlashAttention is being used.
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
Please cite and credit FlashAttention if you use it.
## Installation and features
Requirements:
- CUDA 11.6 and above.
- PyTorch 1.12 and above.
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
We recommend the
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
container from Nvidia, which has all the required tools to install FlashAttention.
To install:
1. Make sure that PyTorch is installed.
2. Make sure that `packaging` is installed (`pip install packaging`)
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
compiling can take a very long time (2h) since it does not use multiple CPU
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
4. Then:
```sh
pip install flash-attn --no-build-isolation
```
Alternatively you can compile from source:
```sh
python setup.py install
```
If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might
run too many parallel compilation jobs that could exhaust the amount of RAM. To
limit the number of parallel compilation jobs, you can set the environment
variable `MAX_JOBS`:
```sh
MAX_JOBS=4 pip install flash-attn --no-build-isolation
```
Interface: `src/flash_attention_interface.py`
FlashAttention-2 currently supports:
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
GPUs for now.
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
## How to use FlashAttention
The main functions implement scaled dot product attention (softmax(Q @ K^T *
softmax_scale) @ V):
```python
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
```
```python
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
```
```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
```
```python
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True,
alibi_slopes=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Note: Does not support backward pass.
Arguments:
q: (batc
没有合适的资源?快使用搜索试试~ 我知道了~
快速内存高效的注意力算法:FlashAttention
共405个文件
py:120个
cu:110个
yaml:106个
1.该资源内容由用户上传,如若侵权请联系客服进行举报
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
版权申诉
0 下载量 36 浏览量
2024-01-13
11:32:23
上传
评论
收藏 5.31MB ZIP 举报
温馨提示
一种快速、内存高效的注意力算法。它无需任何近似即可加速注意力并减少内存占用。许多机构和研究实验室已经采用FlashAttention来加速训练和推理。对于需要高性能注意力算法的人工智能研究者和工程师,这个项目提供了一个有用的解决方案。
资源推荐
资源详情
资源评论
收起资源包目录
快速内存高效的注意力算法:FlashAttention (405个子文件)
AUTHORS 29B
flash_api.cpp 68KB
ln_api.cpp 36KB
ft_attention.cpp 10KB
fused_dense.cpp 10KB
fused_softmax.cpp 5KB
interface.cpp 2KB
rotary.cpp 2KB
xentropy_kernel.cu 25KB
fused_dense_cuda.cu 24KB
decoder_masked_multihead_attention.cu 7KB
scaled_masked_softmax_cuda.cu 4KB
scaled_upper_triang_masked_softmax_cuda.cu 3KB
rotary_cuda.cu 2KB
ln_parallel_bwd_5120.cu 1KB
ln_parallel_bwd_4096.cu 1KB
ln_parallel_bwd_1280.cu 1KB
ln_parallel_bwd_1024.cu 1KB
ln_parallel_bwd_1536.cu 1KB
ln_parallel_bwd_768.cu 1KB
ln_parallel_bwd_512.cu 1KB
ln_parallel_bwd_2560.cu 1KB
ln_parallel_bwd_256.cu 1KB
ln_parallel_bwd_3072.cu 1KB
ln_parallel_bwd_8192.cu 1KB
ln_parallel_bwd_6144.cu 1KB
ln_parallel_bwd_7168.cu 1KB
ln_parallel_bwd_2048.cu 1KB
ln_parallel_fwd_1280.cu 1KB
ln_parallel_fwd_4096.cu 1KB
ln_parallel_fwd_3072.cu 1KB
ln_parallel_fwd_8192.cu 1KB
ln_parallel_fwd_7168.cu 1KB
ln_parallel_fwd_768.cu 1KB
ln_parallel_fwd_2048.cu 1KB
ln_parallel_fwd_512.cu 1KB
ln_parallel_fwd_2560.cu 1KB
ln_parallel_fwd_6144.cu 1KB
ln_parallel_fwd_1024.cu 1KB
ln_parallel_fwd_5120.cu 1KB
ln_parallel_fwd_1536.cu 1KB
ln_parallel_fwd_256.cu 1KB
ln_bwd_1280.cu 987B
ln_bwd_1024.cu 987B
ln_bwd_1536.cu 977B
ln_bwd_2560.cu 977B
ln_bwd_256.cu 977B
ln_bwd_512.cu 977B
ln_bwd_768.cu 977B
ln_bwd_3072.cu 976B
ln_bwd_4096.cu 976B
ln_bwd_5120.cu 976B
ln_bwd_2048.cu 976B
ln_bwd_7168.cu 976B
ln_bwd_8192.cu 976B
ln_bwd_6144.cu 976B
ln_fwd_512.cu 925B
ln_fwd_6144.cu 925B
ln_fwd_768.cu 925B
ln_fwd_3072.cu 925B
ln_fwd_8192.cu 925B
ln_fwd_5120.cu 925B
ln_fwd_4096.cu 925B
ln_fwd_1280.cu 925B
ln_fwd_1536.cu 925B
ln_fwd_7168.cu 925B
ln_fwd_256.cu 925B
ln_fwd_2048.cu 925B
ln_fwd_2560.cu 925B
ln_fwd_1024.cu 925B
flash_bwd_hdim160_bf16_sm80.cu 419B
flash_bwd_hdim256_bf16_sm80.cu 419B
flash_bwd_hdim128_bf16_sm80.cu 419B
flash_bwd_hdim192_bf16_sm80.cu 419B
flash_bwd_hdim224_bf16_sm80.cu 419B
flash_bwd_hdim32_bf16_sm80.cu 417B
flash_bwd_hdim64_bf16_sm80.cu 417B
flash_bwd_hdim96_bf16_sm80.cu 417B
flash_bwd_hdim224_fp16_sm80.cu 411B
flash_bwd_hdim160_fp16_sm80.cu 411B
flash_bwd_hdim256_fp16_sm80.cu 411B
flash_bwd_hdim128_fp16_sm80.cu 411B
flash_bwd_hdim192_fp16_sm80.cu 411B
flash_bwd_hdim32_fp16_sm80.cu 409B
flash_bwd_hdim96_fp16_sm80.cu 409B
flash_bwd_hdim64_fp16_sm80.cu 409B
flash_fwd_hdim256_bf16_sm80.cu 386B
flash_fwd_hdim192_bf16_sm80.cu 386B
flash_fwd_hdim160_bf16_sm80.cu 386B
flash_fwd_hdim128_bf16_sm80.cu 386B
flash_fwd_hdim224_bf16_sm80.cu 386B
flash_fwd_hdim32_bf16_sm80.cu 384B
flash_fwd_hdim96_bf16_sm80.cu 384B
flash_fwd_hdim64_bf16_sm80.cu 384B
flash_fwd_hdim160_fp16_sm80.cu 378B
flash_fwd_hdim224_fp16_sm80.cu 378B
flash_fwd_hdim256_fp16_sm80.cu 378B
flash_fwd_hdim192_fp16_sm80.cu 378B
flash_fwd_hdim128_fp16_sm80.cu 378B
flash_fwd_hdim32_fp16_sm80.cu 376B
共 405 条
- 1
- 2
- 3
- 4
- 5
资源评论
UnknownToKnown
- 粉丝: 1w+
- 资源: 782
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 基于HTML、CSS、JavaScript的easy云盘前端设计源码
- 基于Java、Vue等技术的优加任务管理系统设计源码
- matlab simulink半车主动悬架建模:基于ADRC(自抗扰控制)的主动悬架控制 主体模型为半车主动悬架,采取ADRC控制 输出为车身加速度,悬架动挠度,轮胎动变形 默认输入为正弦路面输
- 基于PHP和Vue的河马跑腿私域配送团队小程序设计源码
- Linux RTL8761b蓝牙驱动 Ubuntu 20.04可用
- 移动磁铁在线圈中产生感应电压分析与仿真 COMSOL 6.0案例还原及 此模型模拟磁铁在线圈中的运动,并计算感应电压,磁铁的位移很明显,因此使用动网格和滑移网格
- 基于TypeScript和JavaScript的核桃健康App设计源码
- 永磁同步电机全阶自适应观测器 自适应全阶观测器MATLAB仿真,高速电机,基础版15.9,改进版49(改进版波形精美,易于出图)下面图为改进版,低速高速都可以,最高5W转每分
- 基于Python生态的第三方库管理器设计源码
- 基于three.js和Vue3的简易智慧城市设计源码
- simulink永磁同步风机风光储VSG一次调频,风机为PMSG,风光储并网系统,频率波形和风机VSG出力如图 网侧VSG同步机控制
- 基于Vue框架的汽修门店SaaS系统设计源码
- 基于Kotlin语言的Android作业设计源码
- mmc分布式储能 恒功率控制 恒电压控制 无缝切
- 基于微信小程序的PowerLib图书馆门户小程序设计源码
- 前端分析-2023071100789
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功