MixtralForCausalLM DeepSpeed Inference节约HOST内存【最新的方案】
本文演示了MixtralForCausalLM DeepSpeed Inference如果节约HOST内存
方法:每个rank分别保存,并且使用accelerate的init_empty_weights
增加的功能:
- safetensors分块的存储与加载
- 解决register_buffer persistent=False,参数初始化的问题
一.效果
运行方式 | HOST内存占用 | 备注 |
---|---|---|
单卡推理 | 13198 MB | |
DS 4TP | 13246 MB/GPU | |
DS 4TP 优化内存占用后 | 369 MB/GPU | 直接加载到设备,更节约HOST内存 |
二.特别说明
- 1.MixtralRotaryEmbedding中self.register_buffer(“sin_cached”, emb.sin().to(dtype), persistent=False)
因为persistent为False。所以不会保存到state_dict中,module.to_empty(device)也不会保留它的值
只能在模型初始化之后保存出来,之后engine.moudle加载完权值之后再把这个buffer替换进去
三.测试步骤
1.创建Mixtral-8x7B配置文件(简化了)
mkdir skip_init_demo
cd skip_init_demo
tee ./config.json <<-'EOF'
{
"architectures": [
"MixtralForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"max_position_embeddings": 1024,
"model_type": "mixtral",
"num_attention_heads": 32,
"num_experts_per_tok": 2,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"num_local_experts": 8,
"output_router_logits": false,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000.0,
"router_aux_loss_coef": 0.02,
"sliding_window": 128,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.36.0.dev0",
"use_cache": true,
"vocab_size": 32000
}
EOF
2.生成随机模型,运行cpu float32推理,输出结果
rm -rf Mixtral-8x7B
tee gen_model.py <<-'EOF'
import torch
import os
import time
def main():
torch.manual_seed(1)
from transformers import MixtralForCausalLM, MixtralConfig
config=MixtralConfig.from_pretrained("./config.json")
model = MixtralForCausalLM(config).half()
model.eval()
model.save_pretrained("./Mixtral-8x7B",safe_serialization=True)
torch.manual_seed(2)
input_tokens=torch.randint(0,32000,(1,128))
model=model.float()
output=model(input_tokens)
output=output.logits.detach().reshape(-1).cpu().numpy()[:8]
print(output)
if __name__ == "__main__":
main()
EOF
python gen_model.py
du Mixtral-8x7B -lh
输出
6.3G Mixtral-8x7B
[-0.9623295 -0.36580455 0.767425 1.7021806 -0.17950581 0.36059803
-0.49157432 -0.58618194]
3.加载模型,cuda 单卡推理