2766 字
14 分钟
次浏览
大模型推理显存拆解 — 一步步算清你的显存去哪了

0. 前言#

你有一张 RTX 4090, 24GB 显存. 你想跑 Llama-3-8B.

问题来了: 8B 参数的模型, 在 BF16 精度下, 光是加载权重就要 16GB. 你还有 8GB 的余量. “够了! 跑吧!”

然后跑起来发现 OOM (Out Of Memory).

为什么? 因为权重只是冰山一角. 推理过程中还有KV Cache激活值在悄悄吃显存. 这篇文章我们就来一笔笔算清楚, 每个公式都从原理出发推导出来, 而不是直接扔给你一个数字.

最终你会发现: 显存的去向其实完全可以精确计算, 而且每一笔都有优化的办法.


1. 模型参数: 第一笔账#

这比账最好算.

1.1 基本公式#

模型参数占用的显存 = 参数数量 × 每个参数的字节数:

Mparams=Nparams×bparamM_{\text{params}} = N_{\text{params}} \times b_{\text{param}}

其中 bparamb_{\text{param}} 取决于精度:

  • FP32: 4 字节
  • BF16 / FP16: 2 字节
  • INT8: 1 字节
  • INT4: 0.5 字节

所以对于 8B 参数的模型:

精度MparamsM_{\text{params}}计算过程
BF1616 GB8×109×2=16×1098 \times 10^9 \times 2 = 16 \times 10^9 字节
INT88 GB8×109×1=8×1098 \times 10^9 \times 1 = 8 \times 10^9 字节
INT44 GB8×109×0.5=4×1098 \times 10^9 \times 0.5 = 4 \times 10^9 字节

1.2 更精确地算: 参数从哪来?#

“8B 参数”这个数字到底是怎么组成的? 我们以 Llama-3-8B 为例拆一下.

Llama-3-8B 的结构参数:

参数符号
层数LL32
隐藏维度dmodeld_{\text{model}}4096
FFN 中间维度dffd_{\text{ff}}14336
注意力头数nheadsn_{\text{heads}}32
KV 头数nkvn_{\text{kv}}8 (GQA)
词表大小VV128000

逐层细分:

1. Embedding 层 (词嵌入):

Membed=V×dmodel=128000×4096524MM_{\text{embed}} = V \times d_{\text{model}} = 128000 \times 4096 \approx 524\text{M}

2. 每层 Transformer (共 32 层):

注意力部分:

  • Q 投影: dmodel×(nheads×dhead)=4096×4096=16.8Md_{\text{model}} \times (n_{\text{heads}} \times d_{\text{head}}) = 4096 \times 4096 = 16.8\text{M}
  • K 投影: dmodel×(nkv×dhead)=4096×1024=4.2Md_{\text{model}} \times (n_{\text{kv}} \times d_{\text{head}}) = 4096 \times 1024 = 4.2\text{M}
  • V 投影: 同 K, 4.2M4.2\text{M}
  • O 投影: (nheads×dhead)×dmodel=4096×4096=16.8M(n_{\text{heads}} \times d_{\text{head}}) \times d_{\text{model}} = 4096 \times 4096 = 16.8\text{M}

每层注意力合计: 16.8+4.2+4.2+16.8=42M16.8 + 4.2 + 4.2 + 16.8 = 42\text{M}

FFN 部分 (SwiGLU, 3个矩阵):

  • gate_proj: dmodel×dff=4096×14336=58.7Md_{\text{model}} \times d_{\text{ff}} = 4096 \times 14336 = 58.7\text{M}
  • up_proj: 同上, 58.7M58.7\text{M}
  • down_proj: dff×dmodel=14336×4096=58.7Md_{\text{ff}} \times d_{\text{model}} = 14336 \times 4096 = 58.7\text{M}

每层 FFN 合计: 58.7×3=176.1M58.7 \times 3 = 176.1\text{M}

每层 Transformer 合计: 42+176.1=218.1M42 + 176.1 = 218.1\text{M} 32 层: 218.1×327.0B218.1 \times 32 \approx 7.0\text{B}

3. RMS Norm (每层有 2 个, 加上最后的):

每个 RMS Norm 只有 dmodeld_{\text{model}} 个可训练参数 (= 4096), 可以忽略.

4. LM Head (输出层):

dmodel×V=4096×128000=524Md_{\text{model}} \times V = 4096 \times 128000 = 524\text{M}

总计: 524M(embed)+7.0B(32层)+524M(head)8.0B524\text{M} (\text{embed}) + 7.0\text{B} (\text{32层}) + 524\text{M} (\text{head}) \approx 8.0\text{B}

这个计算验证了: 8B 参数不是凭空说的, 每一层、每个矩阵的贡献都可以精确计算. 当有人告诉你”这是一个 8B 模型”时, 你可以快速心算: 大概要占 16GB (BF16) / 8GB (INT8) / 4GB (INT4).


2. KV Cache: 被严重低估的显存杀手#

KV Cache 的原理我在之前的博客 A Series on LLMs (II) 中已经详细介绍过了, 这里简单回顾一下核心思路, 重点放在”占多少显存”的计算上.

2.1 从注意力公式出发#

先回顾一下 Transformer Decoder 的注意力计算. 在第 tt 步, 模型需要计算:

Attention(Qt,Kt,Vt)=softmax(QtKtTdk)Vt\text{Attention}(Q_t, K_{\le t}, V_{\le t}) = \text{softmax}\left(\frac{Q_t K_{\le t}^T}{\sqrt{d_k}}\right) V_{\le t}

这里 QtQ_t当前 token的 query (大小 1×dk1 \times d_k), 而 KtK_{\le t}VtV_{\le t}所有历史位置的 key 和 value (大小 t×dkt \times d_k).

你可以选择:

  • 方案 A: 每次重新算 KtK_{\le t}VtV_{\le t} — 第 tt 步的计算量是 O(t×dk)O(t \times d_k), 累计 O(T2×dk)O(T^2 \times d_k), 序列长了完全不可接受.
  • 方案 B: 把之前每一步算好的 KiK_i, ViV_i 存起来, 每次只要算当前 token 的 KtK_t, VtV_t, 然后拼到缓存里.

方案 B 就是 KV Cache.

2.2 KV Cache 的精确公式#

每层需要缓存 K 和 V 两份. 每个 token 每层每头需要的空间是 dhead×bparamd_{\text{head}} \times b_{\text{param}}.

所以 KV Cache 的总大小:

Mkv=2×L×nkv×dhead×bparam×TM_{\text{kv}} = 2 \times L \times n_{\text{kv}} \times d_{\text{head}} \times b_{\text{param}} \times T

其中 TT 是序列长度.

为什么要乘以 nkvn_{\text{kv}} 而不是 nheadsn_{\text{heads}}?

这里取决于注意力机制:

  • MHA (Multi-Head Attention): 每个 query head 有独立的 K, V head → nkv=nheadsn_{\text{kv}} = n_{\text{heads}}
  • GQA (Grouped-Query Attention): 多个 query head 共享一组 K, V → nkv<nheadsn_{\text{kv}} < n_{\text{heads}}
  • MQA (Multi-Query Attention): 所有 query head 共享同一组 K, V → nkv=1n_{\text{kv}} = 1

Llama-3-8B 用了 GQA, nkv=8n_{\text{kv}} = 8. 如果它用 MHA (nkv=32n_{\text{kv}} = 32), KV Cache 会大 4 倍!

2.3 具体数字#

以 Llama-3-8B 为例 (L=32L=32, nkv=8n_{\text{kv}}=8, dhead=128d_{\text{head}}=128, BF16, bparam=2b_{\text{param}}=2):

Mkv=2×32×8×128×2×TM_{\text{kv}} = 2 \times 32 \times 8 \times 128 \times 2 \times T

简化: Mkv=131072×T 字节=128×T KBM_{\text{kv}} = 131072 \times T \ \text{字节} = 128 \times T \ \text{KB}

TTMkvM_{\text{kv}}占权重的比例
51264 MB0.4%
2,048256 MB1.6%
4,096512 MB3.1%
8,1921 GB6.3%
32,7684 GB25%
128,00016 GB100%

可以看到: 当序列长度达到 128K 时, KV Cache 的显存开销已经和模型权重本身一样大!

这就是为什么长上下文推理如此吃显存. 跑 128K 上下文意味着你需要双倍的显存——一份装权重, 一份装 KV Cache.

2.4 与 batch size 的关系#

上面的计算假设 batch size = 1. 如果同时处理 BB 个请求:

Mkv(total)=Mkv(T)×BM_{\text{kv}}(\text{total}) = M_{\text{kv}}(T) \times B

KV Cache 随着 batch size 线性增长. 如果有 8 个并发请求, 每个 32K 上下文, KV Cache 就要 32GB——已经超过了大多数消费级显卡.

这就是为什么 vLLM 等推理框架如此重要——它们通过 PagedAttention 让多个请求共用显存, 消除了内部碎片.


3. 激活值: 临时工#

与模型参数和 KV Cache 不同, 激活值是临时占用的——每步前向传播后就会被释放.

3.1 激活值从哪来?#

在推理的每一步, 数据流过每一层 Transformer:

输入 (hidden_states)
→ RMS Norm → QKV 投影 → Attention 计算 → 残差连接
→ RMS Norm → FFN (gate/up/down) → 残差连接
→ 输出到下一层

每一层的中间结果都需要占用显存. 具体来说:

  • Attention 部分: Q, K, V 投影后的矩阵, attention score (T×TT \times T), attention output
  • FFN 部分: gate 输出, up 输出, 中间激活, down 输出
  • 残差连接: 需要保留输入向量用于加法

3.2 估算公式#

有个经验公式可以快速估算:

Mact(34×dmodel+5×dff)×T×B×bparamM_{\text{act}} \approx ( 34 \times d_{\text{model}} + 5 \times d_{\text{ff}} ) \times T \times B \times b_{\text{param}}

这个 34 和 5 是怎么来的? 来自每层中各种中间矩阵的大小之和. 对于 Llama-3-8B (dmodel=4096d_{\text{model}}=4096, dff=14336d_{\text{ff}}=14336):

Mact(34×4096+5×14336)×T×B×2M_{\text{act}} \approx (34 \times 4096 + 5 \times 14336) \times T \times B \times 2

B=1B=1 时:

TTMactM_{\text{act}}
512206MB\approx 206\text{MB}
4,0961.6GB\approx 1.6\text{GB}
32,76813GB\approx 13\text{GB}

注意: 长序列时激活值的占用也接近模型权重了! 这是因为 T×dmodelT \times d_{\text{model}} 的乘积在变大.

3.3 为什么激活值容易被忽略#

KV Cache 和模型参数是常驻显存的——加载后直到推理结束才释放. 激活值是临时的——每算完一层就释放一部分.

所以很多人只关注常驻部分. 但问题在于峰值时刻: 当长序列且没有 Flash Attention 时, 完整的 T×TT \times T 注意力矩阵 (对 32K 序列就是 32K×32K×2bytes2GB32K \times 32K \times 2\text{bytes} \approx 2\text{GB}) 可能瞬间撑爆显存.

这就是为什么 Flash Attention 如此重要——它通过分块计算避免了一次性创建完整的注意力矩阵.


4. 总账本#

对 Llama-3-8B (BF16, batch=1) 做个总账:

4.1 短序列 (512 tokens)#

项目大小占比
模型参数16.0 GB96.5%
KV Cache0.064 GB0.4%
激活值 (峰值)0.2 GB1.2%
总计~16.3 GB

→ 24GB 显卡轻松跑. 主要瓶颈是模型权重.

4.2 中等序列 (8K tokens)#

项目大小占比
模型参数16.0 GB74%
KV Cache1.0 GB5%
激活值 (峰值)3.2 GB15%
其他开销~1.3 GB6%
总计~21.5 GB

→ 24GB 显卡刚好够用, 但快满了.

4.3 长序列 (32K tokens)#

项目大小占比
模型参数16.0 GB44%
KV Cache4.0 GB11%
激活值 (峰值)13 GB36%
其他开销~3 GB9%
总计~36 GB

→ 24GB 显卡完全不够! 必须优化.


5. 每项都能优化#

既然知道了每一笔账, 就可以针对性地”省钱”.

5.1 模型参数: 量化#

量化就是把 BF16 降到更低位宽:

Mparams(INT4)=14Mparams(BF16)M_{\text{params}}(\text{INT4}) = \frac{1}{4} M_{\text{params}}(\text{BF16})

8B 模型: 16 GB → 4 GB, 省 12GB.

代价? 理论上少量精度损失, 实践中 INT4 的 MMLU 损失通常在 1% 以内. 值不值? 对于部署来说, 太值了.

5.2 KV Cache: 三个方向#

方向 1: GQA / MQA (架构层面)

从 MHA 换成 GQA 或 MQA, 直接减少 nkvn_{\text{kv}}:

MkvGQAMkvMHA=nkvGQAnkvMHA\frac{M_{\text{kv}}^{\text{GQA}}}{M_{\text{kv}}^{\text{MHA}}} = \frac{n_{\text{kv}}^{\text{GQA}}}{n_{\text{kv}}^{\text{MHA}}}

Llama-3-8B 用 GQA (nkv=8n_{\text{kv}}=8) 而非 MHA (nkv=32n_{\text{kv}}=32), KV Cache 直接省 4 倍.

方向 2: KV Cache 量化 (数值层面)

把 KV Cache 从 FP16 (2 字节) 存成 INT8 (1 字节):

Mkv(INT8)=12Mkv(FP16)M_{\text{kv}}(\text{INT8}) = \frac{1}{2} M_{\text{kv}}(\text{FP16})

32K 上下文: 4 GB → 2 GB.

方向 3: PagedAttention (系统层面)

KV Cache 按固定大小的 page 分配, 类似操作系统虚拟内存分页. 主要收益:

  • 消除内部碎片 (不同序列长度导致的不连续分配)
  • 方便内存共享 (如 beam search 的多个候选共用前缀)

vLLM 用 PagedAttention 宣称能省 60-80% 的 KV Cache 显存——这个数字来自碎片消除 + 共享前缀 + 按需分配的综合效果.

5.3 激活值: Flash Attention#

标准的注意力实现需要构建 T×TT \times T 的注意力矩阵:

Mattn=T2×bparamM_{\text{attn}} = T^2 \times b_{\text{param}}

32K 序列: 327682×22GB32768^2 \times 2 \approx 2\text{GB}

Flash Attention 把计算分块, 让注意力矩阵的子块在 SRAM 中处理, 然后累加结果进 HBM. 这样在 HBM 层面只需要 O(T×d)O(T \times d) 的显存, 而不是 O(T2)O(T^2).

收益: 长序列时激活值显存从 O(T2)O(T^2) 变成 O(T)O(T)——对于 32K 序列, 可以省数十 GB.

5.4 组合优化的效果#

对 Llama-3-8B, 32K 上下文, BF16→INT4, 加上各种优化:

优化模型参数KV Cache激活值总计备注
原始 (BF16)16 GB4 GB13 GB~36 GB不行
+ INT4 量化4 GB4 GB13 GB~24 GB勉强
+ KV Cache INT84 GB2 GB13 GB~22 GBOK
+ Flash Attention4 GB2 GB~1 GB~7 GB轻松

这就是 LLM 推理优化的魔力——通过理解每一笔开销的数学原理, 你可以有针对性地节省几十 GB 的显存.


6. 实际部署经验#

公式都推导清楚了, 实际操作就简单了:

短上下文 (<2K): 主要瓶颈是模型参数 → 先量化 长上下文 (>8K): 主要瓶颈是 KV Cache → GQA + PagedAttention 大批量推理: 激活值和 KV Cache 都线性增长 → Flash Attention + PagedAttention

一个具体的配置:

# vLLM 自动处理大部分优化
from vllm import LLM, SamplingParams
llm = LLM(
model="meta-llama/Meta-Llama-3-8B",
max_model_len=8192, # 限制最大序列长度
gpu_memory_utilization=0.9, # 使用 90% 显存
kv_cache_dtype="fp8", # KV Cache 量化
)

HuggingFace 原生:

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2", # 省激活值
)

7. 总结#

推理显存的数学很简单, 就是加乘:

项目公式关键参数
模型参数Nparams×bparamN_{\text{params}} \times b_{\text{param}}参数量和精度
KV Cache2×L×nkv×dhead×bparam×T2 \times L \times n_{\text{kv}} \times d_{\text{head}} \times b_{\text{param}} \times T层数、头数、序列长度
激活值34×dmodel×T×B×bparam\approx 34 \times d_{\text{model}} \times T \times B \times b_{\text{param}}模型宽度、序列长度

关键在于: 每一项你都能精确算出, 每算出来一项, 就知道应该从哪下手优化.

下次别人说”这个 8B 模型跑不起来”, 你可以问: 上下文多长? 精度用什么? 用 Flash Attention 了吗? ——而且每一问你都知道他差在哪里.


参考资料#

  1. Kwon et al., Efficient Memory Management for Large Language Model Serving with PagedAttention. SOSP 2023. arXiv:2309.06180vLLM 核心论文
  2. Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention. NeurIPS 2022. arXiv:2205.14135
  3. Shazeer, Fast Transformer Decoding: One Write-Head is All You Need. 2019. arXiv:1911.02150MQA
  4. Ainslie et al., GQA: Training Generalized Multi-Query Transformer Models. 2023. arXiv:2305.13245
  5. Meta, The Llama 3 Herd of Models. 2024. arXiv:2407.21783Llama-3 架构细节
大模型推理显存拆解 — 一步步算清你的显存去哪了
https://xuchenhui.cc/posts/2026-05-16-llm-memory-breakdown/
作者
CHENHUI
发布于
2026-05-16
许可协议
CC BY-NC-SA 4.0
📖 目录