LLM 常见手撕代码

零、张量变换与重塑

0.1 教程说明

操作 作用 是否复制内存 常见用途
view() 重塑形状 分头、展平
reshape() 重塑形状 必要时是 安全重塑
transpose() 交换两维 注意力计算
permute() 重排所有维 复杂维度重排
contiguous() 内存连续化 必要时是 view前确保连续
unsqueeze() 插入维度 广播准备
squeeze() 删除维度 去除冗余维
expand() 扩展维度 只读广播
repeat() 复制扩展 需要独立副本
cat() 拼接 视情况 合并张量
stack() 堆叠 视情况 增维合并
split() 按大小分割 分离特征
chunk() 按份数分割 均分张量
where() 条件选择 视情况 掩码、索引
  1. 优先使用 view() 而非 reshape()(如果确定张量连续)
  2. transpose() 后如需 view(),必须先 contiguous()
  3. 只读扩展用 expand(),需要副本用 repeat()
  4. 使用 -1 让 PyTorch 自动推断维度大小
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch

x = torch.arange(12)
# tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])


# 1. view 返回新视图, 必须共享内存, 张量必须连续
# reshape 返回新视图, 尽量共享内存, 必要时复制
y = x.view(3, 4)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
y[0, 0] = 999 # 共享内存, 此时 x[0] = 999

yt = y.t() # 转置后内存不连续
yt.is_contiguous() # Fasle
# tensor([[ 0, 4, 8],
# [ 1, 5, 9],
# [ 2, 6, 10],
# [ 3, 7, 11]])
z = yt.reshape(4, 3)
# tensor([[ 0, 4, 8],
# [ 1, 5, 9],
# [ 2, 6, 10],
# [ 3, 7, 11]])


# 2. transpose 交换两个维度
# permute 重排所有维度
x = torch.randn(2, 3, 4) # torch.Size([2, 3, 4])
y = x.transpose(0, 1) # torch.Size([3, 2, 4])
z = x.permute(2, 0, 1) # torch.Size([4, 2, 3])


# 3. contiguous 内存连续化, 返回内存连续的副本(若原来不连续)
x = torch.arange(12).view(3, 4)
xt = x.transpose(-2, -1)
xt.is_contiguous() # Fasle

y = xt.view(-1) # 报错 size 不兼容
y = xt.contiguous().view(-1) # 正确执行, 会有新副本, y 与xt 不共享内存


# 4. squeeze 删除所有/指定位置大小为1的维度 [去除多余维度]
# unsqueeze 指定位置插入大小为1的维度 [扩展维度以便广播]
x = torch.randn(3, 4)
x.unsqueeze(0) # torch.Size([1, 3, 4])
x.unsqueeze(-1) # torch.Size([3, 4, 1])

y = torch.randn(1, 3, 1, 4)
y.squeeze() # torch.Size([3, 4])
y.squeeze(0) # torch.Size([3, 1, 4])


# 5. expend 广播扩展, 不复制内存, 只能扩展大小为1的维度, 创建新视图, 共享内存
# repeat 复制扩展, 复制内存, 独立内存
x = torch.randn(1, 3, 1, 4)
y = x.expend(2, 3, 5, 4) # torch.Size([2, 3, 5, 4])
y[0, 0, 0, 0] = 999 # 此时 y[1, 0, 0, 0] 和 x[0, 0, 0, 0] 都是 999.0

x = torch.tensor([[1, 2]]) # torch.Size([1, 2])
# tensor([[1, 2]])
y = x.repeat(2, 3) # torch.Size([2, 6]) 第0维重复2次, 第1维重复3次
# tensor([[1, 2, 1, 2, 1, 2],
# [1, 2, 1, 2, 1, 2]])


# 6. cat 在已有维度上拼接
# stack 创建新维度堆叠
a = torch.arange(6).view(2, 3) # torch.Size([2, 3])
b = torch.arange(6, 12).view(2, 3) # torch.Size([2, 3])
# tensor([[0, 1, 2],
# [3, 4, 5]])
# tensor([[ 6, 7, 8],
# [ 9, 10, 11]])

c = torch.cat([a, b], dim=0) # torch.Size([4, 3])
d = torch.cat([a, b], dim=1) # torch.Size([2, 6])
# tensor([[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]])
# tensor([[ 0, 1, 2, 6, 7, 8],
# [ 3, 4, 5, 9, 10, 11]])

e = torch.stack([a, b], dim=0) # torch.Size([2, 2, 3])
# tensor([[[ 0, 1, 2],
# [ 3, 4, 5]],
#
# [[ 6, 7, 8],
# [ 9, 10, 11]]])


# 7. split 指定大小, 尽量分割, 不直接返回 tensor
# chunk 指定份数, 平均分割, 返回 tensor
x = torch.arange(10)
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
res = torch.split(x, 3, dim=0) # 每份3个, 最后一份可能不足
# (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9]))
a, b = torch.split(x, [2, 8], dim=0) # 按指定大小分
# tensor([0, 1])
# tensor([2, 3, 4, 5, 6, 7, 8, 9])

x = torch.arange(12).view(3, 4)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
a, b = torch.chunk(x, 2, dim=1) # 分成2份
# tensor([[0, 1],
# [4, 5],
# [8, 9]])
# tensor([[ 2, 3],
# [ 6, 7],
# [10, 11]])


# 8. where 条件选择
# 三元选择 where(cond, x, y) 相当于 cond ? x : y
# 索引查找 where(cond) 返回满足条件的索引元组
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([10, 20, 30, 40, 50])
cond = x > 3
# tensor([False, False, False, True, True])
res = torch.where(cond, x, y)
# tensor([10, 20, 30, 4, 5])

mask = torch.tensor([[True, False, True],
[False, True, False]])
indices = torch.where(mask) # 返回一个满足条件的元素所在 (行, 列) 元组
# (tensor([0, 0, 1]), tensor([0, 2, 1]))

0.2 完整示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import math

def multi_head_attention(q, k, v, num_heads):
'''
简化版多头注意力
输入 q, k, v : [B, S, D]
输出 : [B, S, D]
'''
# 输入: [B, S, D]
batch_size, seq_len, model_dim = q.shape
head_dim = model_dim // num_heads

# q, k, v 分头 : [B, H, S, head_dim]
q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)

# 计算注意力分数 : [B, H, S, S]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)

# 通过 softmax 得到注意力权重 : [B, H, S, S]
atten_weight = torch.softmax(scores, dim=-1)

# 加权求和 : [B, H, S, head_dim]
context = torch.matmul(atten_weight, v)

# 合并多头返回结果 :
output = context.transpose(1, 2).contiguous().view(batch_size, seq_len, model_dim)

return output

# 使用示例
B, S, D, H = 2, 10, 64, 8
q = torch.randn(B, S, D)
k = torch.randn(B, S, D)
v = torch.randn(B, S, D)

# 输入: torch.Size([2, 10, 64])
output = multi_head_attention(q, k, v, H)
# 输出: torch.Size([2, 10, 64])
# tensor([[[-0.0507, 0.0495, 0.2281, ..., 0.8132, 0.2912, 0.6504],
# [ 0.2156, 0.5946, 0.1377, ..., 0.3467, -0.0341, 0.8923],
# [ 0.0520, 0.5703, 0.1694, ..., 0.5722, 0.2187, 0.6068],
# ...,
# [ 0.6040, 0.6911, 0.6183, ..., -0.2492, 0.7850, -0.6464],
# [ 0.0835, 0.6042, 0.0594, ..., -0.1515, 0.0158, 0.4431],
# [ 0.2226, 0.5818, 0.3697, ..., 0.6246, 0.0199, 0.8018]],

# [[-0.5293, -0.2701, 0.3302, ..., -0.2581, -0.6833, -0.2513],
# [-0.1782, -0.9686, 0.4143, ..., -0.6749, -0.6872, -0.4578],
# [-0.4887, -0.5157, 0.5180, ..., -0.5745, -0.8481, -0.0262],
# ...,
# [-0.7730, -0.1319, 0.5004, ..., -0.4520, -0.9131, -0.0570],
# [-1.1040, -0.2407, 0.3970, ..., -0.6808, 0.1697, 0.0223],
# [-0.5515, 0.3597, 0.2924, ..., -0.7344, -0.5441, -0.0239]]])

一、注意力机制 Attention

1.1 Scaled Dot-Product Attention

缩放点积注意力机制

  • 它的核心功能是计算序列中不同元素之间的相关性,并根据这种相关性重新分配信息
  • 通俗来说,它在做三件事:
  1. 打分(Scores):询问“Query(我想要什么)”和“Key(你有什么)”之间匹配度有多高?
  2. 归一化(Softmax):将匹配度转换成概率(权重),总和为 1。
  3. 加权融合(Sum):根据权重,从“Value(信息内容)”中提取出最相关的部分。
  1. 为什么代码中要“缩放(Scaled)”?
  • 即代码中的 / math.sqrt(head_dim)。这是为了防止点积结果过大。如果数值过大,经过 Softmax 后会进入梯度的饱和区,导致梯度消失,模型就学不动了。
  1. 为什么 Mask 放在计算完 Scores 后?
  • 核心原因:为了在进行 Softmax 归一化之前,彻底“封死”不该看的信息。
  • 逻辑顺序:
    1. Scores (QKTQK^T):算出每个词对其他词的原始“亲密度”。
    2. Mask:把不该看的位置手动强行抹除(设为 -\infty)。
    3. Softmax:数学上,e=0e^{-\infty} = 0。这样归一化后,非法位置的权重就变成了严格的 00
  • 可以放在别的地方吗?
    1. 放在投影前? 不行。投影是在提取特征,此时还没算词与词的关系,你不知道该遮住谁。
    2. 放在 Softmax 后? 理论上可以(手动把某些权重置 0),但会带来数学偏差。如果你在 Softmax 后置 0,剩下位置的权重和就不再是 11 了,这会破坏概率分布的特性,导致训练不稳定。
  1. 为什么 Dropout 放在 attn_weights 后?
  • 它的作用:Dropout 会随机让某些注意力权重变成 00。为了增加注意力的“鲁棒性(Robustness)”。
  • 为什么要这么做?
    • 防止模型过于依赖某一个特定的词。比如翻译“苹果”时,模型可能 90% 的注意力都盯着“Apple”。
    • Dropout 强迫模型:“如果我不让你看 Apple,你能不能通过上下文的其他词(如“红色的”、“多汁的”)也推断出结果?”
  • 可以放在别的地方吗?
    1. 放在 VV 矩阵后:这是很常见的做法。很多实现会在 context = attn_weights @ v 之后再加一个 Dropout
    2. 放在 Q,K,VQ, K, V 投影后:也可以,那是为了防止过拟合特征。
  • 结论:在 attn_weights 上做 Dropout 是 Transformer 论文的原生做法,目的是对“联系”进行随机阻断,而不是对“特征”进行随机阻断。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
"""
缩放点击注意力机制
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(head_dim)) @ V
"""
def __init__(self, dropout_p=0.0):
super().__init__()
self.dropout = dropout_p

def forward(self, q, k, v, mask=None):
head_dim = q.size(-1)

scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
if mask is not None:
scores = scores.masked_fill(mask==0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, v)

return output, attn_weights

# 维度情况:
# q [batch_size, head_num, seq_len_q, head_dim]
# k [batch_size, head_num, seq_len_k, head_dim]
# v [batch_size, head_num, seq_len_k, head_dim]
# q 和 k 的维度可以不一样
# 若 seq_len_q == seq_len_k 则是自注意力, 词与词之间相互看
# 若 seq_len_q != seq_len_k 则是交叉注意力/检索式, 用一堆东西去检索另一堆东西

if __name__ == "__main__":
batch_size, seq_len, model_dim = 2, 8, 256
head_num = 4
head_dim = model_dim//head_num

q = torch.randn(batch_size, head_num, seq_len, head_dim)
k = torch.randn(batch_size, head_num, seq_len, head_dim)
v = torch.randn(batch_size, head_num, seq_len, head_dim)
mask = torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len) # Casual Mask 下三角1

attention_layer = ScaledDotProductAttention()
output, attn_weights = attention_layer(q, k, v, mask=mask)

print(output.shape, attn_weights.shape)
print(attn_weights[0, 0, 0, :]) # 第一个样本第一个头的第1行(应该仅有第一个值)
print(attn_weights[0, 0, 1, :]) # 第一个样本第一个头的第2行
# torch.Size([2, 4, 8, 64]) torch.Size([2, 4, 8, 8])
# tensor([1., 0., 0., 0., 0., 0., 0., 0.])
# tensor([0.7627, 0.2373, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])

1.2 Multi-Head Attention

多头注意力机制

  • 前面是单次“缩放点积注意力”,多头注意力(Multi-Head Attention, MHA) 就是在这个基础上进行的“并行化升级”。
  • 核心功能是 “分而治之”,具体体现在以下两点:
  1. 捕捉多维度关系:在自然语言中,一个词可能有多重身份。单头注意力很难同时关注这些不同的维度,而多头注意力可以。
  2. 增强模型稳定性:通过并行计算并拼接(Concatenate),模型可以综合多个头的意见,类似于集成学习(Ensemble Learning),减少了对单一注意力权重的依赖。

疑问:

  1. 为什么简单的线性层就能提取信息?
  • 投影的本质是“特征重组”:线性层其实是在对原始维度做加权组合。
  • WW 是静态的知识,而 q=xWq=xW 是动态的结果。
  1. WW 参数是怎么训练出来的?
  • 初始化、前向传播、计算损失、反向传播,经过数万亿次训练,WqW_q 终于学会了:“哦,当我看到名词时,我产生的 qq 应该去寻找它的修饰词”。
  1. Wq 是怎么“分头”的?
  • 这是一个非常重要的工程实现细节。在代码实现中,我们通常不会真的定义 num_heads 个小的线性层,而是定义一个大的线性层,然后通过矩阵切分。
  • 逻辑如下:
    • 假设 model_dim = 512num_heads = 8,那么 head_dim = 64
    • self.w_q 是一个 (512, 512) 的矩阵。
    • 当你做完投影 q = self.w_q(x_query) 后,得到的 q 形状是 [batch, seq_len, 512]。接下来的 view 操作将 512 拆成 8 个 64。WqW_q 的前 64 列:实际上构成了“第 1 个头”的变换矩阵。WqW_q 的第 65-128 列:构成了“第 2 个头”的变换矩阵…
    • 大矩阵 WW 被横向切分成了 HH 块。每一块负责把输入 xx 映射到一个特定的子空间(头)。
  1. 输出映射 w_o 在做什么?
  • 整合多头的意见:在 MultiHeadAttention 中,我们将一个大的维度(比如 512)拆成了 8 个独立的小头(每个 64 维)。
    • 拆分时:每个头独立工作,互不干扰(头 1 在看语法,头 2 在看语义)。
    • 拼接后:我们把这 8 个头的结果横向拼在一起,得到了一个 512 维的向量。
      问题来了: 拼接后的向量,其前 64 维和后 64 维是完全“孤立”的,它们之间没有发生过任何数学上的交互。w_o(输出投影层)的作用就像是一个总编辑:它把 8 个专家的意见汇总起来,通过矩阵乘法让这 8 个子空间的信息进行二次加权融合。没有 w_o,多头注意力就只是“简单堆砌”,而不是“深度集成”。
  • 空间的“恢复”, 保持维度一致性:虽然拼接后的维度已经是 512 了,但这个 512 维的空间分布和输入时的 xx 空间可能已经完全不同了。w_o 提供了一个线性变换的机会,让模型能够把注意力机制提取到的特征,重新映射回主干网络所期望的特征空间里。
  1. 多头注意力全流程工作?
  • 如果我们把 MultiHeadAttention 比作一次专家座谈会:
    • Wq,Wk,WvW_q, W_k, W_v:给每个专家发一份不同侧重点的资料。
    • Scaled Dot-Product:专家们各自关起门来写分析报告。
    • Concat(拼接):把 8 份报告装订成一册。
    • WoW_o (Output Projection):主编阅读这本手册,提炼出最终的决策摘要,发给下一个部门。
  • 如果没有 WoW_o,下一个部门拿到的就是 8 份乱七八糟、各说各话的草稿,无法直接使用。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
'''
多头注意力机制
并行多个缩放点积注意力,最后线性融合即可
'''
def __init__(self, model_dim, head_num, dropout_p=0.1):
super().__init__()
assert model_dim % head_num == 0, "model_dim must be disiviable by head_num"

self.model_dim = model_dim
self.head_num = head_num
self.head_dim = model_dim // head_num

self.w_q = nn.Linear(model_dim, model_dim)
self.w_k = nn.Linear(model_dim, model_dim)
self.w_v = nn.Linear(model_dim, model_dim)

self.dropout = nn.Dropout(dropout_p)
self.w_o = nn.Linear(model_dim, model_dim)

def forward(self, x_query, x_context, mask=None):
'''
可以没有 x_context, 则是自注意力
'''
batch_size = x_query.shape[0]

# 线性投影, 得到 q, k, v
q = self.w_q(x_query)
if x_context is not None:
k = self.w_k(x_context)
v = self.w_v(x_context)
else:
k = self.w_k(x_query)
v = self.w_v(x_query)

# 分头处理
q = q.view(batch_size, -1, self.head_num, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.head_num, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.head_num, self.head_dim).transpose(1, 2)

# 缩放点积注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask==0, -1e9)

attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)

context = torch.matmul(attn_weights, v)

# 合并多头
context = context.transpose(1, 2).contiguous()
output = context.view(batch_size, -1, self.model_dim)

# 输出投影
output = self.w_o(output)
return output


if __name__ == "__main__":
# batch_size=2, seq_len=10, model_dim=64, num_heads=8
x = torch.randn(2, 10, 64)

mha = MultiHeadAttention(model_dim=64, head_num=8)
out = mha(x, x) # Self-Attention: x_query=x, x_context=x

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
# Input shape: torch.Size([2, 10, 64])
# Output shape: torch.Size([2, 10, 64])

1.3 Group Query Attention

分组查询注意力

  • 这是对标准多头注意力(MHA)的变体优化。
  • 它旨在解决大模型在生成(Inference)时的一个巨大瓶颈:KV Cache(键值缓存)过大导致的内存带宽受限。
  • 在 LLM 生成单词时,我们需要把每一层已经计算过的 K 和 V 存下来,这就是 KV Cache。
    • MHA(多头)的问题:Q, K, V 的头数一样多。如果模型很大,KV Cache 会占用海量的显存,导致 Batch Size 开不大,推理速度慢。
    • MQA(多查询)的极端:让所有 Q 头共享 同一组 K 和 V。虽然省了空间,但模型表达能力下降太厉害,效果变差。
    • GQA(分组)的折中:把 Q 分成几组,每一组 Q 共享一对 KV。
  • 核心功能:在不显著降低模型效果的前提下,大幅减少显存占用并提升推理速度。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class GroupQueryAttention(nn.Module):
'''
分组查询注意力 GQA
- Q 有 head_num 个头
- K, V 有 head_kv_num 个头 ( head_kv_num < head_num)
- 每几个 Q 共享一个 KV (计算过程中通过 repeat_kv 实现广播扩张用于计算)
'''
def __init__(self, model_dim, head_num, head_kv_num, dropout_p=0.1):
super().__init__()

assert model_dim % head_num == 0, "model_dim must be divisible by head_num"
assert head_num % head_kv_num == 0, "head_num must be divisible by head_kv_num"

self.model_dim = model_dim
self.head_num = head_num
self.head_kv_num = head_kv_num
self.head_dim = model_dim // head_num
self.repeat_num = head_num // head_kv_num

self.w_q = nn.Linear(model_dim, model_dim, bias=False)
self.w_k = nn.Linear(model_dim, head_kv_num * self.head_dim, bias=False)
self.w_v = nn.Linear(model_dim, head_kv_num * self.head_dim, bias=False)

self.dropout = nn.Dropout(dropout_p)

self.w_o = nn.Linear(model_dim, model_dim)

def repeat_kv(self, x, n_rep):
# 先扩展维度, 再合并回去
batch_size, head_kv_num, seq_len, head_dim = x.shape
if n_rep == 1:
return x

x = x[:, :, None, :, :]
x = x.expand(batch_size, head_kv_num, n_rep, seq_len, head_dim)
x = x.reshape(batch_size, head_kv_num * n_rep, seq_len, head_dim)
return x

def forward(self, x, mask=None):
batch_size, seq_len, model_dim = x.shape
# 输入映射
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)

# 分头处理
q = q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.head_kv_num, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.head_kv_num, self.head_dim).transpose(1, 2)

# 共享头扩展
k = self.repeat_kv(k, self.repeat_num)
v = self.repeat_kv(v, self.repeat_num)

# 缩放点积注意力计算
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)

attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)

context = torch.matmul(attn_weights, v)

# 拼接多头与输出映射
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, model_dim)
output = self.w_o(context)

return output


if __name__ == "__main__":
batch_size, seq_len, model_dim = 2, 10, 256
head_num = 8
head_kv_num = 4

x = torch.randn(batch_size, seq_len, model_dim)
gqa = GroupQueryAttention(model_dim, head_num, head_kv_num)
out = gqa(x)

print("x.shape: ", x.shape)
print("out.shape: ", out.shape)
# x.shape: torch.Size([2, 10, 256])
# out.shape: torch.Size([2, 10, 256])

1.4 Multi-Latent Attention

多头潜在注意力

  • 痛点:KV Cache 的矛盾
    • MHA:性能好,但 KV 缓存太大(显存炸了)。
    • GQA:缓存小了,但多少还是牺牲了一些模型性能。
  • MLA 的黑科技功能:
    • 低秩压缩(Compression):通过 kv_down_proj 把高维信息压缩到一个很小的 latent_dim。在推理生成时,我们只需要缓存这个极小的压缩向量,而不是庞大的原始 KV。
    • 解耦 RoPE(Decoupled RoPE):这是 MLA 最天才的设计。传统 RoPE 旋转编码会破坏矩阵分解的特性,导致无法压缩。MLA 把 KV 拆成了 “内容(Content)” 和 “位置(RoPE)” 两部分:
      • 内容部分:被狠狠压缩,节省空间。
      • 位置部分:不压缩,保证模型对顺序的敏感度。
    • 计算时还原(Up-Projection):虽然缓存的是压缩包,但在计算的那一瞬间,通过 kv_up_proj 把它瞬间解压还原出来。
  • 核心逻辑:
    1. KV 的“漏斗形”变换
    • 压缩:从 model_dim 压到很小的 latent_dim
    • 解压:从 latent_dim 还原出多个头的信息
      在推理(Inference)阶段,模型只需要把 kv_latent 存进显存。这比存下所有头的 KV 要小得多。
    1. 三路分离(Content, RoPE, Value)
    • k_content & v_content:这些是从压缩包里解压出来的,用于表达“这是什么词”。
    • k_rope:这是专门用来加旋转位置编码的。
    1. 混合计算
    • 最后计算注意力得分时,模型把带有位置信息的向量和内容向量拼在一起。这样既保证了计算效率,又保证了位置感知的精度。
  • 应用在哪里:目前仅见于 DeepSeek 系列模型(以及效仿它的自研模型)。
  • 完成的工作:
    • 大幅降低推理成本:KV 缓存大小降至 MHA 的几十分之一。
    • 吞吐量翻倍:因为显存省了,可以在同一张显卡上跑更多的并发(Batch Size 大幅提升)。
    • 超长文本:支持极长的上下文,因为 KV Cache 不再是长文本的瓶颈。
  1. 既然计算时还是要通过 up_proj 还原出全维度的 KV,那为什么能省显存?
  • 显存消耗主要在存储,而不在那一瞬间的计算
    • 在生成任务中,模型需要把每一层、每一个 Token 的 KV 永久保留。
    • MLA 把原本需要存 512 维的东西,变成了只存 128 维(举例)。虽然计算那一刻要变回 512 维,但计算完就扔掉了,不需要长久占用显存。
  • 这就像:
    • MHA 是把所有快递都拆开,平铺在仓库里。
    • MLA 是把快递全部真空压缩叠放。只有当你要看某个快递的那一秒,你才给它充气,看完立马放气存回去。
  1. 为什么 Q 也需要压缩(q_down_proj)?KV 压缩是为了省显存(KV Cache),但 Q 又不需要缓存,为什么要多此一举压一遍再解压呢?
  • 不是为了省显存,而是为了节省计算量(训练和推理时的算力)
    • 低秩特性:如果 KV 是低秩的(即可以用很小的维度表达核心信息),那么 Q 理论上也可以是低秩的。
    • 计算对齐:通过让 Q 也经过一个“压缩-解压”的过程,模型可以学习在一个更紧凑的“潜在特征空间”里进行匹配。
    • 减少参数量:如果直接从 model_dim (如 512) 映射到 num_heads * head_dim (如 8 * 128 = 1024),参数量很大。通过先压到 128 再弹回去,中间的参数量会显著减少。
  1. K 和 V 下采样到共享同一个潜在向量(Latent Vector),上采样才映射到各自向量?
  • 是的。K 和 V 在存储(Cache)阶段是合二为一的(共享 kv_latent),只有在计算注意力的一瞬间,才通过 up_proj 临时变回各自的样子。
    • K 是为了告诉 Query:“我是什么标签”。
    • V 是为了告诉 Query:“我携带什么内容”。
  • 虽然它们功能不同,但它们都源自同一个输入 xx。DeepSeek 团队认为,既然它们都来自同一个 xx,那么它们的信息一定是高度冗余的。既然高度冗余,我干脆把 xx 压缩成一个全能的“潜在压缩包”(kv_latent)。这个压缩包既包含了做 Key 的潜力,也包含了做 Value 的信息。
  1. 最后输出 model_dim 可以完全不等于 num_heads * head_dim 是吗?
  • 是的,但在工程实践中,我们为了“效率”和“对齐”,绝大多数模型都强行让它们相等。
  • 为什么主流架构(如 GPT, Llama, BERT)都要求相等?
    • 在经典的 Transformer 论文中,有一个核心设计原则:残差连接(Residual Connection)。
    • 因此,这是维度对齐的硬性要求和计算逻辑的对齐。
  1. q_rope, k_rope = self.rope(q_rope, k_rope)是进行什么操作,会变换维度吗?
  • 简单来说:RoPE 是一种“给向量注入位置感”的数学变换,它不会改变维度,但会改变向量指向的方向。
  • 传统的编码(如 BERT)是直接把位置向量加到词向量上。而 RoPE 是让向量旋转。
    • 想象每个 head_dim 里的元素成对组成一个二维平面上的点:
    • 第 1 个词:向量旋转 θ\theta 角度。
    • 第 2 个词:向量旋转 2θ2\theta 角度。
    • nn 个词:向量旋转 nθn\theta 角度。
  • 为什么要旋转?
    • 因为旋转有一个神奇的数学特性:两个向量点积的结果,只取决于它们之间的相对角度。
    • 这意味着,当 Query 和 Key 进行点积时,模型能自动感知到它们之间相隔了多少个词(相对位置),而不仅仅是它们各自在什么位置(绝对位置)。
  • RoPE 是一种**逐元素(Element-wise)**的变换。它在保持向量长度(模长)不变的情况下,通过正弦和余弦函数改变了分量的数值。
  1. 旋转位置编码(RoPE)是“加”还是“拼”?
  • 传统做法(如 Llama/GPT):它们不分 content 和 rope。它们是直接对整个 head_dim 进行旋转。你可以理解为把整个特征向量丢进一个旋转矩阵里“搅匀”了。
  • MLA 的做法:拼接(Concatenate)。
    • q_content:纯粹的特征,不带位置信息。
    • q_rope:纯粹的位置,不带语义特征(或者是极简特征)。
    • torch.cat:把它们像火车车厢一样接在一起。
  • 为什么要“拼”在后面,而不是“加”或“全旋”?
    1. 保护“低秩压缩”的线性特性(最核心原因)
    • MLA 为了省显存,把 KV 压缩到了一个很小的 latent_dim。数学冲突:RoPE 旋转是一个非线性的三角函数变换。如果你对整个向量进行旋转,那么压缩矩阵 WdownW_{down} 和解压矩阵 WupW_{up} 就无法再通过简单的矩阵分解来还原信息了。
    • 我只压缩 content 部分。因为这部分不旋转,它保持了纯粹的线性,可以完美地被压缩和解压。而 rope 部分我不压缩,让它独立存在,保证位置信息的绝对精准。
    1. 解耦:上帝的归上帝,凯撒的归凯撒
    • 通过拼接,MLA 实现了**语义(Content)与位置(Location)**的彻底解耦:
      • q_content @ k_content:计算的是“这两个词的意思匹配吗?”
      • q_rope @ k_rope:计算的是“这两个词的相对距离合适吗?”
    • 最终 Score:是这两者的求和(因为向量拼接后的点积,等于各部分点积之和)。
  1. 拼接后的点积数学原理为什么 cat 之后做点积能起作用?
  • 看这个数学等式:
    假设 q=[qc,qr]q = [q_c, q_r]k=[kc,kr]k = [k_c, k_r]cc 为内容,rr 为位置):

    qkT=(qckcT)+(qrkrT)q \cdot k^T = (q_c \cdot k_c^T) + (q_r \cdot k_r^T)

    这意味着,模型在计算注意力分数时,实际上是在同时考虑两个独立的维度:

    • 意思对不对? (Content Match)
    • 位置对不对? (Position Match)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# 导入旋转位置编码直接从另一个文件导入
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from position.RotaryEmbedding import RotaryEmbedding


class MultiLatentAttention(nn.Module):
'''
多头潜在注意力
通过低秩投影将 Q 和 KV 压缩到潜在空间,减少 KV 缓存的显存占用
同时结合 RoPE 位置编码保持位置感知能力
'''
def __init__(self, model_dim, head_num, head_dim, latent_dim, rope_dim, dropout_p=0.1):
super().__init__()

assert model_dim % head_num == 0, "model_dim must be divisible by head_num"

self.model_dim = model_dim
self.head_num = head_num
self.head_dim = head_dim
self.latent_dim = latent_dim
self.rope_dim = rope_dim

# KV 投影: 下采样+上采样
self.kv_down_proj = nn.Linear(model_dim, latent_dim, bias=False)
self.kv_up_proj = nn.Linear(latent_dim, head_num * (head_dim + rope_dim + head_dim), bias=False)

# Q 投影: 下采样+上采样
self.q_down_proj = nn.Linear(model_dim, latent_dim, bias=False)
self.q_up_proj = nn.Linear(latent_dim, head_num * (head_dim + rope_dim), bias=False)

# RoPE 旋转位置编码
self.rope = RotaryEmbedding(head_dim=rope_dim)

# 输出映射
self.o_proj = nn.Linear(head_num * head_dim, model_dim, bias=False)
self.dropout = nn.Dropout(dropout_p)

def forward(self, x, mask=None):
batch_size, seq_len, model_dim = x.shape

# KV 映射与分离
kv_latent = self.kv_down_proj(x) # [batch_size, seq_len, latent_dim]
kv_full = self.kv_up_proj(kv_latent) # [batch_size, seq_len, head_num * (head_dim + rope_dim + head_dim)]
kv_full = kv_full.view(batch_size, seq_len, self.head_num, -1)

k_content, k_rope, v_content = torch.split(kv_full, [self.head_dim, self.rope_dim, self.head_dim], dim=-1)

# Q 映射与分离
q_latent = self.q_down_proj(x)
q_full = self.q_up_proj(q_latent)
q_full = q_full.view(batch_size, seq_len, self.head_num, -1)

q_content, q_rope = torch.split(q_full, [self.head_dim, self.rope_dim], dim=-1)

# 应用RoPE, 合并内容
q_rope, k_rope = self.rope(q_rope, k_rope)
q = torch.cat([q_content, q_rope], dim=-1) # [batch_size, seq_len, head_num, head_dim + rope_dim]
k = torch.cat([k_content, k_rope], dim=-1)

# 分头
q.transpose_(1, 2) # [batch_size, head_num, seq_len, head_dim + rope_dim]
k.transpose_(1, 2) # [batch_size, head_num, seq_len, head_dim + rope_dim]
v = v_content.transpose(1, 2) # [batch_size, head_num, seq_len, head_dim]

# 缩放点积注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim + self.rope_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

attn_weights = F.softmax(scores, dim=-1) # [batch_size, head_num, seq_len, seq_len]
attn_weights = self.dropout(attn_weights)

context = torch.matmul(attn_weights, v) # [batch_size, head_num, seq_len, head_dim]

# 合并头, 映射输出
output = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.head_num * self.head_dim)
output = self.o_proj(output) # [batch_size, seq_len, num_heads * head_dim] -> [batch_size, seq_len, model_dim]

return output


if __name__ == "__main__":
batch_size, seq_len, model_dim = 2, 10, 512
head_num, head_dim = 8, 64
latent_dim = 128 # 【核心】KV 压缩后的维度 (远小于 num_heads * head_dim)
rope_dim = 32

x = torch.randn(batch_size, seq_len, model_dim)
mask = torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)

mla = MultiLatentAttention(model_dim, head_num, head_dim, latent_dim, rope_dim)

out = mla(x, mask=mask)

print(f"x.shape: {x.shape}")
print(f"out.shape: {out.shape}")
# x.shape: torch.Size([2, 10, 512])
# out.shape: torch.Size([2, 10, 512])

# KV 缓存节省
mha_mem = batch_size * (2 * model_dim * head_num * head_dim)
mla_mem = batch_size * (1 * model_dim * (latent_dim + rope_dim))
reduction_ratio = mha_mem / mla_mem
# 相当于 2 * head_num * head_dim / 1 * (latent_dim + rope_dim)
print(f"MHA占用: {mha_mem:,}")
print(f"MLA占用: {mla_mem:,}")
print(f"压缩倍率: {reduction_ratio:.1f}")
# MHA占用: 1,048,576
# MLA占用: 163,840
# 压缩倍率: 6.4

二、归一化层 Normalization

  • BatchNorm (BN):是在一个 Batch 之间算均值。
  • LayerNorm (LN):是在每个样本内部(特征维度)算均值。

2.1 LayerNorm

层归一化

  • LayerNorm 的作用就是**“控场”**。它确保神经元的输出不会因为层数太深而变得忽大忽小,维持整个系统的稳定性。
  • 在深度网络中,数据经过每一层都会发生偏移和缩放。LayerNorm 的存在是为了让每一层输出的特征分布都回到一个“标准状态”(均值为 0,方差为 1),防止梯度爆炸或梯度消失。
  • 核心功能可以概括为两步:
    1. 标准化(Standardization):把这一层所有神经元的输出拿出来,算出平均值(Mean)和标准差(Std),然后把大家强行拉回到同一个起跑线上。
    2. 重构(Re-scaling & Shifting):利用代码中的 gamma 和 beta。因为单纯的标准化可能会破坏模型已经学到的有用特征,所以我们给模型两个可学习的参数,让它自己决定:“如果全归一化太死板了,我可以稍微偏移一点点。”
  • 标准化(0, 1):是为了生存(不让梯度爆炸,让训练能跑通)。
  • 仿射变换(β,γ2\beta, \gamma^2):是为了进化(让模型学到更复杂的特征关系)。
  1. 未进行仿射变换(只有标准化)时,还未应用 gammagamma(缩放)和 betabeta(平移),LayerNorm 的输出就是纯粹的标准化结果:
  • 均值 (μ\mu) 为 0:所有 model_dimmodel\_dim 个特征值加起来除以 model_dimmodel\_dim 等于 0。
  • 方差 (σ2\sigma^2) 为 1:这些特征值的离散程度被缩放到单位 1。
    此时的状态:这种状态被称为**“白化(Whitening)”。它保证了无论输入数据的量级(Scale)如何(比如有的层输出数值在 100 左右,有的在 0.1 左右),进入下一层时,它们的分布都是统一的。这极大地解决了内部协变量偏移(Internal Covariate Shift)**问题,让学习率可以开得更大,收敛更快。
  1. 进行仿射变换后,均值和方差的变化一旦乘上 gammagamma 并加上 betabeta,输出的均值和方差就不再是 0 和 1 了。
  • 均值的变化:输出的均值会变成 betabeta
  • 方差的变化:输出的方差会变成 gamma2gamma^2
  • 数学推导简述:
    • 如果 xnormx_{norm} 的均值为 0,方差为 1,那么对于 y=xnormγ+βy = x_{norm} \cdot \gamma + \beta
    • E[y]=E[xnorm]γ+β=0γ+β=βE[y] = E[x_{norm}] \cdot \gamma + \beta = 0 \cdot \gamma + \beta = \beta
    • Var[y]=Var[xnorm]γ2=1γ2=γ2Var[y] = Var[x_{norm}] \cdot \gamma^2 = 1 \cdot \gamma^2 = \gamma^2
  1. 好不容易归一化成了 0 和 1,为什么又要用 gammagammabetabeta 把它破坏掉?
    这是深度学习中一个极其精妙的设计哲学:“保底”与“自适应”。
  • 恢复模型的表达能力(Identity Mapping)
    • 如果模型发现“纯归一化”反而让效果变差了(例如某些激活函数在 0 附近是线性的,失去了非线性表达力),那么模型可以通过学习,让 gamma=1,beta=0gamma=1, beta=0,从而回到原始分布。仿射变换给了模型一个**“后悔药”**,保证 LayerNorm 层的加入最差也不会让模型变笨。
  • 调整激活函数的激活区间
    • 很多激活函数(如 Sigmoid 或 Tanh)在 0 附近最敏感(梯度最大),但在远处会饱和(梯度消失)。
    • 通过 betabeta,模型可以把特征移动到激活函数最灵敏的区域。
    • 通过 gammagamma,模型可以控制特征进入非线性区的程度。
  • 特征重要性的重分配
    • 在一个词向量的 512 维中,并不是每一维都同样重要。
    • gammagamma 可以增大某些重要维度的权重,缩小噪音维度的权重。
    • 这本质上是给模型提供了一层逐通道(Per-channel)的缩放控制。
  1. 为什么要这么变?
  • 如果不归一化:如果下一层是一个 ReLU 激活函数,10.0 这个值可能太大了,导致网络对微小的变化不敏感;或者如果是 Sigmoid,10.0 会直接让梯度变成 0(饱和)。
  • 标准化后:数据回到了激活函数最敏感的“黄金地带”(0 附近)。
  • 仿射变换后:模型通过 γ\gammaβ\beta 告诉网络:“虽然大家都要在 0 附近,但我觉得第二维信息比较重要,我要把它拉长一点;第三维信息有点吵,我把它压低一点。”
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn


class LayerNorm(nn.Module):
'''
层归一化
在每个样本的特征维度进行归一化, 不依赖 batch 统计量, 适用于变长序列

LayerNorm(x) = (x - mean) / sqrt(var + eps) * gamma + beta
'''
def __init__(self, model_dim, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(model_dim)) # 可学习的缩放参数
self.beta = nn.Parameter(torch.ones(model_dim)) # 可学习的偏移参数

def forward(self, x):
# x: [batch_size, seq_len, model_dim]
# 计算均值、方差
mean = x.mean(-1, keepdim=True) # [batch_size, seq_len, 1]
var = x.var(-1, keepdim=True, unbiased=False) # [batch_size, seq_len, 1]

x_norm = (x - mean) / torch.sqrt(var + self.eps) # 归一化
x_o = x_norm * self.gamma + self.beta # 应用可学习的仿射变换

return x_o


if __name__ == "__main__":
batch_size, seq_len, model_dim = 2, 2, 6 # 每个词向量只有 6 个特征

x = torch.randn(batch_size, seq_len, model_dim)
layernorm = LayerNorm(model_dim)
out = layernorm(x)

print("x:\n", x)
print("out:\n", out)

# x:
# tensor([[[ 1.2764, -0.1716, 0.2268, -0.4756, -1.7076, -0.6254],
# [ 0.7817, -0.4869, 0.8618, -0.5067, 0.0978, -0.3469]],
#
# [[-1.6539, -1.1373, -0.4134, -1.0421, 0.1656, 2.2103],
# [ 1.7681, 0.8299, -0.9012, -0.3458, -0.2297, -1.5390]]])
# out:
# tensor([[[ 2.6883, 1.0827, 1.5245, 0.7456, -0.6205, 0.5795],
# [ 2.2537, 0.0289, 2.3943, -0.0058, 1.0544, 0.2744]],
#
# [[-0.0602, 0.3479, 0.9198, 0.4231, 1.3771, 2.9923],
# [ 2.6848, 1.8246, 0.2376, 0.7467, 0.8533, -0.3471]]],
# grad_fn=<AddBackward0>)

2.2 RMSNorm

均方根层归一化

  • RMSNorm 的公式极其简单:y=xmean(x2)+ϵγy = \frac{x}{\sqrt{mean(x^2)+\epsilon}} \cdot \gamma
  • 它去掉了 LayerNorm 中的两个关键部分:
    1. 均值中心化(Re-centering):不再减去 μ。
    2. 偏置项(Additive Bias):不再加上 β。
  • 为什么敢这么删?
    研究发现,LayerNorm 的成功主要来自于**重缩放(Re-scaling)**带来的不变性,而不是减去均值带来的平移不变性。删掉均值计算后,不仅计算量减少了约 10%-40%,而且实验证明模型效果几乎没有下降,甚至在某些大规模训练中更加稳定。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
'''
RMS层归一化
LayerNorm 的简化变体,去掉了均值中心化步骤,只用 RMS 归一化
计算量更小, 效果并不减少

RMSNorm(x) = x / RMS(x) * gamma
其中 RMS(x) = sqrt(mean(x^2) + eps)
'''
def __init__(self, model_dim, eps=1e-8):
super().__init__()

self.eps = eps
self.gamma = nn.Parameter(torch.ones(model_dim)) # 可学习的缩放参数

def _norm(self, x):
mean_square = x.float().pow(2).mean(-1, keepdim=True)
rsqrt = torch.rsqrt(mean_square + self.eps)
return x.float() * rsqrt

def forward(self, x):
x_norm = self._norm(x)
return x_norm.type_as(x) * self.gamma


if __name__ == "__main__":
batch_size, seq_len, model_dim = 2, 2, 6 # 每个词向量只有 6 个特征

x = torch.randn(batch_size, seq_len, model_dim)
rmsnorm = RMSNorm(model_dim)
out = rmsnorm(x)

print("x:\n", x)
print("out:\n", out)
# x:
# tensor([[[ 0.0966, 0.9523, 1.3856, -0.5899, -0.0432, -0.8740],
# [-0.1539, 2.0977, -1.1415, -0.1885, 0.8163, 0.4271]],
#
# [[-1.6803, -1.3997, 0.1455, 0.0500, -0.0594, 0.5897],
# [-0.1727, -0.9613, -0.1581, -1.0434, -0.2535, 0.0951]]])
# out:
# tensor([[[ 0.1190, 1.1737, 1.7078, -0.7270, -0.0533, -1.0772],
# [-0.1466, 1.9984, -1.0874, -0.1796, 0.7776, 0.4069]],
#
# [[-1.8124, -1.5097, 0.1570, 0.0539, -0.0641, 0.6361],
# [-0.2891, -1.6093, -0.2646, -1.7467, -0.4244, 0.1592]]],
# grad_fn=<MulBackward0>)

三、前馈网络

3.1 FFN

  • FFN的功能
    • 知识存储器:研究表明,Transformer 中的大部分“知识”(比如:巴黎是法国的首都)其实是存储在 FFN 的权重里的,而 Attention 更多是负责逻辑路由。
    • 非线性变换:Attention 本质上是加权求和(线性操作)。如果没有 FFN 里的 ReLU 激活函数,整个模型无论叠加多少层,在数学上都只是一个巨大的线性矩阵。FFN 赋予了模型处理复杂逻辑的非线性能力。
    • 维度跳跃(升维再降维):
      • 上投影(WupW_{up}):把特征从 512 维拉伸到 2048 维。在一个更高维的空间里,特征更容易被分开和处理。
      • 下投影(WdownW_{down}):处理完后,再压回 512 维,以便进行残差连接。
  • 时代眼泪 ReLU 的局限性:ReLU 在 x<0x < 0 时输出全为 0,这会导致所谓的“神经元死亡”现象(Dead ReLU)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
import torch.nn.functional as F

class FFN(nn.Module):
'''
前馈神经网络
FFN(x) = relu(x @ w1 + b1) @ w2 + b2
'''
def __init__(self, model_dim, intermediate_dim):
super().__init__()

self.model_dim = model_dim
self.intermediate_dim = intermediate_dim # intermediate_dim 通常是 model_dim 的 4 倍

self.w_up = nn.Linear(model_dim, intermediate_dim)
self.w_down = nn.Linear(intermediate_dim, model_dim)

def forward(self, x):
up = self.w_up(x)
down = F.relu(self.w_down(up))
return down


if __name__ == "__main__":
batch_size, seq_len, model_dim = 2, 3, 6
intermediate_dim = 24

x = torch.randn(batch_size, seq_len, model_dim)
ffn = FFN(model_dim, intermediate_dim)
out = ffn(x)

print(x)
print(out)
# tensor([[[ 1.2860, -0.1088, -0.3646, -1.5348, 0.3474, -0.1418],
# [-0.6025, -0.2117, 1.1263, -1.1982, 0.2312, 1.3653],
# [-0.1740, 1.9831, 1.2118, -0.4681, -0.2985, 1.6933]],
#
# [[ 0.3352, 1.4951, 0.5365, -1.3445, 1.5750, 0.7772],
# [ 2.5872, -0.0512, 0.0086, 0.0346, -1.2601, -2.1188],
# [-2.3334, -2.2904, -0.0671, -0.4802, 0.1850, 1.8944]]])
# tensor([[[0.0617, 0.2117, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.0000, 0.1303, 0.0000, 0.0000, 0.0000, 0.0000]],
#
# [[0.0000, 0.2630, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.2756, 0.0591, 0.3597, 0.3379, 0.0000, 0.1075],
# [0.0000, 0.0000, 0.0000, 0.0000, 0.4051, 0.0000]]],
# grad_fn=<ReluBackward0>)

3.2 SwiGLU

SwiGLU 前馈神经网络

  • 对比刚才的标准 FFN,你会发现 SwiGLU 最直观的变化是:中间层多了一个线性分支。
    • 标准 FFN (ReLU):xLinearReLULinearx \rightarrow \text{Linear} \rightarrow \text{ReLU} \rightarrow \text{Linear}
    • SwiGLU:它把中间层拆成了两路:
      • Gate 路:经过 w_gate 再加上 SiLU 激活函数。
      • Up 路:经过 w_up,但不加激活函数。
      • 合体:两路数据对应位置相乘(Element-wise Product)。
  • 门控机制(Gating Mechanism)
    “GLU”代表 Gated Linear Unit。你可以把 gate 分支想象成一个过滤器或开关:
    • SiLU(w_gate(x)) 算出每个特征的“重要程度”(0 到 1 之间的值)。
    • 用这个程度去乘以 w_up(x) 提取的原始特征。
    • 意义:模型可以根据当前上下文,动态地决定哪些信息需要加强,哪些信息需要抑制。ReLU 只有“开”或“关”,而 SwiGLU 是“精细化调节”。
  • SiLU (Swish) 的平滑性
    代码里用的 F.silu(也就是 xsigmoid(x)x \cdot \text{sigmoid}(x)):
    • 它比 ReLU 更平滑。在 x=0x=0 附近没有生硬的折点,这有助于梯度在深层网络中更稳定地流动。
    • 它在负数区域有一点点“下潜”(小负数),这能保留一些微弱的负向信号,增加模型的容错率。
  • 一个有趣的细节:“intermediate_dim 通常是 model_dim 的 8/3 倍”。
    这可不是随便写的数字,而是一个**“等效替换”**的数学题:
    • 标准 FFN:有 2 个线性层(Wup,WdownW_{up}, W_{down})。
    • SwiGLU:有 3 个线性层(Wgate,Wup,WdownW_{gate}, W_{up}, W_{down})。
      为了让 SwiGLU 的总参数量和标准 FFN 差不多,研究者通常把 intermediate_dim 缩小一点点。
    • 标准 FFN 是 4×model_dim4 \times \text{model\_dim}
    • SwiGLU 通常设为 23×42.67×model_dim\frac{2}{3} \times 4 \approx 2.67 \times \text{model\_dim}(即 8/3 倍)。
  • 关于 Sigmoid 和 SwiGLU 的区别:
    • Sigmoid:$$\text{Sigmoid}(x) = \frac{1}{1 + e^{-x}}$$
      • 它是一个“压制函数”,不管输入多大,输出永远在 (0, 1) 之间。
      • 局限性:当 xx 是负数时,它变成 0;当 xx 是正数时,它趋近 1。它像是一个开关,要么开要么关。
      • 问题:在深度网络中,Sigmoid 在 xx 很大或很小时梯度几乎为 0,这会导致“梯度消失”。
    • SiLU (Swish):$$\text{SiLU}(x) = x \cdot \text{Sigmoid}(x) = \frac{x}{1 + e^{-x}}$$它在 Sigmoid 的基础上乘以了输入本身 xx
      • 优势:当 xx 为正数时,它几乎是线性的(x1xx \cdot 1 \approx x)。当 xx 为负数时,它会有一个轻微的负值下潜,然后再回到 0。
      • 意义:这个小小的“下潜”允许模型在负数区域保持一定的梯度,且因为它具有非单调性(先下后上),它能捕捉到比 ReLU 和 Sigmoid 更复杂的特征。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLUFNN(nn.Module):
'''
SwiGLU 前馈神经网络
SwiGLUFNN(x) = Down( silu(Gate(x)) * Up(x) )
SiwGLU 可以更好地捕捉复杂的非线性关系
'''
def __init__(self, model_dim, intermediate_dim):
super().__init__()

self.model_dim = model_dim
self.intermediate_dim = intermediate_dim # intermediate_dim 通常是 model_dim 的 8/3 倍

self.w_gate = nn.Linear(model_dim, intermediate_dim)
self.w_up = nn.Linear(model_dim, intermediate_dim)
self.w_down = nn.Linear(intermediate_dim, model_dim)

def forward(self, x):
gate = F.silu(self.w_gate(x))
up = self.w_up(x)
down = self.w_down(gate * up) # 注意,我们这里应该用 * 进行逐元素相乘,不是 @ 矩阵乘法

return down

if __name__ == "__main__":
batch_size, seq_len, model_dim = 2, 3, 6
intermediate_dim = 8

x = torch.randn(batch_size, seq_len, model_dim)
ffn = SwiGLUFNN(model_dim, intermediate_dim)
out = ffn(x)

print(x)
print(out)
# tensor([[[-0.7441, -0.4776, -0.7285, -1.7437, 0.9644, -0.8784],
# [-1.5262, -0.4709, 0.0326, 0.6303, -0.8658, 1.1058],
# [-1.1282, -0.1326, -0.3696, 0.0927, 0.3892, 1.0497]],
#
# [[ 1.3472, 0.7196, 1.0456, 0.6882, 0.3920, -0.3255],
# [ 0.1005, -0.0137, -1.6010, -1.0572, 0.7091, -1.7054],
# [-0.8979, -0.7946, 0.7308, -2.0546, -0.2698, -0.0163]]])
# tensor([[[-0.1308, 0.1688, 0.0388, -0.1201, -0.2233, -0.0011],
# [-0.2041, 0.2709, -0.0328, 0.0195, 0.0214, 0.0874],
# [-0.1009, 0.3554, -0.0249, -0.0675, 0.0372, 0.0095]],
#
# [[-0.2625, 0.2757, -0.1333, 0.1463, 0.1986, 0.1686],
# [-0.0979, 0.1598, 0.0733, -0.1395, -0.2471, 0.0206],
# [-0.0568, 0.2394, 0.0434, -0.1560, -0.0331, -0.0535]]],
# grad_fn=<ViewBackward0>)

3.3 Mixture of Experts

混合专家模型

  • MoE (Mixture of Experts) 的本质是:让模型在参数量上变成“巨无霸”,但在计算量上保持“小清新”。
  • 在标准 Transformer 中,每一层 FFN 都是“全员上阵”。如果你想让模型更聪明,你就得增大 FFN 的维度,但这会让推理速度变慢。
  • MoE 的思维跳跃:
    • 我准备 100 个专家(FFN),每个专家擅长不同的知识(比如:专家 A 懂代码,专家 B 懂法语)。
    • 当一个 Token 进来时(比如“Bonjour”),Router(路由器) 发现这是法语,就只激活专家 B。
    • 结果:模型总参数量是 100 倍,但每次计算只耗费 1 个专家的算力。
  • 这里有一个隐藏的“大坑”:专家负载均衡
    • 在这段纯净的 MoE 代码中,隐藏着一个工业界极其头疼的问题:专家贫富差距。
    • 现象:Router 可能会发现某一个专家特别好用,导致所有的 Token 都往它那里跑(比如专家 1 处理了 99% 的 Token),而其他专家都在“带薪休假”。
    • 后果:
      • 专家 1 所在的显卡会爆显存。
      • 其他专家没有得到训练,模型退化成了单专家模型。
    • 解决方案:在实际的 DeepSeek 或 Llama 代码中,会加入一个 Auxiliary Loss(辅助损失),强迫 Router 把任务均匀地分配给所有专家。
  1. 为什么要“遍历专家”而不是“遍历 Token”?
  • 为什么代码写 for i, expert in enumerate(self.experts) 而不是直接循环每一个词?
  • GPU 效率: 如果你有 100 万个 Token,循环 100 万次会慢死。
  • 批处理 (Batching): 既然 Token A 和 Token C 都选了专家 0,我们把它们拼在一起一次性传给专家 0 的 Linear 层计算,这能利用 GPU 的并行能力。
  1. 函数声明 index_add_(dim, index, source)
  • 含义:带索引的累加。
  • 说明:这是一个“原地(In-place)”操作(函数名末尾的下划线 _ 表示会直接修改原张量)。它把 source 中的值,按照 index 指定的索引,累加到 self(即调用它的张量)中。
  • 关键点:如果多个索引指向同一个位置,它会将所有源值累加在一起。
  • 在 MoE 中的作用:因为很多 Token 可能都选择了同一个专家,它们计算出的结果需要“归位”到最终输出中。index_add_ 完美解决了多对一的映射与累加。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn as nn
import torch.nn.functional as F


class MoE(nn.Module):
'''
Mixture of Experts 混合专家模型
一种稀疏激活的神经网络架构,通过路由机制将输入分配给不同的专家网络
只有部分专家参与计算,从而在增加模型容量的同时保持计算效率
1. Router: 计算每个 token 对专家的得分
2. TopK 选择: 选择得分最高的 K 个专家
3. 专家计算: 每个专家对选择自己的 token 进行计算处理
4. 加权融合: 根据路由得分加权融合专家输出
'''
def __init__(self, model_dim, expert_num, top_k):
super().__init__()

self.model_dim = model_dim
self.expert_num = expert_num
self.top_k = top_k

# 路由计算: 每个 token 对专家的喜好
self.router = nn.Linear(model_dim, expert_num, bias=False)

# 专家网络列表
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(model_dim, model_dim * 4),
nn.ReLU(),
nn.Linear(model_dim * 4, model_dim)
) for _ in range(expert_num)
])

def forward(self, x):
batch_size, seq_len, model_dim = x.shape # [batch_size, seq_len, model_dim]
# 展平方便处理
x_flat = x.view(-1, model_dim) # [batch_size * seq_len, model_dim]

# 计算专家路由选择情况
gate_logits = self.router(x_flat) # [batch_size * seq_len, model_dim]
weights, indices = torch.topk(gate_logits, self.top_k, dim=-1) # [batch_size * seq_len, topk]
weights = F.softmax(weights, dim=-1) # [batch_size * seq_len, topk]

# 初始化输出
output = torch.zeros_like(x_flat) # [batch_size * seq_len, model_dim]
for i, expert in enumerate(self.experts):
mask = (indices == i)
token_indices, topk_pos = torch.where(mask)

if token_indices.numel() > 0: # number of element
# 这个专家对选择自己的 token 计算输出
expert_input = x_flat[token_indices] # [N, model_dim]
expert_ouput = expert(expert_input) # [N, model_dim]

# 加权融合
expert_weight = weights[token_indices, topk_pos] # [N]
weighted_expert_ouput = expert_ouput * expert_weight.unsqueeze(-1) # [N, model_dim] = [N, model_dim] * [N, 1]

# 累加到答案板上
output.index_add_(dim=0, index=token_indices, source=weighted_expert_ouput)

# 最后记得恢复形状
output = output.view(batch_size, seq_len, model_dim)
return output


if __name__ == "__main__":
batch_size, seq_len, model_dim = 2, 2, 6
expert_num, top_k = 4, 2

x = torch.randn(batch_size, seq_len, model_dim)
moe = MoE(model_dim, expert_num, top_k)
out = moe(x)

print(x.shape, "\n", x)
print(out.shape, "\n", out)
# torch.Size([2, 2, 6])
# tensor([[[ 0.0503, -0.0135, -0.0667, -0.3115, 0.0631, -0.2993],
# [-1.3180, 0.1936, -0.1778, -1.0257, 1.6202, 0.3969]],
#
# [[ 0.4592, 0.6902, 0.0411, 0.0283, -1.5601, -0.1610],
# [ 0.9260, 0.9828, 1.4012, -1.1552, 0.6786, 1.3436]]])
# torch.Size([2, 2, 6])
# tensor([[[ 0.0353, -0.0668, -0.1204, -0.0040, -0.1293, -0.0500],
# [ 0.1011, 0.1161, 0.0174, 0.1119, -0.1103, 0.0209]],
#
# [[ 0.1264, 0.2935, 0.0548, 0.0006, -0.3465, -0.0703],
# [ 0.1220, 0.3189, -0.2566, 0.1833, 0.2553, -0.0788]]],
# grad_fn=<ViewBackward0>)

四、损失函数

4.1 EntropyLoss

熵损失

  1. Softmax 函数
  • 标准公式:
    将一组得分(Logits)转化为概率分布,所有分量之和为 1。

Softmax(xi)=exij=1nexjSoftmax(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}

  • 数值稳定公式(代码实现版):
    M=max(x)M = \max(x),通过平移避免 exe^x 溢出。

Softmax(xi)=exiMj=1nexjMSoftmax(x_i) = \frac{e^{x_i - M}}{\sum_{j=1}^{n} e^{x_j - M}}

为什么要减去 max_logits?(数值稳定性的真相)

  • 这是面试中出镜率极高的问题。
  • 理论公式:Softmax(xi)=exiexjSoftmax(x_i) = \frac{e^{x_i}}{\sum e^{x_j}}
  • 现实危机:计算机表达浮点数是有极限的。如果 logits 中有一个数是 88.788.7,那么 e88.7e^{88.7} 就会超出单精度浮点数(Float32)的最大范围,变成 inf(无穷大)。一旦出现 inf,整个模型就崩了。
  • 解决方案:利用指数函数的性质:exiexj=exiCexjC\frac{e^{x_i}}{e^{\sum x_j}} = \frac{e^{x_i - C}}{\sum e^{x_j - C}}。通过减去最大值 CC,所有的指数项都会变成 e负数或0e^{\text{负数或0}},结果永远在 (0,1](0, 1] 之间。再大的数也被“驯服”了。
  1. Log Softmax 函数
  • 标准公式:
    Softmax 结果取自然对数。

LogSoftmax(xi)=ln(exij=1nexj)LogSoftmax(x_i) = \ln \left( \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} \right)

  • 数值稳定公式(利用 ln\ln 的性质展开):
    这种写法称为 Log-Sum-Exp 技巧。

LogSoftmax(xi)=xiln(j=1nexj)LogSoftmax(x_i) = x_i - \ln \left( \sum_{j=1}^{n} e^{x_j} \right)

代码中进一步优化为:xi(M+lnexjM)x_i - (M + \ln \sum e^{x_j - M})

Log Softmax:为什么要把它拆出来?

  • 既然有了 softmax 函数,为什么不直接写 torch.log(softmax(x))
  • 原因:因为 softmax 会把某些非常小的概率压成 00(由于精度限制)。如果你对 00 取 log,结果是 -\infty
  • Log-Sum-Exp 技巧:$$\log(\frac{e^{x_i}}{\sum e^{x_j}}) = x_i - \log(\sum e^{x_j})$$
  • 这种写法直接在对数空间操作,避免了先求 exe^x 再求 log\log 的精度损耗。这是大模型训练(如 Llama, DeepSeek)中交叉熵损失的标准写法。
  1. 交叉熵损失 (Cross Entropy Loss)
  • 标准公式:
    衡量真实分布 yy(通常是 one-hot 向量)与预测分布 pp 之间的差异。

L=i=1nyilog(pi)L = -\sum_{i=1}^{n} y_i \log(p_i)

  • 单标签分类简化公式(代码实现版):
    由于 yy 中只有一个位置是 1(正确类别 targettarget),其余都是 0,公式简化为:

Loss=log(ptarget)Loss = -\log(p_{target})

结合 Log Softmax,它直接等于:

Loss=(xtargetlnexj)Loss = -(x_{target} - \ln \sum e^{x_j})

  1. KL 散度 (Kullback-Leibler Divergence)
  • 标准公式:
    衡量概率分布 PP 相对 QQ 的偏离程度。

DKL(PQ)=i=1nP(xi)ln(P(xi)Q(xi))D_{KL}(P \parallel Q) = \sum_{i=1}^{n} P(x_i) \ln \left( \frac{P(x_i)}{Q(x_i)} \right)

  • 计算变形公式(代码实现版):
    为了计算方便,通常拆分为两个 Log Softmax 的差:

DKL(PQ)=i=1nP(xi)[lnP(xi)lnQ(xi)]D_{KL}(P \parallel Q) = \sum_{i=1}^{n} P(x_i) \cdot \left[ \ln P(x_i) - \ln Q(x_i) \right]

函数 输入 输出 关键操作
Softmax 任意实数 (0,1) 概率 归一化,和为 1
LogSoftmax 任意实数 (−∞,0] 解决概率太小时的精度丢失
CrossEntropy Logits + 标签 正数标量 只盯着“正确答案”的概率
KL Divergence 两个分布 正数标量 衡量两个分布有多“像”
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
'''
熵相关损失函数
'''
import torch

def softmax(logits):
'''
softmax(xi) = exp(xi-M) / sum(exp(xj-M))
输入: [n, class_num] 未归一化的得分
输出: [n, class_num] 归一化的概率分布
'''
M, _ = torch.max(logits, dim=-1, keepdim=True) # [batch_size, 1]
exp_shifted = torch.exp(logits - M) # [batch_size, class_num]
sum_exp_shifted = torch.sum(exp_shifted, dim=-1, keepdim=True) # [batch_size, 1]
out = exp_shifted / sum_exp_shifted # [batch_size, class_num]

return out


def log_softmax(logits):
'''
log_softmax(xi) = log[ exp(xi-M) / sum(exp(xi-M)) ]
= xi - M - log(sum(exp(xi-M)))
'''
M, _ = torch.max(logits, dim=-1, keepdim=True) # [batch_size, 1]
exp_shifted = torch.exp(logits - M) # [batch_size, class_num]
sum_exp = torch.sum(exp_shifted, dim=-1, keepdim=True) # [batch_size, 1]
log_sum_exp = torch.log(sum_exp) # [batch_size, 1]
out = logits - M - log_sum_exp # [batch_size, class_num]

return out


def cross_entropy_loss(logits, targets):
'''
交叉熵损失
输入: logits [batch_size, class_num] 未对数归一化的模型输出得分
targets [batch_size] 真实类别标签 [0, class_num)

最终公式: Loss = - (x_target - log(sum(exp(xj))))
'''
# 直接把所有的 log_softmax 算出来,选 target 的加和即可
log_prob = log_softmax(logits)
batch_indices = torch.arange(logits.shape[0])
loss = - log_prob[batch_indices, targets]

return loss.mean()


def KL_divergence(p_logits, q_logits):
'''
KL 散度
计算两个概率分布 P, Q 之间的差异(P 相对 Q 的偏离程度)
最终公式: D_KL(P || Q) = sum( P(xi) * ( log P(xi) - log Q(xi) ) )
'''
p_ls = log_softmax(p_logits) # [batch_size, class_num]
q_ls = log_softmax(q_logits) # [batch_size, class_num]
p_s = softmax(p_logits) # [batch_size, class_num]

kl_div = torch.sum(p_s * (p_ls - q_ls), dim=-1) # [batch_size], 是每一组 PQ 的偏离程度

return kl_div.mean() # 返回这一批数据的平均情况, 代表这个模型的性能情况

if __name__ == "__main__":
torch.manual_seed(2026)
# 1. softmax / log_softmax 数值稳定性
# 特意加入一个很大的值 (100.0) 来测试是否会溢出产生 nan
logits = torch.tensor([[10.0, 5.0, 100.0],
[1.0, 2.0, 3.0]])

prob = softmax(logits)
log_prob = log_softmax(logits)
official_prob = torch.nn.functional.softmax(logits, dim=-1)
official_log_prob = torch.nn.functional.log_softmax(logits, dim=-1)

print(prob)
print(official_prob)
print(log_prob)
print(official_log_prob)
# tensor([[8.1940e-40, 5.5211e-42, 1.0000e+00],
# [9.0031e-02, 2.4473e-01, 6.6524e-01]])
# tensor([[8.1940e-40, 5.5211e-42, 1.0000e+00],
# [9.0031e-02, 2.4473e-01, 6.6524e-01]])
# tensor([[-90.0000, -95.0000, 0.0000],
# [ -2.4076, -1.4076, -0.4076]])
# tensor([[-90.0000, -95.0000, 0.0000],
# [ -2.4076, -1.4076, -0.4076]])

# 2. CE 交叉熵损失
batch_size, class_num = 2, 3
logits = torch.randn(batch_size, class_num)
targets = torch.tensor([1, 2]) # 第一个样本真类是1,第二个是2

my_ce = cross_entropy_loss(logits, targets)
official_ce = torch.nn.functional.cross_entropy(logits, targets)

print(f"{my_ce.item():.6f}")
print(f"{official_ce.item():.6f}")
# 0.876983
# 0.876983
assert torch.allclose(my_ce, official_ce), "CE Loss 不匹配!"

# 3. KL 散度 (模拟两个分布,p_logits 是模型 A,q_logits 是模型 B)
batch_size, class_num = 2, 3
p_logits = torch.randn(batch_size, class_num)
q_logits = torch.randn(batch_size, class_num)

my_kl = KL_divergence(p_logits, q_logits)
p_s = torch.nn.functional.softmax(p_logits, dim=-1) # 官方 F.kl_div 默认不包含 log_softmax 步骤
q_ls = torch.nn.functional.log_softmax(q_logits, dim=-1)
official_kl = torch.nn.functional.kl_div(q_ls, p_s, reduction='batchmean')

# 注意:根据 P, Q 定义顺序不同,正负号或数值可能有微差,但逻辑一致即 OK
print(f"{my_kl.item():.6f}")
print(f"{official_kl.item():.6f}")
# 0.557274
# 0.557274

以上部分已经完成学习,以下部分对于推荐算法重要性不是特别大,以后学习,暂时把内容先放在这里。

4.2 SFT Loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
监督微调损失(Supervised Fine-Tuning Loss)

用于大语言模型监督微调的损失函数。
与预训练损失类似,但支持屏蔽 prompt 部分,只计算 response 的损失。
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class SFTLoss(nn.Module):
"""
监督微调损失模块

在 SFT 阶段,我们通常只希望计算 response 部分的损失,
而不计算 prompt 部分的损失。该模块支持通过 prompt_lengths 来屏蔽 prompt。

Args:

"""

def __init__(self):
super().__init__()

def forward(self, logits, labels, prompt_lengths):
"""
前向传播

Args:
logits: 模型输出的未归一化对数概率 [batch_size, seq_len, vocab_size]
labels: 真实词元索引 [batch_size, seq_len]
prompt_lengths: 每个样本的 prompt 长度 [batch_size]

Returns:
loss: 标量损失值
"""
# 步骤1: 构造 masked labels
# 将 prompt 部分的标签设为 -100(ignore_index)
masked_labels = labels.clone()
for batch_idx, prompt_length in enumerate(prompt_lengths):
masked_labels[batch_idx, :prompt_length] = -100 # 设置为 ignore_index

# 步骤2: 移位操作(Shift)
# 预测下一个词:logits 去掉最后一个,labels 去掉第一个
# shifted_logits: [batch_size, seq_len-1, vocab_size]
shifted_logits = logits[:, :-1, :].contiguous()

# shifted_labels: [batch_size, seq_len-1]
shifted_labels = masked_labels[:, 1:].contiguous()

# 步骤3: 展平张量
batch_size, seq_length, vocab_size = shifted_logits.size()

# flattened_logits: [batch_size * (seq_len-1), vocab_size]
flattened_logits = shifted_logits.view(-1, vocab_size)

# flattened_labels: [batch_size * (seq_len-1)]
flattened_labels = shifted_labels.view(-1)

# 步骤4: 计算交叉熵损失
# ignore_index=-100 的位置(prompt 部分)不参与损失计算
loss = F.cross_entropy(flattened_logits, flattened_labels, ignore_index=-100)

return loss

4.3 DPO Loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
直接偏好优化损失(Direct Preference Optimization Loss)

DPO 是一种不使用奖励模型的 RLHF 替代方案。
通过直接优化人类偏好数据来对齐语言模型,避免了训练奖励模型的复杂性。

参考论文: Direct Preference Optimization: Your Language Model is Secretly a Reward Model
"""

import torch
import torch.nn.functional as F


def dpo_loss(policy_chosen_logps, policy_rejected_logps,
ref_chosen_logps, ref_rejected_logps,
beta=0.1, label_smoothing=0.0):
"""
计算 DPO 损失

DPO 的核心思想是:将奖励函数参数化为策略和参考策略的对数比率,
然后直接在偏好数据上优化这个隐式奖励。

公式: L_DPO = -E[log sigmoid(β * (log(π_θ(y_w|x) / π_ref(y_w|x)) - log(π_θ(y_l|x) / π_ref(y_l|x))))]

Args:
policy_chosen_logps: 策略模型对 chosen 回复的对数概率 [batch_size]
policy_rejected_logps: 策略模型对 rejected 回复的对数概率 [batch_size]
ref_chosen_logps: 参考模型对 chosen 回复的对数概率 [batch_size]
ref_rejected_logps: 参考模型对 rejected 回复的对数概率 [batch_size]
beta: KL 散度的缩放因子,控制对参考模型的偏离程度,默认 0.1
label_smoothing: 标签平滑系数,默认 0.0

Returns:
loss: DPO 损失值(标量)
"""
# 步骤1: 计算对数比率(隐式奖励)
# chosen_ratio = log(π_θ(y_w|x) / π_ref(y_w|x))
chosen_ratio = policy_chosen_logps - ref_chosen_logps

# rejected_ratio = log(π_θ(y_l|x) / π_ref(y_l|x))
rejected_ratio = policy_rejected_logps - ref_rejected_logps

# 步骤2: 计算 DPO logits(chosen 和 rejected 的奖励差)
# logits: [batch_size]
logits = chosen_ratio - rejected_ratio

# 步骤3: 计算 DPO 损失
# dpo_loss = -log(sigmoid(β * logits))
dpo_loss = -F.logsigmoid(beta * logits).mean()

# 步骤4: 应用标签平滑(可选)
if label_smoothing > 0.0:
# 标签平滑版本:混合正向和反向损失
inverse_loss = -F.logsigmoid(-beta * logits).mean()
dpo_loss = (1 - label_smoothing) * dpo_loss + label_smoothing * inverse_loss

return dpo_loss

4.4 PPO Loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
近端策略优化损失(Proximal Policy Optimization Loss)

PPO 是一种基于置信域的策略梯度算法,通过截断重要性采样比率
来限制策略更新的幅度,保证训练的稳定性。

参考论文: Proximal Policy Optimization Algorithms
"""

import torch
import numpy as np
import matplotlib.pyplot as plt


def ppo_clip_loss(old_log_probs, new_log_probs, advantages, clip_epsilon=0.2):
"""
计算 PPO 截断损失

PPO 通过截断重要性采样比率来限制策略更新的幅度:
L_CLIP = E[min(r_t * A_t, clip(r_t, 1-ε, 1+ε) * A_t)]

其中 r_t = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) 是重要性采样比率。

Args:
old_log_probs: 旧策略的对数概率 [batch_size]
new_log_probs: 新策略的对数概率 [batch_size]
advantages: 优势估计值 [batch_size]
clip_epsilon: 截断参数 ε,默认 0.2

Returns:
loss: PPO 截断损失值(标量)
"""
# 步骤1: 计算重要性采样比率
# r_t = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) = exp(log π_θ - log π_θ_old)
ratio = torch.exp(new_log_probs - old_log_probs)

# 步骤2: 计算截断后的比率
# 将比率限制在 [1 - ε, 1 + ε] 范围内
clipped_ratio = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon)

# 步骤3: 计算代理损失
# surrogate1: 未截断的目标 r_t * A_t
surrogate1 = ratio * advantages

# surrogate2: 截断后的目标 clip(r_t) * A_t
surrogate2 = clipped_ratio * advantages

# 步骤4: 取两者中的较小值(保守更新)
# 取最小值使得:当 A > 0 时,限制奖励增长;当 A < 0 时,限制惩罚增长
loss = -torch.mean(torch.min(surrogate1, surrogate2))

return loss


def plot_ppo_clip():
"""
绘制 PPO 截断函数的可视化图表

展示 PPO 在不同优势值下如何限制策略更新的幅度。
"""
# 设定 r 的范围 (0 到 2)
r = np.linspace(0, 2, 200)
epsilon = 0.2

# 创建画布
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# --- 情况 1: Advantage > 0 (这是一个好动作) ---
A_pos = 1.0
# 1. 未截断的收益 (r * A)
obj_unclipped_pos = r * A_pos
# 2. 截断的收益 (clip(r) * A)
obj_clipped_pos = np.clip(r, 1 - epsilon, 1 + epsilon) * A_pos
# 3. PPO 最终收益 (取最小值 min)
obj_ppo_pos = np.minimum(obj_unclipped_pos, obj_clipped_pos)

# 绘图
ax1.plot(r, obj_unclipped_pos, 'g--', label='Unclipped (r*A)', alpha=0.5)
ax1.plot(r, obj_clipped_pos, 'b--', label='Clipped (clip*A)', alpha=0.5)
ax1.plot(r, obj_ppo_pos, 'r-', linewidth=3, label='PPO Reward (Min)')

# 标注区域
ax1.set_title(f'Case 1: Advantage > 0 (Good Action)\nLimit Reward for large change')
ax1.axvline(x=1 + epsilon, color='k', linestyle=':', label='1+epsilon')
ax1.set_xlabel('Probability Ratio r_t')
ax1.set_ylabel('Reward L')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.text(1.3, 1.05, 'Gradient = 0\n(Stop Updating)', color='red', fontweight='bold')

# --- 情况 2: Advantage < 0 (这是一个坏动作) ---
A_neg = -1.0
# 1. 未截断的收益
obj_unclipped_neg = r * A_neg
# 2. 截断的收益
obj_clipped_neg = np.clip(r, 1 - epsilon, 1 + epsilon) * A_neg
# 3. PPO 最终收益 (取最小值 min)
# 注意:因为 A 是负的,min 会发挥"悲观"作用
obj_ppo_neg = np.minimum(obj_unclipped_neg, obj_clipped_neg)

# 绘图
ax2.plot(r, obj_unclipped_neg, 'g--', label='Unclipped (r*A)', alpha=0.5)
ax2.plot(r, obj_clipped_neg, 'b--', label='Clipped (clip*A)', alpha=0.5)
ax2.plot(r, obj_ppo_neg, 'r-', linewidth=3, label='PPO Reward (Min)')

# 标注区域
ax2.set_title(f'Case 2: Advantage < 0 (Bad Action)\nLimit Penalty for large change')
ax2.axvline(x=1 - epsilon, color='k', linestyle=':', label='1-epsilon')
ax2.set_xlabel('Probability Ratio r_t')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.text(0.1, -0.7, 'Gradient = 0\n(Stop Updating)', color='red', fontweight='bold')

plt.tight_layout()
plt.show()


# 运行绘图
plot_ppo_clip()

4.5 GRPO Loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
组相对策略优化损失(Group Relative Policy Optimization Loss)

GRPO 是 DeepSeek 提出的一种 RLHF 算法,是 PPO 的简化变体。
主要特点是去掉了 Critic 网络,使用组内归一化来计算优势函数。

参考论文: DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models
"""

import torch


def compute_grpo_advantages(rewards):
"""
计算 GRPO 中的优势函数

GRPO 去掉了 Critic 网络,使用组内归一化代替传统的优势估计。
对于每个问题生成一组回答,然后计算组内的相对优势。

公式: A = (R - mean(R)) / (std(R) + eps)

Args:
rewards: 每个回答的奖励值 [batch_size, group_size]

Returns:
advantages: 归一化后的优势值 [batch_size, group_size]
"""
# 计算组内均值
# mean: [batch_size, 1]
mean = rewards.mean(dim=-1, keepdim=True)

# 计算组内标准差
# std: [batch_size, 1]
std = rewards.std(dim=-1, keepdim=True)

# 归一化得到优势值
# advantages: [batch_size, group_size]
advantages = (rewards - mean) / (std + 1e-8)

return advantages


def grpo_loss(old_log_probs, new_log_probs, advantages, clip_epsilon=0.2, beta=0.01, ref_kl=None):
"""
计算 GRPO 损失

GRPO 的损失函数与 PPO 类似,但通常包含显式的 KL 散度惩罚项。
公式: L = E[min(r_t * A_t, clip(r_t) * A_t)] + β * KL(π || π_ref)

Args:
old_log_probs: 旧策略的对数概率 [batch_size]
new_log_probs: 新策略的对数概率 [batch_size]
advantages: 优势估计值 [batch_size]
clip_epsilon: 截断参数 ε,默认 0.2
beta: KL 散度惩罚系数,默认 0.01
ref_kl: 当前策略与参考策略的 KL 散度(可选)[batch_size]

Returns:
loss: GRPO 损失值(标量)
"""
# 步骤1: 计算重要性采样比率
# r_t = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)
ratio = torch.exp(new_log_probs - old_log_probs)

# 步骤2: 计算截断后的比率
# 将比率限制在 [1 - ε, 1 + ε] 范围内
clipped_ratio = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon)

# 步骤3: 计算代理损失
# surrogate1: 未截断的目标 r_t * A_t
surrogate1 = ratio * advantages

# surrogate2: 截断后的目标 clip(r_t) * A_t
surrogate2 = clipped_ratio * advantages

# 取两者中的较小值
policy_loss = -torch.min(surrogate1, surrogate2)

# 步骤4: GRPO 特有的 KL 正则项(DeepSeek 做法)
# DPO 把 KL 藏在 Loss 里,GRPO 通常显式地加一个 KL 惩罚
# loss = policy_loss + β * KL(π || π_ref)
if ref_kl is not None:
return (policy_loss + beta * ref_kl).mean()

return policy_loss.mean()


def compute_kl_penalty(log_probs, ref_log_probs):
"""
计算 KL 散度惩罚项

使用 Schulman 估计器计算 KL 散度,该方法具有较低的方差。

公式: KL(P || Q) ≈ E_P[exp(log Q - log P) - (log Q - log P) - 1]

推导:
KL(P || Q) = E_P[log P(x) - log Q(x)]
= E_P[exp(log Q(x) - log P(x)) - (log Q(x) - log P(x)) - 1]
(通过 Taylor 展开得到无偏估计)

Args:
log_probs: 当前策略的对数概率 [batch_size]
ref_log_probs: 参考策略的对数概率 [batch_size]

Returns:
kl: KL 散度的估计值(标量)
"""
# Schulman 估计器
# ratio = exp(log Q - log P) = Q / P
ratio = torch.exp(ref_log_probs - log_probs)

# KL = E_P[ratio - log_ratio - 1]
# 其中 log_ratio = log Q - log P
kl = ratio - (ref_log_probs - log_probs) - 1

return kl.mean()

五、其他

5.1 位置编码 Rotary Position Embedding (RoPE)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
旋转位置编码(Rotary Position Embedding, RoPE)

一种将位置信息编码到注意力机制中的方法,通过旋转向量来表示相对位置。
具有良好的外推性和位置感知能力。

参考论文: RoFormer: Enhanced Transformer with Rotary Position Embedding
"""

import torch
import torch.nn as nn


class RotaryEmbedding(nn.Module):
"""
旋转位置编码模块

通过将查询和键向量按照位置进行旋转,使注意力分数包含相对位置信息。
公式: f(x, m) = x * cos(m*θ) + rotate_half(x) * sin(m*θ)

Args:
head_dim: 旋转编码的维度(通常是 attention head 的维度)
max_seq_len: 支持的最大序列长度
theta: 旋转角度的基数,默认 10000.0
"""

def __init__(self, head_dim, max_seq_len=2048, theta=10000.0):
super().__init__()
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.theta = theta

# 预计算 cos 和 sin 值,避免重复计算
cos, sin = self.precompute_freqs(head_dim, max_seq_len, theta)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)

def precompute_freqs(self, head_dim, max_seq_len, theta):
"""
预计算旋转位置编码的频率

Args:
head_dim: 编码维度
max_seq_len: 最大序列长度
theta: 旋转基数

Returns:
cos: 余弦值 [max_seq_len, head_dim]
sin: 正弦值 [max_seq_len, head_dim]
"""
# 计算逆频率: 1 / (theta^(2i/d))
# inv_freqs: [head_dim/2]
inv_freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))

# 位置索引
# t: [max_seq_len]
t = torch.arange(max_seq_len, device=inv_freqs.device, dtype=torch.float32)

# 计算角度矩阵: outer(t, inv_freqs)
# angles: [max_seq_len, head_dim/2]
angles = torch.outer(t, inv_freqs)

# 将角度复制一份以匹配完整维度
# angles: [max_seq_len, head_dim]
angles = torch.cat((angles, angles), dim=-1)

return angles.cos(), angles.sin()

def forward(self, xq, xk):
"""
对查询和键应用旋转位置编码

Args:
xq: 查询张量 [batch_size, seq_len, num_heads, head_dim]
xk: 键张量 [batch_size, seq_len, num_heads, head_dim]

Returns:
xq_rotated: 旋转后的查询 [batch_size, seq_len, num_heads, head_dim]
xk_rotated: 旋转后的键 [batch_size, seq_len, num_heads, head_dim]
"""
seq_len = xq.size(1)

# 获取当前位置的 cos 和 sin 值
# cos, sin: [1, seq_len, 1, head_dim]
cos = self.cos[:seq_len].view(1, seq_len, 1, self.head_dim)
sin = self.sin[:seq_len].view(1, seq_len, 1, self.head_dim)

def rotate_half(x):
"""
将张量分成两半并旋转

Args:
x: 输入张量 [..., head_dim]

Returns:
旋转后的张量 [..., head_dim]
"""
# x1, x2: [..., head_dim/2]
x1, x2 = torch.chunk(x, 2, dim=-1)
# 拼接为 [-x2, x1]: [..., head_dim]
return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置编码
# 公式: x * cos + rotate_half(x) * sin
#
# 详细推导:
# 设 x = [x1, x2],则 rotate_half(x) = [-x2, x1]
# x * cos = [x1*cos, x2*cos]
# rotate_half(x) * sin = [-x2*sin, x1*sin]
# 结果 = [x1*cos - x2*sin, x2*cos + x1*sin]
#
# 等价于旋转矩阵: [cos, sin; -sin, cos] @ [x1; x2] = [x1'; x2']

xq_rotated = (xq * cos) + (rotate_half(xq) * sin)
xk_rotated = (xk * cos) + (rotate_half(xk) * sin)

return xq_rotated, xk_rotated

5.2 参数高效微调 LoRA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
LoRA 线性层(Low-Rank Adaptation Linear)

一种高效的参数高效微调(PEFT)方法,通过低秩分解来近似权重更新。
只训练少量参数即可达到接近全量微调的效果。

参考论文: LoRA: Low-Rank Adaptation of Large Language Models
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class LoRALinear(nn.Module):
"""
LoRA 线性层模块

将权重更新分解为两个低秩矩阵的乘积: ΔW = B @ A
前向传播: y = W @ x + (B @ A) @ x * scaling

优势:
- 大幅减少可训练参数(rank << min(in_features, out_features))
- 原始权重冻结,支持即插即用
- 训练完成后可将 LoRA 权重合并到原始权重中

Args:
in_features: 输入特征维度
out_features: 输出特征维度
rank: LoRA 的秩(低秩矩阵的维度),默认 8
alpha: LoRA 的缩放因子,默认 1.0
dropout: Dropout 概率,默认 0.0
"""

def __init__(self, in_features, out_features, rank=8, alpha=1.0, dropout=0.0):
super().__init__()

# 原始预训练权重(冻结)
self.weight = nn.Linear(in_features, out_features, bias=False)
self.weight.requires_grad = False # 冻结原始权重

# LoRA 低秩适应矩阵
# A: [in_features, rank] - 下投影
# B: [rank, out_features] - 上投影
self.lora_a = nn.Linear(in_features, rank, bias=False)
self.lora_b = nn.Linear(rank, out_features, bias=False)

self.alpha = alpha
self.rank = rank

# 缩放因子:alpha / rank
# 用于平衡 LoRA 输出和原始输出的比例
self.scaling = self.alpha / rank

self.dropout = nn.Dropout(dropout)

# 初始化 LoRA 权重
self.reset_parameters()

def reset_parameters(self):
"""
初始化 LoRA 权重

- A 使用 Kaiming 均匀初始化
- B 初始化为零,确保初始状态时 LoRA 输出为零
"""
# A 使用 Kaiming 均匀初始化
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))

# B 初始化为零,确保初始输出等于预训练模型输出
nn.init.zeros_(self.lora_b.weight)

def forward(self, x):
"""
前向传播

Args:
x: 输入张量 [batch_size, seq_len, in_features]

Returns:
输出张量 [batch_size, seq_len, out_features]
"""
# 步骤1: 计算原始输出(不计算梯度)
# original_output: [batch_size, seq_len, out_features]
with torch.no_grad():
original_output = self.weight(x)

# 步骤2: 计算 LoRA 增量输出
# x -> dropout -> A -> B -> scaling
# lora_output: [batch_size, seq_len, out_features]
lora_output = self.lora_b(self.lora_a(self.dropout(x))) * self.scaling

# 步骤3: 合并原始输出和 LoRA 增量
return original_output + lora_output


# --- 测试代码 ---
if __name__ == "__main__":
x = torch.randn(2, 5, 10) # batch_size, seq_len, in_features
# 假设原模型 10 -> 20
layer = LoRALinear(10, 20, rank=4)

out = layer(x)
print(f"Output shape: {out.shape}")

# 验证初始状态 LoRA 是否为 0
# 理论上初始输出应该等于 pretrained 输出
diff = (out - layer.weight(x)).abs().sum()
print(f"Diff at init (should be 0): {diff.item()}")