用Attention和微调BERT进行自然语言推断-PyTorch

一、自然语言推断与数据集

当需要决定一个句子是否可以从另一个句子推断出来,或者需要通过识别语义等价的句子来消除句子间冗余时,知道如何对一个文本序列进行分类是不够的。相反,我们需要能够对成对的文本序列进行推断。

1.自然语言推断

自然语言推断(natural language inference)主要研究假设(hypothesis)是否可以从前提(premise)中推断出来,其中两者都是文本序列。换言之,自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型:

  • 蕴涵(entailment): 假设可以从前提中推断出来。
  • 矛盾(contradiction): 假设的否定可以从前提中推断出来。
  • 中性(neutral): 所有其他情况。

自然语言推断也被称为识别文本蕴涵任务。

2.斯坦福自然语言推断数据集

斯坦福自然语言推断语料库(Stanford Natural Language Inference,SNLI是由500000多个带标签的英语句子对组成的集合。训练集约有550000对,测试集约有10000对,训练集和测试集中的三个标签“蕴涵”、“矛盾”和“中性”是平衡的。

import os
import re
import torch
from torch import nn
from d2l import torch as d2l

#@save
d2l.DATA_HUB['SNLI'] = (
    'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
    '9fcde07509c7e87ec61c640c1b2753d9041758e4')

data_dir = d2l.download_extract('SNLI')

def read_snli(data_dir, is_train):
    """将SNLI数据集解析为前提、假设和标签"""
    def extract_text(s):
        # 删除括号
        s = re.sub('\\(', '', s)
        s = re.sub('\\)', '', s)
        # 两个或多个连续的空格只保留一个空格
        s = re.sub('\\s{2,}', ' ', s)
        return s.strip()
    # 蕴涵:0,矛盾:1,中性:2
    label_set = {
   
   'entailment': 0, 'contradiction': 1, 'neutral': 2}
    file_name = os.path.join(data_dir, 'snli_1.0_train.txt'
                             if is_train else 'snli_1.0_test.txt')
    with open(file_name, 'r') as f:
        rows = [row.split('\t') for row in f.readlines()[1:]]
    premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
    hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set]
    labels = [label_set[row[0]] for row in rows if row[0] in label_set]
    return premises, hypotheses, labels
  • 加载数据集
class SNLIDataset(torch.utils.data.Dataset):
    """用于加载SNLI数据集的自定义数据集"""
    def __init__(self, dataset, num_steps, vocab=None):
        self.num_steps = num_steps
        all_premise_tokens = d2l.tokenize(dataset[0])
        all_hypothesis_tokens = d2l.tokenize(dataset[1])
        if vocab is None:
            self.vocab = d2l.Vocab(all_premise_tokens + \
                all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])
        else:
            self.vocab = vocab
        self.premises = self._pad(all_premise_tokens)
        self.hypotheses = self._pad(all_hypothesis_tokens)
        self.labels = torch.tensor(dataset[2])
        print('read ' + str(len(self.premises)) + ' examples')

    def _pad(self, lines):
        return torch.tensor([d2l.truncate_pad(
            self.vocab[line], self.num_steps, self.vocab['<pad>'])
                         for line in lines])

    def __getitem__(self, idx):
        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]

    def __len__(self):
        return len(self.premises)

调用read_snli函数和SNLIDataset类来下载SNLI数据集,并返回训练集和测试集的DataLoader实例,以及训练集的词表。注意,必须使用从训练集构造的词表作为测试集的词表。因此,在训练集中训练的模型将不知道来自测试集的任何新词元。

def load_data_snli(batch_size, num_steps=50):
    """下载SNLI数据集并返回数据迭代器和词表"""
    data_dir = d2l.download_extract('SNLI')
    train_data = read_snli(data_dir, True)
    test_data = read_snli(data_dir, False)
    train_set = SNLIDataset(train_data, num_steps)
    test_set = SNLIDataset(test_data, num_steps, train_set.vocab)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,
                                             shuffle=True)
    test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                            shuffle=False)
    return train_iter, test_iter, train_set.vocab

二、利用注意力进行自然语言推断

1.模型

与保留前提和假设中词元的顺序相比,我们可以将一个文本序列中的词元与另一个文本序列中的每个词元对齐,然后比较和聚合这些信息,以预测前提和假设之间的逻辑关系。与机器翻译中源句和目标句之间的词元对齐类似,前提和假设之间的词元对齐可以通过注意力机制灵活地完成。

请添加图片描述


  • 注意

第一步是将一个文本序列中的词元与另一个序列中的每个词元对齐。对齐是使用加权平均的“软”对齐,其中理想情况下较大的权重与要对齐的词元相关联。

使用注意力机制的软对齐 A = ( a 1 , … , a m ) \mathbf{A} = (\mathbf{a}_1, \ldots, \mathbf{a}_m) A=(a1,,am) B = ( b 1 , … , b n ) \mathbf{B} = (\mathbf{b}_1, \ldots, \mathbf{b}_n) B=(b1,,bn)表示前提和假设,词元数量分别为 m m m n n n,其中 a i , b j ∈ R d \mathbf{a}_i, \mathbf{b}_j \in \mathbb{R}^{d} ai,bjRd。对于软对齐,注意力权重 e i j ∈ R e_{ij} \in \mathbb{R} e

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

葫芦娃啊啊啊啊

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值