简单实现BERT模型文本情感分类

任务说明

使用huggingface的开源预训练模型训练下游任务。要求训练模型能够对文本做情感分类任务。"Positive"or"Negative"

示例:

"text": "The movie is excellent",

"Lable": "Positive"

制定计划:

截止时间:8月30日,周五

8月28日:

完成计划,准备资源。

8月29日:

执行计划,记录进展

8月30日:

撰写报告,反馈和优化,学习与总结。

任务要求

  • 找到合适的预训练模型
  • 下载预训练模型到本地
  • 使用加载的BERT模型作为特征提取器,并在其顶部添加一个线性分类器。
  • 实现情感二分类任务

实现思路:

  1. 加载预训练模型:
    • 使用BertModel.from_pretrained加载预训练的BERT模型。
  2. 构建分类模型:
    • 使用加载的BERT模型作为特征提取器,并在其顶部添加一个线性分类器。
    • BertClassifier类中的__init__方法初始化了BERT模型和分类器。
    • forward方法定义了模型的前向传播过程,使用BERT模型提取文本特征,并通过分类器进行分类。
  3. 微调过程:
    • train_epoch函数中,我们执行了一个epoch的训练过程,包括前向传播、计算损失、反向传播、优化器更新等。
    • 我们使用了AdamW优化器和线性学习率调度器来更新BERT模型和分类器的参数。
    • 通过这种方式,BERT模型的权重会被微调以更好地适应情感分类任务。
  4. 评估:
    • eval_model函数中,我们评估了模型在测试集上的性能,计算了准确率和损失。

实现代码:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import (BertTokenizer,
                          BertModel,
                          AdamW,
                          get_linear_schedule_with_warmup)
import warnings


warnings.filterwarnings('ignore')


# 自定义数据集类
class TextDataset(Dataset):
    """
    定义了一个继承自 `torch.utils.data.Dataset` 的类 `TextDataset`,用于处理文本数据。

    `__init__` 方法初始化数据集,接收文本、标签、tokenizer 和最大长度作为参数。
    `__len__` 方法返回数据集中文本的数量。
    `__getitem__` 方法对单个文本进行编码,并返回编码后的输入ID、注意力掩码和对应的标签。
    """

    def __init__(self, texts, labels, tokenizer, max_len):
        """
        初始化方法接收 texts、labels、tokenizer 和 max_len 作为参数。

        :param texts: list,文本列表。
        :param labels: list,与文本相对应的标签列表。
        :param tokenizer: `transformers` 中的 tokenizer,用于文本编码。
        :param max_len: int,每个文本的最大长度,超过该长度的文本将会被截断。
        """
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        """
        返回数据集中文本的数量,即 `self.texts` 的长度。

        :return: int,数据集中文本的数量。
        """
        return len(self.texts)

    def __getitem__(self, item):
        """
        对单个文本进行编码,并返回编码后的输入ID、注意力掩码和对应的标签。

        :param item: int,数据集中文本的索引。
        :return: dict,包含 'input_ids'、'attention_mask' 和 'labels' 键的字典。
        """
        text = str(self.texts[item])
        label = self.labels[item]

        # 使用 tokenizer 对文本进行编码
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,  # 添加特殊令牌 [CLS] 和 [SEP]
            max_length=self.max_len,  # 指定最大长度
            padding='max_length',  # 填充到最大长度
            truncation=True,  # 超过最大长度的部分将被截断
            return_attention_mask=True,  # 返回注意力掩码
            return_tensors='pt'  #
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值