3 完整的Transformer
基于之前所实现过的组件,我们实现完整的 Transformer 模型:
3.1 Transformer 类详解
class Transformer(nn.Module):
'''整体模型'''
def __init__(self, args):
super().__init__()
# 必须输入词表大小和 block size
assert args.vocab_size is not None
assert args.block_size is not None
self.args = args
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(args.vocab_size, args.n_embd),
wpe = PositionalEncoding(args),
drop = nn.Dropout(args.dropout),
encoder = Encoder(args),
decoder = Decoder(args),
))
# 最后的线性层,输入是 n_embd,输出是词表大小
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
# 初始化所有的权重
self.apply(self._init_weights)
# 查看所有参数的数量
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
'''统计所有参数的数量'''
def get_num_params(self, non_embedding=False):
# non_embedding: 是否统计 embedding 的参数
n_params = sum(p.numel() for p in self.parameters())
# 如果不统计 embedding 的参数,就减去
if non_embedding:
n_params -= self.transformer.wpe.weight.numel(