bert2joint_dataloader.py

本文详细介绍了BERTforKeyphraseExtraction模型的工作原理及其实现细节,包括关键函数如convert_to_label和get_ngram_features的解析,以及数据预处理流程。通过分析,读者可以了解到如何利用BERT进行关键词抽取。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

2021SC@SDUSC

系列文章目录

(一)面向特定问题的开源算法管理和推荐

(二)论文阅读上

(三)sent2vec

(四)BERT for Keyphrase Extraction

(五)config.py 代码分析

(六)model.py(上)

(七)论文 - 补充理解

(八)数据处理之prepro_utils.py

(九)preprocess.py代码分析

(十)preprocess.py代码分析-下

(十一)spllit_json.py代码分析

(十二)prepro_utils.py代码分析

(十三)jsonify_multidata.py + Constant.py

(十四)loader_utils.py

(十五)Keyphrase Chunking - bert2chunk_dataloader.py分析

(十六)Encoder-Decoder

(十七)bert2joint_dataloader.py


bert2joint

转化为标签(convert_to_label)

函数

def convert_to_label(filter_positions, tot_mention_list, differ_phrase_num):

……

return ngram_label, chunk_label
    else:
        return None, None

首先检查关键词提到的索引是否相同;
然后将keyprhase ngrams设置为+1,其他候选短语设置为-1。

ngram_label = [-1 for _ in range(differ_phrase_num)]
    chunk_label_list = [
        [0] * len(tot_mention_list[i]) for i in range(len(tot_mention_list))
    ]

    for i, positions in enumerate(filter_positions):
        for s, e in positions:
            chunk_label_list[e - s][s] = 1
            key_index = tot_mention_list[e - s][s]
            ngram_label[key_index] = 1

平滑标签块

 chunk_label = [_chunk for chunks in chunk_label_list for _chunk in chunks]

返回结果 保持有一个以上的正项和一个以上的负项

if (
        (1 in ngram_label)
        and (-1 in ngram_label)
        and (1 in chunk_label)
        and (0 in chunk_label)
    ):
        return ngram_label, chunk_label
    else:
        return None, None

获取n-gram特征

函数

def get_ngram_features(doc_words, max_gram, stem_flag=False):

…………
  return {"tot_phrase_list": tot_phrase_list, "tot_mention_list": tot_mention_list}

定义用于最后评价和训练的相应 list

phrase2index = {}  # use to shuffle same phrases
    tot_phrase_list = []  # use to final evaluation
    tot_mention_list = []  # use to train pooling the same

函数实现,调用loader_utils.py中的whether_stem_existing()方法和whether_existing()方法,判断是否存在词干存在

 gram_num = 0
    for n in range(1, max_gram + 1):
        valid_length = len(doc_words) - n + 1

        if valid_length < 1:
            break

        _ngram_list = []
        _mention_list = []
        for i in range(valid_length):

            gram_num += 1

            n_gram = " ".join(doc_words[i : i + n]).lower()

            if stem_flag:
                index = loader_utils.whether_stem_existing(
                    n_gram, phrase2index, tot_phrase_list
                )
            else:
                index = loader_utils.whether_existing(
                    n_gram, phrase2index, tot_phrase_list
                )

            _mention_list.append(index)
            _ngram_list.append(n_gram)

        tot_mention_list.append(_mention_list)

    assert len(tot_phrase_list) > 0

    assert (len(tot_phrase_list) - 1) == max(tot_mention_list[-1])
    assert sum([len(_mention_list) for _mention_list in tot_mention_list]) == gram_num
    return {"tot_phrase_list": tot_phrase_list, "tot_mention_list": tot_mention_list}

获取块对应的n-gram特征

定义关键词列表和块list

keyphrases_list = [" ".join(kp).lower() for kp in keyphrases]
    chunk_label = []
    phrase2index = {}  # use to shuffle same phrases
    tot_phrase_list = []  # use to final evaluation
    tot_mention_list = []  # use to train pooling the same

与上文类似

 gram_num = 0
    for n in range(1, max_gram + 1):
        valid_length = len(doc_words) - n + 1

        if valid_length < 1:
            break

        _ngram_list = []
        _mention_list = []
        for i in range(valid_length):

            gram_num += 1

            n_gram = " ".join(doc_words[i : i + n]).lower()

            if stem_flag:
                index = loader_utils.whether_stem_existing(
                    n_gram, phrase2index, tot_phrase_list
                )
            else:
                index = loader_utils.whether_existing(
                    n_gram, phrase2index, tot_phrase_list
                )

不同的处理,对于块标签的处理

if n_gram in keyphrases_list:
                chunk_label.append(1)
            else:
                chunk_label.append(0)

处理要返回的数据

_mention_list.append(index)
            _ngram_list.append(n_gram)

        tot_mention_list.append(_mention_list)

    assert len(tot_phrase_list) > 0

    assert (len(tot_phrase_list) - 1) == max(tot_mention_list[-1])
    assert (
        sum([len(_mention_list) for _mention_list in tot_mention_list])
        == gram_num
        == len(chunk_label)
    )
    return {
        "tot_phrase_list": tot_phrase_list,
        "tot_mention_list": tot_mention_list,
        "chunk_label": chunk_label,
    }

获取 n-gram信息标签

函数名以及返回值:

def get_ngram_info_label(
    doc_words, max_phrase_words, stem_flag, keyphrases=None, start_end_pos=None
):

    returns = {"overlen_flag": False, "ngram_label": None, "chunk_label": None}

定义特征,要返回内容赋值

 feature = get_ngram_features(
        doc_words=doc_words, max_gram=max_phrase_words, stem_flag=stem_flag
    )
    returns["tot_phrase_list"] = feature["tot_phrase_list"]
    returns["tot_mention_list"] = feature["tot_mention_list"]

注意检查长度是否过长

 if start_end_pos is not None:
        filter_positions = loader_utils.limit_scope_length(
            start_end_pos, len(doc_words), max_phrase_words
        )

        # check over_length
        if len(filter_positions) != len(start_end_pos):
            returns["overlen_flag"] = True

        if len(filter_positions) > 0:
            returns["ngram_label"], returns["chunk_label"] = convert_to_label(
                **{
                    "filter_positions": filter_positions,
                    "tot_mention_list": feature["tot_mention_list"],
                    "differ_phrase_num": len(feature["tot_phrase_list"]),
                }
            )
        else:
            returns["ngram_label"] = None
            returns["chunk_label"] = None

    return returns

进行预处理

函数定义:

def bert2joint_preprocessor(
    examples,
    tokenizer,
    max_token,
    pretrain_model,
    mode,
    max_phrase_words,
    stem_flag=False,
):

词条化,用到了tqdm库,使用进度条

overlen_num = 0
    new_examples = []
    for idx, ex in enumerate(tqdm(examples)):

        # tokenize
        tokenize_output = loader_utils.tokenize_for_bert(
            doc_words=ex["doc_words"], tokenizer=tokenizer
        )

        if len(tokenize_output["tokens"]) < max_token:
            max_word = max_token
        else:
            max_word = tokenize_output["tok_to_orig_index"][max_token - 1] + 1

        new_ex = {}
        new_ex["url"] = ex["url"]
        new_ex["tokens"] = tokenize_output["tokens"][:max_token]
        new_ex["valid_mask"] = tokenize_output["valid_mask"][:max_token]
        new_ex["doc_words"] = ex["doc_words"][:max_word]
        assert len(new_ex["tokens"]) == len(new_ex["valid_mask"])
        assert sum(new_ex["valid_mask"]) == len(new_ex["doc_words"])

参数设置

parameter = {
            "doc_words": new_ex["doc_words"],
            "max_phrase_words": max_phrase_words,
            "stem_flag": stem_flag,
        }

获取gram信息和标签

info_or_label = get_ngram_info_label(**parameter)

        new_ex["phrase_list"] = info_or_label["tot_phrase_list"]
        new_ex["mention_lists"] = info_or_label["tot_mention_list"]

        if info_or_label["overlen_flag"]:
            overlen_num += 1

在获取gram信息和标签前后进行对mode的判断
对训练状态下进行处理

if mode == "train":
            parameter["keyphrases"] = ex["keyphrases"]
            parameter["start_end_pos"] = ex["start_end_pos"]

 if mode == "train":
            if not info_or_label["ngram_label"]:
                continue
            new_ex["keyphrases"] = ex["keyphrases"]
            new_ex["ngram_label"] = info_or_label["ngram_label"]
            new_ex["chunk_label"] = info_or_label["chunk_label"]

logger日志记录和返回值

 new_examples.append(new_ex)

    logger.info(
        "Delete Overlen Keyphrase (length > 5): %d (overlen / total = %.2f"
        % (overlen_num, float(overlen_num / len(examples) * 100))
        + "%)"
    )
    return new_examples

将每批数据转换为张量;添加 [CLS] [SEP]标记

函数定义:

def bert2joint_converter(index, ex, tokenizer, mode, max_phrase_words):

初始定义赋值

 src_tokens = [BOS_WORD] + ex["tokens"] + [EOS_WORD]
    valid_ids = [0] + ex["valid_mask"] + [0]

    src_tensor = torch.LongTensor(tokenizer.convert_tokens_to_ids(src_tokens))
    valid_mask = torch.LongTensor(valid_ids)

    mention_lists = ex["mention_lists"]
    orig_doc_len = sum(valid_ids)

对于训练和测试状态进行不同赋值

if mode == "train":
        label = torch.LongTensor(ex["ngram_label"])
        chunk_label = torch.LongTensor(ex["chunk_label"])
        return (
            index,
            src_tensor,
            valid_mask,
            mention_lists,
            orig_doc_len,
            max_phrase_words,
            label,
            chunk_label,
        )

    else:
        tot_phrase_len = len(ex["phrase_list"])
        return (
            index,
            src_tensor,
            valid_mask,
            mention_lists,
            orig_doc_len,
            max_phrase_words,
            tot_phrase_len,
        )

训练加载器和评估加载器

函数定义:

def batchify_bert2joint_features_for_train(batch):

获取每一轮的id ,docs,有效掩码,文档单词长度,最大短语长度

ids = [ex[0] for ex in batch]
    docs = [ex[1] for ex in batch]
    valid_mask = [ex[2] for ex in batch]
    mention_mask = [ex[3] for ex in batch]
    doc_word_lens = [ex[4] for ex in batch]
    max_phrase_words = [ex[5] for ex in batch][0]

处理标签
判断是否是短语块,同时获取单词的最大长度

 label_list = [ex[6] for ex in batch]  # different ngrams numbers
    chunk_list = [ex[7] for ex in batch]  # whether is a chunk phrase

    bert_output_dim = 768
    max_word_len = max([word_len for word_len in doc_word_lens])  # word-level

对每个词项,获取有效id,最大单词长度

 # ---------------------------------------------------------------
    # [1] [2] src tokens tensor
    doc_max_length = max([d.size(0) for d in docs])
    input_ids = torch.LongTensor(len(docs), doc_max_length).zero_()
    input_mask = torch.LongTensor(len(docs), doc_max_length).zero_()
    # segment_ids = torch.LongTensor(len(docs), doc_max_length).zero_()

    for i, d in enumerate(docs):
        input_ids[i, : d.size(0)].copy_(d)
        input_mask[i, : d.size(0)].fill_(1)

    # ---------------------------------------------------------------
    # [3] valid mask tensor
    valid_max_length = max([v.size(0) for v in valid_mask])
    valid_ids = torch.LongTensor(len(valid_mask), valid_max_length).zero_()
    for i, v in enumerate(valid_mask):
        valid_ids[i, : v.size(0)].copy_(v)

    # ---------------------------------------------------------------

对于 n-gram 处理mask

 # [4] active mention mask : for n-gram (original)

    max_ngram_length = sum([max_word_len - n for n in range(max_phrase_words)])
    chunk_mask = torch.LongTensor(len(docs), max_ngram_length).fill_(-1)

    for batch_i, word_len in enumerate(doc_word_lens):
        pad_len = max_word_len - word_len

        batch_mask = []
        for n in range(max_phrase_words):
            ngram_len = word_len - n
            if ngram_len > 0:
                assert len(mention_mask[batch_i][n]) == ngram_len
                gram_list = mention_mask[batch_i][n] + [
                    -1 for _ in range(pad_len)
                ]  # -1 for padding
            else:
                gram_list = [-1 for _ in range(max_word_len - n)]
            batch_mask.extend(gram_list)
        chunk_mask[batch_i].copy_(torch.LongTensor(batch_mask))

    # ---------------------------------------------------------------
    # [4] active mask : for n-gram
    max_diff_gram_num = 1 + max(
        [max(_mention_mask[-1]) for _mention_mask in mention_mask]
    )
    active_mask = torch.BoolTensor(
        len(docs), max_diff_gram_num, max_ngram_length
    ).fill_(1)
    #     active_mask = torch.ByteTensor(len(docs), max_diff_gram_num, max_ngram_length).fill_(1) # Pytorch Version = 1.1.0
    for gram_ids in range(max_diff_gram_num):
        tmp = torch.where(
            chunk_mask == gram_ids,
            torch.LongTensor(len(docs), max_ngram_length).fill_(0),
            torch.LongTensor(len(docs), max_ngram_length).fill_(1),
        )  # shape = (batch_size, max_ngram_length) # 1 for pad
        for batch_id in range(len(docs)):
            active_mask[batch_id][gram_ids].copy_(tmp[batch_id])

清空 tensor,单词 最大长度
返回

# -------------------------------------------------------------------
    valid_output = torch.zeros(len(docs), max_word_len, bert_output_dim)
    return (
        input_ids,
        input_mask,
        valid_ids,
        active_mask,
        valid_output,
        phrase_list_lens,
        ids,
    )
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值