代码
import torch
import torch.nn as nn
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
代码解释
类定义与初始化
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
- 继承自
nn.Module
,是 PyTorch 的标准模块基类 - 初始化方法接收两个参数:
hidden_size
:隐藏层的大小,决定了权重参数的维度eps
:一个小常数,用于数值稳定性,默认为 1e-6
- 创建了一个可学习的权重参数
weight
,初始化为全 1 向量 - 保存 epsilon 值用于后续计算
前向传播方法
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
前向传播方法实现了 RMSNorm 的核心计算:
- 保存输入张量的原始数据类型
- 将输入转换为 float32 类型以确保计算精度
- 计算每个特征向量的均方值(沿最后一个维度):
pow(2)
对每个元素求平方mean(-1, keepdim=True)
沿最后一个维度计算均值,保持维度
- 使用
torch.rsqrt
计算均方根的倒数(1/sqrt(x)) - 将输入张量乘以这个倒数进行归一化
- 应用可学习的权重参数进行缩放
- 将结果转换回原始数据类型并返回
额外表示方法
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
这个方法用于自定义模块的字符串表示,当打印模型时会显示权重形状和 epsilon 值。
与 LayerNorm 的区别
RMSNorm 与标准 LayerNorm 的主要区别在于:
- LayerNorm 会减去均值并除以标准差
- RMSNorm 只除以均方根,不减去均值
这种设计使 RMSNorm 计算更高效,同时保持了良好的归一化效果,特别适合大型语言模型。
代码运行结果
import torch
import torch.nn as nn
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
print(f"输入 hidden_states 的维度: {hidden_states.shape}")
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
print(f"转换为float32后的维度: {hidden_states.shape}")
variance = hidden_states.pow(2).mean(-1, keepdim=True)
print(f"variance的维度: {variance.shape}")
rsqrt_result = torch.rsqrt(variance + self.variance_epsilon)
print(f"rsqrt结果的维度: {rsqrt_result.shape}")
hidden_states = hidden_states * rsqrt_result
print(f"归一化后的维度: {hidden_states.shape}")
weighted_result = self.weight * hidden_states.to(input_dtype)
print(f"加权后的最终输出维度: {weighted_result.shape}")
return weighted_result
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
def main():
# 设置随机种子以确保结果可重现
torch.manual_seed(42)
# 创建一个RMSNorm实例,隐藏层大小为8
hidden_size = 8
rms_norm = DeepseekV3RMSNorm(hidden_size=hidden_size)
# 创建一个随机输入张量
# 形状为 [2, 4, 8],表示 batch_size=2, seq_length=4, hidden_size=8
x = torch.randn(2, 4, hidden_size)
# 应用RMSNorm
normalized_x = rms_norm(x)
main()
输入 hidden_states 的维度: torch.Size([2, 4, 8])
转换为float32后的维度: torch.Size([2, 4, 8])
variance的维度: torch.Size([2, 4, 1])
rsqrt结果的维度: torch.Size([2, 4, 1])
归一化后的维度: torch.Size([2, 4, 8])
加权后的最终输出维度: torch.Size([2, 4, 8])