Skip to content

zhenhanbai/THUCNews-text-classification

Repository files navigation

THUCNews-text-classification

1、功能描述

  1. 数据集的选择:选择THUCNews数据集用于中文文本分类任务的训练和测试。
  2. 数据预处理:对中文文本进行分词和向量化等预处理操作,一边模型的理解和处理。
  3. 模型选择:
    • 首先,初步尝试使用 fastTextTextCNNLSTM 这三种神经模型。
    • 接着,进行实验,采用 BERTRoBERTaERNIE 等预训练模型。
    • 最后,采用 ERNIE 结合神经网络的方法,进一步提高准确率。
  4. 模型训练:使用训练集对选定的模型进行训练。在训练过程中,采用交叉熵损失函数和Adam优化算法,调整模型参数以最小化损失。
  5. 模型评估:使用测试集对已训练的模型进行评估,包括计算准确率和混淆矩阵,以全面评估模型性能。

2、数据集的选择

选用THUCNews数据集,该数据集是清华大学提供的一个中文新闻文本分类数据集,主要用于中文文本分类任务。该数据集包含了来自新浪新闻网的14个不同类别的新闻文本。

数据集的类别包括但不限于体育、娱乐、家居、教育、时政、游戏、社会、科技、股票和财经等。每个类别有数万条新闻文本样本,是一个相对较大的中文文本分类数据集。

本实验从THUCNews数据集中随机抽取了24万条数据,按照10:1:1的比例分配为训练集、验证集和测试集。以训练集为例,训练集共包含20万条数据,涵盖十个类别,每个类别均匀分布,各有2万条数据。

3、模型设计

  • 神经网络模型

    1. fastText

      fastText是Facebook于2016年开源的一个词向量计算和文本分类工具,在效果上往往可以媲美深度网络,却在训练时间上比深度网络快许多数量级。fastText的核心思想就是:将整篇文档的词及n-gram向量叠加平均得到文档向量,然后使用文档向量做softmax多分类。

    2. TextCNN 

      TextCNN是一种经典的卷积神经网络(CNN)架构,用于文本分类和文本表示。它通过卷积层和池化层来捕获文本中的局部特征,并在整个文本上提取关键信息。TextCNN使用多个卷积核对文本进行卷积操作,每个卷积核负责捕获不同大小的特征。这些卷积核在文本的不同窗口大小上进行滑动,每个窗口大小对应一种局部模式的捕获,如单个词、词组或短语。 经过卷积层后,通常使用池化层对每个卷积特征图进行下采样。一种常见的池化操作是最大池化(Max Pooling),它提取特征图中每个通道的最大值,以保留最显著的特征。池化操作有助于减少特征的维度并保留最重要的信息。

    3. RNN

      RNN在解决文本分类问题的时候具有先天性的优势,通过循环结构处理序列数据,能够记忆先前的信息,并将其用于后续的预测或分类。RNN按序列逐步处理输入文本,每个时间步接收一个词嵌入向量,并考虑之前的隐藏状态,计算当前时间步的隐藏状态。RNN模型的最后一个隐藏状态或整个序列的隐藏状态被用于文本分类任务。通常将最后一个隐藏状态或整个序列的表示传递给全连接层进行分类。

  • 预训练模型

    1. BERT

      BERT使用了Transformer模型结构,能够同时考虑文本序列中任意位置的上下文信息。相比传统的从左到右或从右到左的单向语言模型,BERT能够双向地预测单词或token,从而更好地理解上下文信息。BERT通过大规模无监督的预训练来学习语言表示。在预训练阶段,使用了大量的文本数据来训练模型,学习到了通用的语言表示。之后,可以通过微调(fine-tuning)方式,将预训练好的模型用于文本分类。

    2. RoBERTa

      RoBERTa在BERT的基础上通过一系列优化,如更严格的掩码方式、更多数据和更长序列的训练,以及更大的训练批次等,进一步提高了模型的泛化能力和性能,在各种NLP任务上取得了更好的效果。

    3. XLNet

      XLNet是由谷歌提出的一个自然语言处理模型,它结合了Transformer结构和自回归语言模型的概念,并利用了预训练和微调的方法来解决NLP任务。XLNet以自回归方式训练,通过考虑所有可能的文本片段排列来最大化预训练数据的似然性。这种方式相比BERT的单向和GPT的双向训练方法,更全面地捕捉文本序列的相关性。XLNet通过将上下文中的词考虑在内,利用了更多的信息来预测下一个词。

    4. ERNIE

      ERNIE是百度提出的一种语言表示模型。它结合了语言模型预训练和知识增强的方式,以更好地理解语言和建模语言表达。ERNIE采用多任务学习和知识融合的方式,通过多种预训练任务和数据源来提高模型的语义理解能力。在解决中文文本分类时的效果往往会更好。

  • 基于预训练模型的改进模型

    1. ERNIE-TextCNN

      ERNIE-TextCNN在处理文本任务时能够充分利用ERNIE的语义理解和知识融合能力,同时结合TextCNN对局部特征的敏感性,综合提升了模型的表现。

    2. ERNIE-RNN

      ERNIE模型在预训练阶段通过多任务学习和知识融合,能够更好地理解语义,提取文本特征。RNN能够有效地捕捉序列信息和长期依赖关系。ERNIE-RNN结合了ERNIE对语义理解和知识融合的能力,以及RNN对于序列数据的敏感性,可以在多个层次上对文本信息进行建模和理解。

    3. ERNIE-RCNN

      ERNIE-RCNN将ERNIE和RNN学习到的信息进行了拼接,使得模型能够同时充分利用ERNIE对语义的理解和RNN对序列信息的处理能力。这种方式保留了ERNIE和RNN学习到的双重信息,并让模型更全面地理解和处理文本信息。相较于ERNIE-RNN舍去ERNIE学习到的内容,ERNIE-RCNN的信息融合方式更综合,有望在一些任务中表现更出色。

4、实验结果分析

3090上进行训练的本次实验,下表显示了模型在验证集上的最佳表现以及一个epoch所需的时间。通过采用early stop策略,当模型在验证集上连续1000个步骤未见准确率提升时,即进行早停,以防止过拟合。实验设定了epoch大小为5,本次实验采用的RNN模型类型是双向LSTM。

image-20231120212234569

image-20231120212212385

Model_Name Dev_Acc(%) Time(s)
fastText 91.60 8
TextCNN 94.21 22
RNN 94.37 39
BERT 95.86 381
RoBERTa 95.92 105
XLNet 95.99 470
ERNIE 96.10 387
ERNIE-TextCNN 96.11 403
ERNIE-RNN 96.07 398
ERNIE-RCNN 96.32 412

分析上述结果:

  • 从准确率上看: ERNIE-RCNN模型在开发集上取得了96.32%的准确率,相比其他模型表现更好。不过,它的训练时间较长,达到了412秒。BERT和RoBERTa模型在准确率上相当接近,并且相对较快地训练完成。fastText虽然训练速度快,但准确率相对较低。总体来说,ERNIE-RCNN在准确率和训练效率上都表现出色。
  • 从时间上看: fastText和神经模型占有优势,尤其是fastText在保持91%以上的准确率的情况下,8s完成了对20万条数据的处理。

总的来说,使用预训练模型及其与后续神经网络的组合能够提高本任务的准确率。然而,这种提升是伴随着更高的计算成本和更长的训练时间。在实际应用中,需要根据具体情况进行权衡和选择。

对使用上面效果最好的ERNIE-RCNN模型对测试集进行测试,Test Acc: 96.08%。下表给出了每一类的precision、 recall、f1-score。总体准确率为96.08%,每个类别的准确率、召回率和 F1-score 都很高,大部分都在 95% 到 98% 之间。这表示模型在多个类别上都有很好的表现,能够有效地对各个类别进行分类。有一些类别(如“时政”)的召回率略低于其他类别。

              precision    recall  f1-score   support

          体育     0.9949    0.9840    0.9894      2000
          娱乐     0.9685    0.9840    0.9762      2000
          家居     0.9721    0.9590    0.9655      2000
          教育     0.9625    0.9630    0.9628      2000
          时政     0.9254    0.9430    0.9341      2000
          游戏     0.9784    0.9755    0.9770      2000
          社会     0.9514    0.9590    0.9552      2000
          科技     0.9696    0.9580    0.9638      2000
          股票     0.9465    0.9375    0.9420      2000
          财经     0.9403    0.9455    0.9429      2000

    accuracy                          0.9608      20000
   macro avg      0.9610    0.9608    0.9609      20000
weighted avg      0.9610    0.9608    0.9609      20000
[[1968   14    1    2    9    1    4    0    0    1]
 [   2 1968    7    3    7    2    6    2    1    2]
 [   1   21 1918   13   11    6    7    9    6    8]
 [   0    5    1 1926   21    1   37    8    1    0]
 [   0    5    6   15 1886    2   19   17   40   10]
 [   1    7   11    5    1 1951    8    8    6    2]
 [   1    5    3   18   42    0 1918   11    0    2]
 [   1    0    5    8   27   28    9 1916    3    3]
 [   3    1    7    6   13    1    1    1 1875   92]
 [   1    6   14    5   21    2    7    4   49 1891]]

分析混淆矩阵,可以观察到一些有用的信息。例如,社会类别被错误地分类为时政,而时政又被误分类为股票。此外,财经类别也经常被错误地分类为股票。在检查误分类的测试集时,发现一些文本可能存在多个标签,这导致某些内容在不同类别之间存在交叉。为了解决这种情况,增加训练数据可能会有所帮助,因为这有助于模型更好地理解不同类别之间的差异,最终减少这种交叉分类的情况。

data	label 	pre_label
分析称卫星相撞事件可能影响美太空政策及预算	时政	股票
台湾商人疑遭他人毒手在南非惨死家中	时政	社会
德国智库称德不需要美国式银行拯救计划	时政	股票
美科学家证实爱情可以维持永恒	时政	科技
我国重型巨型运载火箭研制列入发展规划(图)	时政	科技
连战次子连胜武周日完婚婚纱照提前曝光(组图)	时政	娱乐
扔鞋砸布什游戏风靡网络	时政	游戏
瑞典使用直升机营救濒危甲虫(图)	时政	科技
欧盟正式批准瑞士下月加入申根区	时政	股票
奥巴马反对只购买美国货条款	时政	股票
美能源部长朱棣文清华演讲杨振宁夫妇旁听	时政	教育
3名越南男子锯断旧炮弹时被炸死	时政	社会
台湾密宗大师林云过世弟子争遗产	时政	社会
英国1岁女童接受手术移植牛血管	时政	社会

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages