任务说明
使用huggingface的开源预训练模型训练下游任务。要求训练模型能够对文本做情感分类任务。"Positive"or"Negative"
示例:
"text": "The movie is excellent",
"Lable": "Positive"
制定计划:
截止时间:8月30日,周五
8月28日:
完成计划,准备资源。
8月29日:
执行计划,记录进展
8月30日:
撰写报告,反馈和优化,学习与总结。
任务要求
- 找到合适的预训练模型
- 下载预训练模型到本地
- 使用加载的BERT模型作为特征提取器,并在其顶部添加一个线性分类器。
- 实现情感二分类任务
实现思路:
- 加载预训练模型:
- 使用
BertModel.from_pretrained
加载预训练的BERT模型。
- 使用
- 构建分类模型:
- 使用加载的BERT模型作为特征提取器,并在其顶部添加一个线性分类器。
BertClassifier
类中的__init__
方法初始化了BERT模型和分类器。forward
方法定义了模型的前向传播过程,使用BERT模型提取文本特征,并通过分类器进行分类。
- 微调过程:
- 在
train_epoch
函数中,我们执行了一个epoch的训练过程,包括前向传播、计算损失、反向传播、优化器更新等。 - 我们使用了
AdamW
优化器和线性学习率调度器来更新BERT模型和分类器的参数。 - 通过这种方式,BERT模型的权重会被微调以更好地适应情感分类任务。
- 在
- 评估:
- 在
eval_model
函数中,我们评估了模型在测试集上的性能,计算了准确率和损失。
- 在
实现代码:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import (BertTokenizer,
BertModel,
AdamW,
get_linear_schedule_with_warmup)
import warnings
warnings.filterwarnings('ignore')
# 自定义数据集类
class TextDataset(Dataset):
"""
定义了一个继承自 `torch.utils.data.Dataset` 的类 `TextDataset`,用于处理文本数据。
`__init__` 方法初始化数据集,接收文本、标签、tokenizer 和最大长度作为参数。
`__len__` 方法返回数据集中文本的数量。
`__getitem__` 方法对单个文本进行编码,并返回编码后的输入ID、注意力掩码和对应的标签。
"""
def __init__(self, texts, labels, tokenizer, max_len):
"""
初始化方法接收 texts、labels、tokenizer 和 max_len 作为参数。
:param texts: list,文本列表。
:param labels: list,与文本相对应的标签列表。
:param tokenizer: `transformers` 中的 tokenizer,用于文本编码。
:param max_len: int,每个文本的最大长度,超过该长度的文本将会被截断。
"""
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
"""
返回数据集中文本的数量,即 `self.texts` 的长度。
:return: int,数据集中文本的数量。
"""
return len(self.texts)
def __getitem__(self, item):
"""
对单个文本进行编码,并返回编码后的输入ID、注意力掩码和对应的标签。
:param item: int,数据集中文本的索引。
:return: dict,包含 'input_ids'、'attention_mask' 和 'labels' 键的字典。
"""
text = str(self.texts[item])
label = self.labels[item]
# 使用 tokenizer 对文本进行编码
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True, # 添加特殊令牌 [CLS] 和 [SEP]
max_length=self.max_len, # 指定最大长度
padding='max_length', # 填充到最大长度
truncation=True, # 超过最大长度的部分将被截断
return_attention_mask=True, # 返回注意力掩码
return_tensors='pt' #