Gemma模型论文详解(附源码)

原文链接:Gemma模型论文详解(附源码)

1. 背景介绍

Gemma模型是在2024.2.21号Google新发布的大语言模型, Gemma复用了Gemini相同的技术(Gemini也是Google发布的多模态模型),Gemma这次发布了了2B和7B两个版本的参数,不仅提供了预训练的checkpoints,还提供了用于对话、指令跟随等fine-tune的checkpoints。在QA问答、常识。在11

在这里插入图片描述

2. 模型介绍

2.1 模型结构

Gemma模型使用了transformer decoder结构进行训练,训练的上下文大小为8192个token,模型参数如下:
在这里插入图片描述

相比原始transformer结构的区别:

  • Multi-Query Attention:7B模型使用了multi-head attention,2B模型使用了multi-query attention (with 𝑛𝑢𝑚_𝑘𝑣_ℎ𝑒𝑎𝑑𝑠 = 1)。对比llama2中用了group-query attention
    在这里插入图片描述

  • RoPE Embeddings: 不使用绝对位置编码,在每一层前加下RoPE Embedding,同时共享输入与输出层的embedding权重。

  • GeGLU Activations: ReLU的激活替换为GeGLU的激活。对比llama中用了swiglu。

  • Normalizer Location: 在transformer的每一层layer的前后都进行规一化,这里使用RMSNorm做为规一化层。

2.2 训练搭建

Gemma使用TPUv5e进行训练;一个pod中有256块TPUv5e芯片,256块芯片被设计为16X16的2D拓扑;Gemma-7B使用16个pods(4096块卡)进行训练,Gemma-2B使用2个pods(512块卡)。7B模型在一个pod内使用16路模型并行和16路数据并行,2B模型在一个pod内使用256路数据并行。优化器状态使用ZeRO-3进行切分,减少显存占用。在pod外使用类似Pathways的方式减少数据复制的成本。

和Gemini模型训练一样,综合了Jax和Pathways的单控制器single controller编程范式,使用单个python进程编排整个训练; 使用GSPMD partitioner用于训练step的计算,使用XLA compiler减少中间结果的大小。

2.3 训练数据

Gemma 2B和7B分别基于2T和6T个token进行训练,token来源于纯英文的文本,内容包括网页、数学、代码等。使用SentencePiece的tokenizer,字典大小有256K个token。数据过滤使用基于模型的分类器去除有害的、低质量的内容。最后采用类似Gemini的方式进行训练数据的混合,提升高质量数据的占比。

2.4 指令微调(Instruction Tuning)

2B和7B进行有监督微调(SFT)训练中使用混合生成数据和人工标注的prompt文本对,同时进行RLHF训练。在SFT阶段,基于给定的一个prompt,通过测试模型生成多个响应的回答结果,通过一个更大更好的模型进行结果的好坏判断。基于不同的侧重方向(指令跟随/事实/创造性/安全等)构建不同的prompt。使用多种基于LM的自动判断方法,比如chain-of-thought prompting

训练和推理过程中使用相同的数据格式,格式的设计重点在于两点,一个是确定多轮对话中的角色,一个是确定一轮对话的开始结束。对应格式标记和示例的训练数据如下:

在这里插入图片描述
在这里插入图片描述

3. 源码

@dataclasses.dataclass
class GemmaConfig:
    # The number of tokens in the vocabulary.
    vocab_size
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MLTalks

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值