06 | Infer Optimization¶
约 1066 个字 4 行代码 16 张图片 预计阅读时间 4 分钟
正在施工中👷..
Key-value cache¶
- Key-value cache
原理 ¶
在 Decoder 阶段,使用 Auto Regressive 机制
由于有 Mask 机制,每次有新的 token 加入的时候,只需要做 \(Q_{new}\) 和 \(K_{old}\) 的注意力计算,而不用重新计算整个序列。
所以,我们只需要保存 \(K_{old}\) 和 \(V_{old}\) ( 因为只用到了 KV),就可以实现高效的增量生成。
值得注意的是,KV 缓存的大小通常和模型本身大小是同一级别,也是一种空间换时间的策略
pie
title Memory Usage of 13B LLM on A100-40GB
"Parameters" : 65
"KV Cache" : 30
"Others" : 5
Paged Attention¶
为什么需要 ¶
操作系统
操作系统需要给进程预先分配内存吗
每个页 4K
原理 ¶
- 不预分配,按需调用
- 按块 Block 分配内存,碎片更小
- 虚拟内存:逻辑内存是连续的,通过映射表链接到物理内存(实际分配不连续
) ;方便调用
Share KV Cache¶
copy on write机制:引用大于 1 的时候,不能直接写入,必须拷贝一份,再写入
还可以优化 beam-search
Flash Attention¶
[2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- Fast
- Memory-Efficient
- Exact
为什么需要 ¶
Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length.
SRAM 读取快,HBM 读取慢
- Compute-bound: (数据等算力)
- 大的矩阵乘法,多 channel 卷积
- IO-bound:(算力等数据)
- 按位操作:Relu,Dropout
- 规约操作:sum、softmax
一般使用 fusion 融合操作,算结果时候只读取一次 HBM
原始 Attention 的实现 ¶
矩阵 \(Q\), \(K\), \(V \in \mathbb{R}^{N\times d}\) 存储在 HBM
- 从
HBM
加载 \(Q\), \(K\) 到SRAM
- 计算出 \(S = QK^T\)
- 将 \(S\) 写到
HBM
- 将 \(S\) 加载到
SRAM
- 计算 \(P = softmax(S)\)
- 将 \(P\) 写出到
HBM
- 从
HBM
加载 \(P\) 和 \(V\) 到SRAM
- 计算 \(O = PV\)
- 把 \(O\) 写出到
HBM
- 返回 \(O\)
tiling softmax¶
减少 IO 量
让 Attention 的所有计算都符合加法结合律
- 通过分块计算,融合多个操作,减少中间结果缓存
- 反向传播等时候,重新计算结果
softmax精度问题
\(e\) 的指数项可能超过精度,比如 65536
使用指数项可能会爆精,所以使用 safe_softmax
即如果计算了左侧的 softmax,右侧的 softmax 如何计算整体的
KV 在外循环 Q 在内循环
对于整体来讲
Q.shape[:-1] = (1, 1, 6)
[..., None]
会在最后增加一个维度,相当于:
(1, 1, 6) → (1, 1, 6, 1)
所以:
l.shape
=(1, 1, Q_LEN, 1)
m.shape
=(1, 1, Q_LEN, 1)
为什么是 (1, 1, Q_LEN, 1)
而不是 (1, 1, Q_LEN)
?
作用:方便广播运算
在注意力计算时,l
和 m
是针对每个 query 位置存储的:
m
→ 这个位置的当前最大 logit(数值稳定 softmax 用)l
→ 这个位置的 softmax 分母(sum(exp(...))
)
在后续更新中,会用到像:
torch.exp(m_block_ij - mi_new)
这里的 m_block_ij
形状通常是 (1, 1, block_size, 1)
,
如果 l
和 m
也有最后一个 1
维度,就可以无额外 reshape 直接广播。
另外一个原因:与 V 对齐
注意力输出是:
output = sum(softmax(QK^T) * V)
V
的形状是 (1, 1, KV_LEN, dim)
,
而 l
、m
只存每个 query 的一个标量,所以最后一维是 1
,
这样在计算时既能和 (1, 1, Q_LEN, dim)
广播,也能和 (1, 1, Q_LEN, 1)
对齐。
需要额外存储
反向传播 recomputation ¶
前向的时候,会保存 softmax 统计值,\(m\) 和 \(l\)
StreamLLM¶
在 nvidia-smi 中可以看到所有 GPU 的利用率会直接冲到 100%,直到这个超卡的请求全部生成完,才会恢复正常。这不就是典型的优先 prefill 暂停 decode 么,解决办法就是 chunked prefill size 啊,deepseek 都告诉你了。