2021SC@SDUSC
系列文章目录
(四)BERT for Keyphrase Extraction
(十三)jsonify_multidata.py + Constant.py
(十五)Keyphrase Chunking - bert2chunk_dataloader.py分析
文章目录
bert2joint
转化为标签(convert_to_label)
函数
def convert_to_label(filter_positions, tot_mention_list, differ_phrase_num):
……
return ngram_label, chunk_label
else:
return None, None
首先检查关键词提到的索引是否相同;
然后将keyprhase ngrams设置为+1,其他候选短语设置为-1。
ngram_label = [-1 for _ in range(differ_phrase_num)]
chunk_label_list = [
[0] * len(tot_mention_list[i]) for i in range(len(tot_mention_list))
]
for i, positions in enumerate(filter_positions):
for s, e in positions:
chunk_label_list[e - s][s] = 1
key_index = tot_mention_list[e - s][s]
ngram_label[key_index] = 1
平滑标签块
chunk_label = [_chunk for chunks in chunk_label_list for _chunk in chunks]
返回结果 保持有一个以上的正项和一个以上的负项
if (
(1 in ngram_label)
and (-1 in ngram_label)
and (1 in chunk_label)
and (0 in chunk_label)
):
return ngram_label, chunk_label
else:
return None, None
获取n-gram特征
函数
def get_ngram_features(doc_words, max_gram, stem_flag=False):
…………
return {"tot_phrase_list": tot_phrase_list, "tot_mention_list": tot_mention_list}
定义用于最后评价和训练的相应 list
phrase2index = {} # use to shuffle same phrases
tot_phrase_list = [] # use to final evaluation
tot_mention_list = [] # use to train pooling the same
函数实现,调用loader_utils.py中的whether_stem_existing()方法和whether_existing()方法,判断是否存在词干存在
gram_num = 0
for n in range(1, max_gram + 1):
valid_length = len(doc_words) - n + 1
if valid_length < 1:
break
_ngram_list = []
_mention_list = []
for i in range(valid_length):
gram_num += 1
n_gram = " ".join(doc_words[i : i + n]).lower()
if stem_flag:
index = loader_utils.whether_stem_existing(
n_gram, phrase2index, tot_phrase_list
)
else:
index = loader_utils.whether_existing(
n_gram, phrase2index, tot_phrase_list
)
_mention_list.append(index)
_ngram_list.append(n_gram)
tot_mention_list.append(_mention_list)
assert len(tot_phrase_list) > 0
assert (len(tot_phrase_list) - 1) == max(tot_mention_list[-1])
assert sum([len(_mention_list) for _mention_list in tot_mention_list]) == gram_num
return {"tot_phrase_list": tot_phrase_list, "tot_mention_list": tot_mention_list}
获取块对应的n-gram特征
定义关键词列表和块list
keyphrases_list = [" ".join(kp).lower() for kp in keyphrases]
chunk_label = []
phrase2index = {} # use to shuffle same phrases
tot_phrase_list = [] # use to final evaluation
tot_mention_list = [] # use to train pooling the same
与上文类似
gram_num = 0
for n in range(1, max_gram + 1):
valid_length = len(doc_words) - n + 1
if valid_length < 1:
break
_ngram_list = []
_mention_list = []
for i in range(valid_length):
gram_num += 1
n_gram = " ".join(doc_words[i : i + n]).lower()
if stem_flag:
index = loader_utils.whether_stem_existing(
n_gram, phrase2index, tot_phrase_list
)
else:
index = loader_utils.whether_existing(
n_gram, phrase2index, tot_phrase_list
)
不同的处理,对于块标签的处理
if n_gram in keyphrases_list:
chunk_label.append(1)
else:
chunk_label.append(0)
处理要返回的数据
_mention_list.append(index)
_ngram_list.append(n_gram)
tot_mention_list.append(_mention_list)
assert len(tot_phrase_list) > 0
assert (len(tot_phrase_list) - 1) == max(tot_mention_list[-1])
assert (
sum([len(_mention_list) for _mention_list in tot_mention_list])
== gram_num
== len(chunk_label)
)
return {
"tot_phrase_list": tot_phrase_list,
"tot_mention_list": tot_mention_list,
"chunk_label": chunk_label,
}
获取 n-gram信息标签
函数名以及返回值:
def get_ngram_info_label(
doc_words, max_phrase_words, stem_flag, keyphrases=None, start_end_pos=None
):
returns = {"overlen_flag": False, "ngram_label": None, "chunk_label": None}
定义特征,要返回内容赋值
feature = get_ngram_features(
doc_words=doc_words, max_gram=max_phrase_words, stem_flag=stem_flag
)
returns["tot_phrase_list"] = feature["tot_phrase_list"]
returns["tot_mention_list"] = feature["tot_mention_list"]
注意检查长度是否过长
if start_end_pos is not None:
filter_positions = loader_utils.limit_scope_length(
start_end_pos, len(doc_words), max_phrase_words
)
# check over_length
if len(filter_positions) != len(start_end_pos):
returns["overlen_flag"] = True
if len(filter_positions) > 0:
returns["ngram_label"], returns["chunk_label"] = convert_to_label(
**{
"filter_positions": filter_positions,
"tot_mention_list": feature["tot_mention_list"],
"differ_phrase_num": len(feature["tot_phrase_list"]),
}
)
else:
returns["ngram_label"] = None
returns["chunk_label"] = None
return returns
进行预处理
函数定义:
def bert2joint_preprocessor(
examples,
tokenizer,
max_token,
pretrain_model,
mode,
max_phrase_words,
stem_flag=False,
):
词条化,用到了tqdm库,使用进度条
overlen_num = 0
new_examples = []
for idx, ex in enumerate(tqdm(examples)):
# tokenize
tokenize_output = loader_utils.tokenize_for_bert(
doc_words=ex["doc_words"], tokenizer=tokenizer
)
if len(tokenize_output["tokens"]) < max_token:
max_word = max_token
else:
max_word = tokenize_output["tok_to_orig_index"][max_token - 1] + 1
new_ex = {}
new_ex["url"] = ex["url"]
new_ex["tokens"] = tokenize_output["tokens"][:max_token]
new_ex["valid_mask"] = tokenize_output["valid_mask"][:max_token]
new_ex["doc_words"] = ex["doc_words"][:max_word]
assert len(new_ex["tokens"]) == len(new_ex["valid_mask"])
assert sum(new_ex["valid_mask"]) == len(new_ex["doc_words"])
参数设置
parameter = {
"doc_words": new_ex["doc_words"],
"max_phrase_words": max_phrase_words,
"stem_flag": stem_flag,
}
获取gram信息和标签
info_or_label = get_ngram_info_label(**parameter)
new_ex["phrase_list"] = info_or_label["tot_phrase_list"]
new_ex["mention_lists"] = info_or_label["tot_mention_list"]
if info_or_label["overlen_flag"]:
overlen_num += 1
在获取gram信息和标签前后进行对mode的判断
对训练状态下进行处理
if mode == "train":
parameter["keyphrases"] = ex["keyphrases"]
parameter["start_end_pos"] = ex["start_end_pos"]
if mode == "train":
if not info_or_label["ngram_label"]:
continue
new_ex["keyphrases"] = ex["keyphrases"]
new_ex["ngram_label"] = info_or_label["ngram_label"]
new_ex["chunk_label"] = info_or_label["chunk_label"]
logger日志记录和返回值
new_examples.append(new_ex)
logger.info(
"Delete Overlen Keyphrase (length > 5): %d (overlen / total = %.2f"
% (overlen_num, float(overlen_num / len(examples) * 100))
+ "%)"
)
return new_examples
将每批数据转换为张量;添加 [CLS] [SEP]标记
函数定义:
def bert2joint_converter(index, ex, tokenizer, mode, max_phrase_words):
初始定义赋值
src_tokens = [BOS_WORD] + ex["tokens"] + [EOS_WORD]
valid_ids = [0] + ex["valid_mask"] + [0]
src_tensor = torch.LongTensor(tokenizer.convert_tokens_to_ids(src_tokens))
valid_mask = torch.LongTensor(valid_ids)
mention_lists = ex["mention_lists"]
orig_doc_len = sum(valid_ids)
对于训练和测试状态进行不同赋值
if mode == "train":
label = torch.LongTensor(ex["ngram_label"])
chunk_label = torch.LongTensor(ex["chunk_label"])
return (
index,
src_tensor,
valid_mask,
mention_lists,
orig_doc_len,
max_phrase_words,
label,
chunk_label,
)
else:
tot_phrase_len = len(ex["phrase_list"])
return (
index,
src_tensor,
valid_mask,
mention_lists,
orig_doc_len,
max_phrase_words,
tot_phrase_len,
)
训练加载器和评估加载器
函数定义:
def batchify_bert2joint_features_for_train(batch):
获取每一轮的id ,docs,有效掩码,文档单词长度,最大短语长度
ids = [ex[0] for ex in batch]
docs = [ex[1] for ex in batch]
valid_mask = [ex[2] for ex in batch]
mention_mask = [ex[3] for ex in batch]
doc_word_lens = [ex[4] for ex in batch]
max_phrase_words = [ex[5] for ex in batch][0]
处理标签
判断是否是短语块,同时获取单词的最大长度
label_list = [ex[6] for ex in batch] # different ngrams numbers
chunk_list = [ex[7] for ex in batch] # whether is a chunk phrase
bert_output_dim = 768
max_word_len = max([word_len for word_len in doc_word_lens]) # word-level
对每个词项,获取有效id,最大单词长度
# ---------------------------------------------------------------
# [1] [2] src tokens tensor
doc_max_length = max([d.size(0) for d in docs])
input_ids = torch.LongTensor(len(docs), doc_max_length).zero_()
input_mask = torch.LongTensor(len(docs), doc_max_length).zero_()
# segment_ids = torch.LongTensor(len(docs), doc_max_length).zero_()
for i, d in enumerate(docs):
input_ids[i, : d.size(0)].copy_(d)
input_mask[i, : d.size(0)].fill_(1)
# ---------------------------------------------------------------
# [3] valid mask tensor
valid_max_length = max([v.size(0) for v in valid_mask])
valid_ids = torch.LongTensor(len(valid_mask), valid_max_length).zero_()
for i, v in enumerate(valid_mask):
valid_ids[i, : v.size(0)].copy_(v)
# ---------------------------------------------------------------
对于 n-gram 处理mask
# [4] active mention mask : for n-gram (original)
max_ngram_length = sum([max_word_len - n for n in range(max_phrase_words)])
chunk_mask = torch.LongTensor(len(docs), max_ngram_length).fill_(-1)
for batch_i, word_len in enumerate(doc_word_lens):
pad_len = max_word_len - word_len
batch_mask = []
for n in range(max_phrase_words):
ngram_len = word_len - n
if ngram_len > 0:
assert len(mention_mask[batch_i][n]) == ngram_len
gram_list = mention_mask[batch_i][n] + [
-1 for _ in range(pad_len)
] # -1 for padding
else:
gram_list = [-1 for _ in range(max_word_len - n)]
batch_mask.extend(gram_list)
chunk_mask[batch_i].copy_(torch.LongTensor(batch_mask))
# ---------------------------------------------------------------
# [4] active mask : for n-gram
max_diff_gram_num = 1 + max(
[max(_mention_mask[-1]) for _mention_mask in mention_mask]
)
active_mask = torch.BoolTensor(
len(docs), max_diff_gram_num, max_ngram_length
).fill_(1)
# active_mask = torch.ByteTensor(len(docs), max_diff_gram_num, max_ngram_length).fill_(1) # Pytorch Version = 1.1.0
for gram_ids in range(max_diff_gram_num):
tmp = torch.where(
chunk_mask == gram_ids,
torch.LongTensor(len(docs), max_ngram_length).fill_(0),
torch.LongTensor(len(docs), max_ngram_length).fill_(1),
) # shape = (batch_size, max_ngram_length) # 1 for pad
for batch_id in range(len(docs)):
active_mask[batch_id][gram_ids].copy_(tmp[batch_id])
清空 tensor,单词 最大长度
返回
# -------------------------------------------------------------------
valid_output = torch.zeros(len(docs), max_word_len, bert_output_dim)
return (
input_ids,
input_mask,
valid_ids,
active_mask,
valid_output,
phrase_list_lens,
ids,
)