# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
import abc
import os
import sys
import re
import json
from types import SimpleNamespace
import logging as logger
from pathlib import Path
from collections import OrderedDict
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForSequenceClassification
from peft import get_peft_model, LoraConfig, TaskType
from megatron.core import mpu
from megatron.training.arguments import validate_args
from megatron.legacy.model import module
from megatron.core.enums import ModelType
from megatron.training.checkpointing import load_args_from_checkpoint
from megatron.training.global_vars import set_args
from megatron.training.checkpointing import load_checkpoint
from megatron.core import tensor_parallel
from mindspeed_llm.training.utils import parse_args
from mindspeed_llm.training import model_provider_func_wrapper
from mindspeed_llm.training.checkpointing import load_checkpoint_wrapper
logger.basicConfig(format="")
logger.getLogger().setLevel(logger.INFO)
load_checkpoint = load_checkpoint_wrapper(load_checkpoint)
class ModelBase(abc.ABC):
def __init__(self, args_cmd=None):
self.args_cmd = args_cmd
self.args = None
self.args_megatron_checkpoint = None
self.module = None
self.module_mapping = None
self.model_cfg = self.read_model_cfg(args_cmd)
if self.args_cmd.save_lora_to_hf:
self.lora_layer_mappings = self.read_model_cfg(self.args_cmd, True)
self.__register_functions()
self.kwargs_idx = OrderedDict({
"vp_rank": 0,
"ep_rank": 0,
"tp_rank": 0,
"layer_idx": 0,
"expert_idx": 0
})
def update_kwargs_idx(self, **kwargs):
for key in self.kwargs_idx:
if key in kwargs:
self.kwargs_idx[key] = kwargs[key]
else:
self.kwargs_idx[key] = 0
def __register_functions(self):
self.get_module_mapping()
def _get_obj(self, value, **kwargs):
pattern = r'(\w+)(?:\[(\w+)\])?'
matches = re.findall(pattern, value)
self.update_kwargs_idx(**kwargs)
obj = self.get_model_item(**kwargs)
for attr, attr_ident in matches:
if hasattr(obj, attr):
obj = getattr(obj, attr)
else:
return None
if attr_ident:
if attr_ident in self.kwargs_idx:
attr_idx = self.kwargs_idx[attr_ident]
obj = obj[attr_idx]
else:
raise AssertionError(f"check {self.__class__.__name__}.module_mapping **{attr_ident}**.")
return obj
def _get_dst_obj(self, value, **kwargs):
if kwargs.get("layer_idx") is None:
kwargs["layer_idx"] = kwargs.get("dst_layer_idx")
return _get_obj(self, value, **kwargs)
def _get_src_obj(self, value, **kwargs):
if kwargs.get("layer_idx") is None:
kwargs["layer_idx"] = kwargs.get("src_layer_idx")
return _get_obj(self, value, **kwargs)
def _func_generator_get_module(value):
def func(self, **kwargs):
return _get_src_obj(self, value, **kwargs)
return func
def _func_generator_get_weight(value):
def func(self, **kwargs):
return _get_src_obj(self, value, **kwargs).weight.data
return func
def _func_generator_get_bias(value):
def func(self, **kwargs):
return _get_src_obj(self, value, **kwargs).bias.data
return func
def _func_generator_set_weight(value):
def func(self, **kwargs):
return _get_dst_obj(self, value, **kwargs).weight.data.copy_(kwargs.get('data'))
return func
def _func_generator_set_module(value):
def func(self, **kwargs):
return _get_dst_obj(self, value, **kwargs).data.copy_(kwargs.get('data'))
return func
def _func_generator_set_bias(value):
def func(self, **kwargs):
return _get_dst_obj(self, value, **kwargs).bias.data.copy_(kwargs.get('data'))
return func
def _func_generator_has_module(value):
def func(self, **kwargs):
obj = _get_src_obj(self, value, **kwargs)
return True if obj else False
return func
def _func_generator_has_bias(value):
def func(self, **kwargs):
bias = getattr(_get_src_obj(self, value, **kwargs), 'bias', None)
return bias is not None
return func
if self.module_mapping:
for key, value in self.module_mapping.items():
setattr(self, "get_" + key + "_module", _func_generator_get_module(value).__get__(self, ModelBase))
setattr(self, "set_" + key + "_module", _func_generator_set_module(value).__get__(self, ModelBase))
setattr(self, "get_" + key + "_weight", _func_generator_get_weight(value).__get__(self, ModelBase))
setattr(self, "get_" + key + "_bias", _func_generator_get_bias(value).__get__(self, ModelBase))
setattr(self, "set_" + key + "_weight", _func_generator_set_weight(value).__get__(self, ModelBase))
setattr(self, "set_" + key + "_bias", _func_generator_set_bias(value).__get__(self, ModelBase))
setattr(self, "has_" + key + "_module", _func_generator_has_module(value).__get__(self, ModelBase))
setattr(self, "has_" + key + "_bias", _func_generator_has_bias(value).__get__(self, ModelBase))
def update_module(self, src_model):
if not self.args_cmd.save_lora_to_hf:
self.set_preprocess_state(src_model)
self.set_postprocess_state(src_model)
if not (hasattr(self.args, "noop_layers") and self.args.noop_layers):
for layer_idx in tqdm(range(self.args.num_layers), "set layer states"):
self.set_layer_state(src_model, layer_idx)
return
# Do ckpt conversion when noop layer is configured.
# For example, hf_layer = [0, 1], add noop layer [1, 3], then mg_layers = [0(0), 1(noop), 2(1), 3(noop)]
hf_num_layers = self.args.num_layers - len(self.args.noop_layers)
mg_layer_list = [i for i in range(hf_num_layers)]
for i in self.args.noop_layers:
# insert noop layer
mg_layer_list.insert(i, -1)
for dst_layer_idx, src_layer_idx in enumerate(mg_layer_list):
if self.args_cmd.save_model_type == "hf":
if not self.is_noop_layer(src_layer_idx):
self.set_layer_state_base(src_model, src_layer_idx=dst_layer_idx, dst_layer_idx=src_layer_idx)
else:
if not self.is_noop_layer(src_layer_idx):
self.set_layer_state_base(src_model, src_layer_idx=src_layer_idx, dst_layer_idx=dst_layer_idx)
def set_preprocess_state(self, src_model):
"""Set embedding params."""
embeddings_weight = src_model.get_embedding_word_embeddings_weight()
if embeddings_weight.size(0) > self.get_embedding_word_embeddings_weight().size(0):
logger.info(f"Source embedding size: {embeddings_weight.size()} "
f"Target embedding size: {self.get_embedding_word_embeddings_weight().size()}")
embeddings_weight = embeddings_weight[:self.get_embedding_word_embeddings_weight().size(0), :]
self.set_embedding_word_embeddings_weight(data=embeddings_weight)
if src_model.has_embedding_word_embeddings_norm_module():
embd_norm_weight = src_model.get_embedding_word_embeddings_norm_weight()
embd_norm_bias = src_model.get_embe