# coding:utf-8
import os
import pickle
import random
import logging
import warnings
from argparse import ArgumentParser
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple
from collections import defaultdict
import torch
from torch.utils.data import Dataset
from transformers import (
BertTokenizer,
TrainingArguments,
Trainer
)
from src.util.modeling.modeling_nezha.modeling import NeZhaConfig, NeZhaForMaskedLM
warnings.filterwarnings('ignore')
logging.basicConfig()
logger = logging.getLogger('')
logger.setLevel(logging.INFO)
def save_pickle(dic, save_path):
with open(save_path, 'wb') as f:
pickle.dump(dic, f)
def load_pickle(load_path):
with open(load_path, 'rb') as f:
message_dict = pickle.load(f)
return message_dict
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def read_data(args, tokenizer: BertTokenizer) -> dict:
pretrain_df = pd.read_csv(args.pretrain_data_path, header=None, sep='\t')
inputs = defaultdict(list)
for i, row in tqdm(pretrain_df.iterrows(), desc='', total=len(pretrain_df)):
sentence = row[0].strip()
inputs_dict = tokenizer.encode_plus(sentence, add_special_tokens=True,
return_token_type_ids=True, return_attention_mask=True)
inputs['input_ids'].append(inputs_dict['input_ids'])
inputs['token_type_ids'].append(inputs_dict['token_type_ids'])
inputs['attention_mask'].append(inputs_dict['attention_mask'])
os.makedirs(os.path.dirname(args.data_cache), exist_ok=True)
save_pickle(inputs, args.data_cache)
return inputs
class PretrainDataset(Dataset):
def __init__(self, data_dict: dict):
super(Dataset, self).__init__()
self.data_dict = data_dict
def __getitem__(self, index: int) -> tuple:
data = (self.data_dict['input_ids'][index]
如何用自己的数据集进行bert预训练
最新推荐文章于 2025-07-03 08:00:00 发布