import tensorflow as tf
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'
PAD_TOKEN = '[PAD]'
UNKNOWN_TOKEN = '[UNK]'
START_DECODING = '[START]'
STOP_DECODING = '[STOP]'
#vocab_file=》vocab.txt,max_size=30000
#建立词表类,与之前vocab.txt的区别在于,加入对特殊符号的处理
class Vocab:
def __init__(self, vocab_file, max_size):
self.word2id = {UNKNOWN_TOKEN: 0, PAD_TOKEN: 1, START_DECODING: 2, STOP_DECODING: 3}
self.id2word = {0: UNKNOWN_TOKEN, 1: PAD_TOKEN, 2: START_DECODING, 3: STOP_DECODING}
self.count = 4
with open(vocab_file, 'r', encoding='utf-8') as f:
#读取词表的每一行,应该的格式形如“说 0”
for line in f:
pieces = line.split()
if len(pieces) != 2:
#跳过不合法的数据
print('Warning : incorrectly formatted line in vocabulary file : %s\n' % line)
continue
w = pieces[0]#取单词,出现非预期词时报错
if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
raise Exception(r'<s>, </s>, [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, '
r'but %s is' % w)
#出现重复词时报错
if w in self.word2id:
raise Exception('Duplicated word in vocabulary file: %s' % w)
#建立双向词表,字典形式
self.word2id[w] = self.count
self.id2word[self.count] = w
self.count += 1
#超过最大值时报错退出
if max_size != 0 and self.count >= max_size:
print("max_size of vocab was specified as %i; we now have %i words. Stopping reading."
% (max_size, self.count))
break
print("Finished constructing vocabulary of %i total words. Last word added: %s" %
(self.count, self.id2word[self.count - 1]))
#根据词查对应id,遇到OOV返回UNK:0
def word_to_id(self, word):
if word not in self.word2id:
return self.word2id[UNKNOWN_TOKEN]
return self.word2id[word]
#根据id查词,id不合法时报错
def id_to_word(self, word_id):
if word_id not in self.id2word:
raise ValueError('Id not found in vocab: %d' % word_id)
return self.id2word[word_id]
#词表实际大小
def size(self):
return self.count
#将文章中的OOV词合集oovs、扩充后的词表大小ids
def article_to_ids(article_words, vocab):
ids = []
oovs = []
unk_id = vocab.word_to_id(UNKNOWN_TOKEN)
for w in article_words:
i = vocab.word_to_id(w)
if i == unk_id: # If w is OOV
if w not in oovs: # Add to list of OOVs
oovs.append(w)
oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...
ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...
else:
ids.append(i)
return ids, oovs
#找到摘要中的OOVs,核对其是不是文章OOVs中的一员,若不是,继续添加
def abstract_to_ids(abstract_words, vocab, article_oovs):
ids = []
unk_id = vocab.word_to_id(UNKNOWN_TOKEN)
for w in abstract_words:
i = vocab.word_to_id(w)
if i == unk_id: # If w is an OOV word
if w in article_oovs: # If w is an in-article OOV
vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number
ids.append(vocab_idx)
else: # If w is an out-of-article OOV
ids.append(unk_id) # Map to the UNK token id
else:
ids.append(i)
return ids
#未看,暂无引用
def output_to_words(id_list, vocab, article_oovs):
words = []
for i in id_list:
try:
w = vocab.id_to_word(i) # might be [UNK]
except ValueError as e: # w is OOV
assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. " \
"This should not happen in baseline (no pointer-generator) mode"
article_oov_idx = i - vocab.size()
try:
w = article_oovs[article_oov_idx]
except ValueError as e: # i doesn't correspond to an article oov
raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this '
'example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs)))
words.append(w)
return words
#暂无引用
def abstract_to_sents(abstract):
"""
Splits abstract text from datafile into list of sentences.
Args:
abstract: string containing <s> and </s> tags for starts and ends of sentences
Returns:
sents: List of sentence strings (no tags)
"""
cur = 0
sents = []
while True:
try:
start_p = abstract.index(SENTENCE_START, cur)
end_p = abstract.index(SENTENCE_END, start_p + 1)
cur = end_p + len(SENTENCE_END)
sents.append(abstract[start_p + len(SENTENCE_START): end_p])
except ValueError as e: # no more sentences
return sents
def get_dec_inp_targ_seqs(sequence, max_len, start_id, stop_id):
"""
Given the reference summary as a sequence of tokens, return the input sequence for the decoder,
and the target sequence which we will use to calculate loss. The sequence will be truncated if it is longer
than max_len. The input sequence must start with the start_id and the target sequence must end with the stop_id
(but not if it's been truncated).
Args:
sequence: List of ids (integers)
max_len: integer
start_id: integer
stop_id: integer
Returns:
inp: sequence length <=max_len starting with start_id
target: sequence same length as input, ending with stop_id only if there was no truncation
"""
inp = [start_id] + sequence[:]
target = sequence[:]
#如果输入长度超过限定最大长度max,inp取start_id+前(max-1)个sequence,target取前max个;否则target加上stop_id,
if len(inp) > max_len: # truncate
inp = inp[:max_len]
target = target[:max_len] # no end_token
else: # no truncation
target.append(stop_id) # end token
#断言,如果实际情况不满足后面条件,则退出返回错误
assert len(inp) == len(target)#不管是一个加STA还是另一个加END,最终应该长度都相等,要不是max,要不是len+1
return inp, target
def example_generator(vocab, train_x_path, train_y_path, test_x_path, max_enc_len, max_dec_len, mode, batch_size):
#训练数据处理
if mode == "train":
#提供文件名自动构造一个dataset/
dataset_train_x = tf.data.TextLineDataset(train_x_path)
dataset_train_y = tf.data.TextLineDataset(train_y_path)
#通过给定的数据集压缩构造一个数据集,形如[(x1,y1),(x2,y2),(x3,y3)]
train_dataset = tf.data.Dataset.zip((dataset_train_x, dataset_train_y))
# train_dataset = train_dataset.shuffle(1000, reshuffle_each_iteration=True).repeat()
# i = 0
#print("gen",train_dataset)
for raw_record in train_dataset:
#编码转换
article = raw_record[0].numpy().decode("utf-8")
#print("article",article)
#article 新车 , 全款 , 买 了 半个 月 , 去 4S店 贴膜 时才 发现 右侧 尾灯 下 (...)。 车主 说 : 恩
abstract = raw_record[1].numpy().decode("utf-8")
#print("abstract",abstract)
#abstract 你好 , 像 这�