
大家好,我是吴师兄。
上周群里有位同学在大模型面试现场被问到这样一道题:
“请手写一个多头注意力(Multi-Head Attention)的实现,支持 KV Cache,并说说 RoPE 应该加在哪?”
听到这道题,他脑子嗡地一下,平时看的论文、看的源码都浮现在眼前,但真让他手写实现,尤其要把 KV 缓存机制和相对位置编码讲明白,还真不知道从哪下手。
所以,今天这篇文章我就从实战视角出发,带你完整走一遍这道高频大模型面试题的答题思路,包含代码实现、关键细节、知识讲解和面试话术。
如果你在准备大模型相关岗位,比如推理系统、模型部署、架构优化方向,建议收藏本文,认真过一遍。
一、多头注意力机制(MHA)回顾
Multi-Head Attention 是 Transformer 架构的核心组件。它的目标是通过多个注意力头并行计算,让模型能从不同子空间捕捉信息。
在面试中,如果让你手写实现,一般是基于 PyTorch。
我们直接上手代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.q_linear = nn.Linear(hidden_size, hidden_size)
self.k_linear = nn.Linear(hidden_size, hidden_size)
self.v_linear = nn.Linear(hidden_size, hidden_size)
self.o_linear = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_state, causal_mask=None, past_key_value=None, use_cache=False):
batch_size = hidden_state.size(0)
query = self.q_linear(hidden_state)
key = self.k_linear(hidden_state)
value = self.v_linear(hidden_state)
# 多头拆分
query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 拼接缓存
if past_key_value isnotNone:
past_key, past_value = past_key_value
key = torch.cat([past_key, key], dim=2)
value = torch.cat([past_value, value], dim=2)
new_past_key_value = (key, value) if use_cache elseNone
# 注意力打分
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
if causal_mask isnotNone:
attention_scores += causal_mask * -1e9
attention_probs = F.softmax(attention_scores, dim=-1)
output = torch.matmul(attention_probs, value)
# 合并多头
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
output = self.o_linear(output)
return (output, new_past_key_value) if use_cache else output
这一段实现了标准的 MHA 结构,同时加入了 KV Cache(past_key_value参数)和 causal mask。
二、KV Cache 是怎么加进去的?
很多刚入门的同学对 KV Cache 的理解还停留在“优化”的层面。但在实际推理中,KV Cache 是影响吞吐量的关键机制。
在这段代码中:
if past_key_value is not None:
past_key, past_value = past_key_value
key = torch.cat([past_key, key], dim=2)
value = torch.cat([past_value, value], dim=2)
这段逻辑做的是把过去生成的 key/value 缓存起来,然后当前这轮推理时直接拼接,不用重复计算。
为什么这样做?
因为在生成式任务中,每生成一个 token,前面的 token 的 key/value 是不变的,完全可以缓存下来复用,节省大量计算和内存访问。
三、RoPE 应该加在哪?
这个问题也是面试高频考点:RoPE(Rotary Positional Embedding)相对位置编码究竟加在什么位置?为什么?
一句话总结:RoPE 作用于 Q、K,位置在 attention score 计算之前。
RoPE 的目标是通过旋转操作嵌入“相对位置信息”。其数学表达为:
=
其中角度 θi=1100002k/dtheta_i = frac{1}{10000^{2k/d}},kk 为维度索引。
你可以把 RoPE 想象成一个无参数的“相对位置信息混入器”,它不需要学习参数,但能在注意力机制中显著提升模型感知长依赖关系的能力。
所以在上述 MHA 实现中,应该在如下位置插入 RoPE 编码:
# 在 query 和 key 上添加 RoPE 编码
query = apply_rope(query)
key = apply_rope(key)
这一步是在 softmax 之前,让打分矩阵天然带有位置感知能力。
四、如果在面试中遇到该题,应该怎么答?
建议你分 4 个层次回答,越往后越体现工程落地能力:
-
功能结构:说明 MHA 的整体组成和各组件的作用; -
KV Cache 的意义和实现方式:减少重复计算,加快推理速度; -
RoPE 的原理和添加位置:无需引入参数,位置插入 Q/K 之间,提升模型相对位置信息处理能力; -
代码层级思维:说清楚 forward 中各步的操作以及为什么这么做。
你可以不写完整代码,但思路得清晰。如果你能一边说一边写出来,面试官绝对对你刮目相看。
五、小结:这道题到底在考什么?
这道题不是在为难你,而是想看你对底层实现是否真的动过手:
-
你是否理解 MHA 的 forward 流程? -
你是否知道推理阶段要加 KV Cache? -
你是否知道 RoPE 是现在主流大模型中的默认位置编码方式? -
你是否写过类似代码或者调试过 HuggingFace 模型结构?
说白了,这是考察“你到底是读过代码,还是只听过讲解”。
最后的话
这道题虽然是面试问题,但背后体现的是 你是否真的参与过大模型推理系统的构建和调优。
真正的工程落地不只是在笔记本上跑 demo,而是要知道:
-
怎样才能让生成更快; -
怎样才能复用已有缓存; -
怎样处理长文本性能下降问题。
这些,才是你在秋招中脱颖而出的底气。
如果你还想系统掌握这些底层机制,我的大模型冲刺营已经帮助大量同学掌握从 KV Cache、RoPE、Flash Attention 到推理部署的全链路知识。
我们不讲“学术定义”,而是手把手带你复现源码、优化模型、跑出性能。
秋招已经进入关键节点,不如直接入局。
