本文是LLM系列文章,针对《JORA: JAX Tensor-Parallel LoRA Library for Retrieval Augmented
Fine
摘要
用于基于检索的任务的大型语言模型(LLM)的扩展,特别是在检索增强生成(RAG)中,面临着巨大的内存限制,尤其是在微调大量提示序列时。当前的开源库支持跨多个GPU的全模型推理和微调,但无法适应检索上下文所需的高效参数分布。为了弥补这一差距,我们引入了一种新的框架,利用分布式训练对Llama-2模型进行PEFT兼容的微调。我们的框架独特地利用了JAX的实时(JIT)编译和张量分片来实现高效的资源管理,从而加速了微调,降低了内存需求。这一进步显著提高了为复杂RAG应用程序微调LLM的可扩展性和可行性,即使在GPU资源有限的系统上也是如此。我们的实验表明,与使用四个GPU的Hugging Face/DeepSpeed实现相比,运行时间提高了12倍以上,而每个GPU消耗的VRAM不到一半。
1 引言
2 背景
3 JORA框架
4 实验
5 使用场景示例
6 结论
本文介绍了JORA,这是一个基于JAX的库,用于Llama-2模型的检索增强微调。JORA为数据操作和训练提供了方便的功能。此外,它还实现了内存效率和性能训练的最佳实践。通过使用LoRA、张量并行和jit的