【极简】如何估算大模型inference所需的内存量

1字节=8bit
16float=2字节
模型后面的xxb的单位是字节。
1b 字节≈ 0.93G,这个是以8bit运行,4bit减半,16bit(float)加倍,32bit(double)炒鸡加倍。

剩下的是小头,需要参数计算:

  • s:最大序列长度(输入中的令牌数量)
  • b:批大小
  • h:模型的隐藏维度
  • a:注意头的数量

对于整个层
总内存需求总计为11sbh + 5as²b(来自注意力块)+ 19sbh(来自MLP块)+ 4sbh(来自LN)
每层激活内存消耗= 34 sbh + 5as²b

小头一般远小于10G。

所以比如llama7b,只需要7*0.93≈9G,再加10,内存19G就可以(实际会更少,因为小头远低于10G),注意这个是以8bit运行,4bit减半,16bit(float)加倍,32bit(double)炒鸡加倍。

感谢博客:https://developer.aliyun.com/article/1496103
感谢github:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值