项目概述
大家好!作为一名和大家一样在编程路上慢慢摸索的开发者,我深知从零开始搭建一个聊天机器人时的迷茫与挑战。上面这些代码,是我一点点调试、修改、完善出来的成果 —— 它可能不够完美,也没有用到最前沿的技术,但每一行都凝结着实践中的思考。今天把它分享出来,就是希望能帮更多新手少走弯路,快速拥有一个属于自己的聊天机器人。本文所选用数据集是xiaohuangji50w_nofenci.conv,可前往https://github.com/candlewill/Dialog_Corpus进行下载。
核心功能解析
1. 数据处理模块
数据处理是聊天机器人实现的基础,下面分别用两个部分代码进行数据处理:
- 对话数据读取:
read_dialog_file
函数负责从特定格式的文件中读取对话数据,识别用户(user)和机器(machine)的对话轮次,并以结构化方式存储
def read_dialog_file(file_path):
dialogs = []
current_dialog = []
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip()
if not line:
continue
prefix = line[0]
content = line[2:]
if prefix == 'E': # 对话结束标记
if current_dialog:
dialogs.append(current_dialog)
current_dialog = []
elif prefix == 'M': # 对话内容标记
if not current_dialog:
current_dialog.append(('user', content))
else:
current_dialog.append(('machine', content))
if current_dialog:
dialogs.append(current_dialog)
return dialogs
词汇表构建:build_vocab
函数从对话数据中提取所有出现的字符,构建词汇表并为每个字符分配唯一索引,同时加入特殊标记<pad>
(填充)、<sos>
(序列开始)和<eos>
(序列结束)
def read_dialog_file(file_path):
dialogs = []
current_dialog = []
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip()
if not line:
continue
# 分割每行的前缀和内容
prefix = line[0]
content = line[2:]
if prefix == 'E':
# 如果当前对话不为空,保存并重置
if current_dialog:
dialogs.append(current_dialog)
current_dialog = []
elif prefix == 'M':
if not current_dialog:
current_dialog.append(('user', content))
else:
current_dialog.append(('machine', content))
# 保存最后一个对话
if current_dialog:
dialogs.append(current_dialog)
return dialogs
def preprocess_dialogs(dialogs):
preprocessed_dialogs = []
for dialog in dialogs:
preprocessed_dialog = []
for turn in dialog:
# 去除斜杠并合并分词后的词
text = turn[1].replace('/', '')
preprocessed_dialog.append((turn[0], text))
preprocessed_dialogs.append(preprocessed_dialog)
return preprocessed_dialogs
def build_vocab(dialogs):
vocab = set()
for dialog in dialogs:
for turn in dialog:
for char in turn[1]:
vocab.add(char)
return sorted(vocab)
def save_vocab(vocab, file_path):
with open(file_path, 'w', encoding='utf-8') as file:
for word in vocab:
file.write(word + '\n')
# 读取文件
file_path = '../data/xiaohuangji50w_nofenci.conv' # 替换为你的文件路径
dialogs = read_dialog_file(file_path)
# 预处理对话
preprocessed_dialogs = preprocess_dialogs(dialogs)
# 构建词汇表
vocab = build_vocab(preprocessed_dialogs)
# 保存词汇表
vocab_file_path = '../data/xhj_vocab.txt' # 替换为你希望保存的词汇表文件路径
save_vocab(vocab, vocab_file_path)
print(f"词汇表已保存到 {vocab_file_path}")
文本与向量转换:convert
函数实现文本与整数向量之间的双向转换,是自然语言与模型输入输出之间的桥梁。在之后代码中会用到。
# 字符向量的转换
def convert(char_list, mode, vocab):
con = []
if mode == "word2vec":
conver = dict((x, y) for x, y in vocab.items())
if isinstance(char_list, str):
char_list = list(char_list)
con = [vocab["<sos>"]] + [conver.get(char, vocab["<pad>"]) for char in char_list] + [vocab["<eos>"]]
elif mode == "vec2word":
conver = dict((y, x) for x, y in vocab.items())
for char in char_list:
con.append(conver[char])
return con
2. 模型架构
实现聊天机器人的核心模型,采用了带注意力机制的 Encoder-Decoder 架构:
- 编码器(Encoder):
接收输入序列(用户话语),通过嵌入层将词索引转换为词向量,再经双向 GRU 处理,输出所有时间步的隐藏状态和最终的上下文向量
class Encoder(nn.Module):
def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(input_dim, emb_dim)
self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True, num_layers=2)
self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, src):
embedded = self.dropout(self.embedding(src))
outputs, hidden = self.rnn(embedded)
# 融合双向GRU的隐藏状态
hidden = torch.tanh(self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)))
return outputs, hidden
- 注意力机制(Attention):
解决传统 Encoder-Decoder 架构中上下文向量难以处理长序列的问题,使解码器在生成每个词时能关注输入序列的不同部分
class Attention(nn.Module):
def __init__(self, enc_hid_dim, dec_hid_dim):
super().__init__()
self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
self.v = nn.Parameter(torch.rand(dec_hid_dim))
def forward(self, hidden, encoder_outputs):
# 计算注意力权重并返回注意力分布
batch_size = encoder_outputs.shape[1]
src_len = encoder_outputs.shape[0]
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
encoder_outputs = encoder_outputs.permute(1, 0, 2)
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
energy = energy.permute(0, 2, 1)
v = self.v.repeat(batch_size, 1).unsqueeze(1)
attention = torch.bmm(v, energy).squeeze(1)
return F.softmax(attention, dim=1)
-
解码器(Decoder):
以编码器输出和上一时间步的输出为输入,通过 GRU 和注意力机制生成当前时间步的输出 -
Seq2Seq 模型:
整合编码器和解码器,实现端到端的序列转换功能,支持教师强制(teacher forcing)机制加速训练
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device="CPU"):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(self, src, trg, teacher_forcing_ratio=0.5,max_length=30):
batch_size = src.shape[1]
trg_len = trg.shape[0]
trg_vocab_size = self.decoder.output_dim
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
encoder_outputs, hidden = self.encoder(src)
input = trg[0, :]
for t in range(1, trg_len):
output, hidden = self.decoder(input, hidden, encoder_outputs) # output:torch.Size([batch_size, vocab_size])
outputs[t] = output
teacher_force = random.random() < teacher_forcing_ratio
top1 = output.argmax(1)
input = trg[t] if teacher_force else top1
if all(top1.item() == 2 for top1 in top1):
break
return outputs
3. 模型训练
实现模型的训练流程:
- 数据加载:
GetDATA
函数按批次加载并预处理数据,将文本转换为模型可接受的张量形式 - 训练循环:
train_model
函数定义了完整的训练过程,包括前向传播、损失计算、反向传播和参数更新 - 优化策略:使用 Adam 优化器和学习率调度器,采用交叉熵损失函数,通过保存验证集表现最好的模型来防止过拟合
def train_model(model, batch_size, vocab, dialogs, optimizer, scheduler, criterion, device, epochs=10, teacher_forcing_ratio=0.5):
best_loss = float('inf')
for epoch in range(epochs):
model.train()
total_train_loss = 0
for batch in tqdm(range(int(train_sample / batch_size))):
src, trg = GetDATA(batch_size, vocab, dialogs, count=batch)
src = src.transpose(0, 1).to(device)
trg = trg.transpose(0, 1).to(device)
optimizer.zero_grad()
output = model(src, trg, teacher_forcing_ratio)
loss = criterion(output[1:].reshape(-1, output.shape[-1]), trg[1:].reshape(-1))
loss.backward()
optimizer.step()
total_train_loss += loss.item()
# 保存最佳模型
if avg_train_loss < best_loss:
best_loss = avg_train_loss
torch.save(model.state_dict(), 'best_model.pth')
4. 模型推理
ED代码流程.py展示了如何使用训练好的模型进行对话生成:
- 加载训练好的模型参数
- 将用户输入转换为模型可接受的向量形式
- 利用解码器逐步生成回复,直到生成
<eos>
标记或达到最大长度 - 将生成的向量转换回文本形式,得到最终回复
# 模型推理过程
batch_size = src.shape[1]
trg = torch.zeros((1, batch_size), dtype=torch.long).fill_(vocab["<sos>"]).to(device)
src = src.to(device)
outputs = []
max_length = 50 # 最大生成长度
with torch.no_grad():
encoder_outputs, hidden = model.encoder(src)
for t in range(1, max_length):
output, hidden = model.decoder(trg[-1], hidden, encoder_outputs)
top1 = output.argmax(1) # 贪婪解码
outputs.append(top1)
trg = torch.cat((trg, top1.unsqueeze(0)), dim=0)
# 如果所有序列都生成了<eos>标记,则停止生成
if all(top1.item() == vocab["<eos>"] for top1 in top1):
break
最后效果展示:
(因为我自己使用的笔记本电脑用CPU训练出来的,只使用了前12800条数据集,所以效果只能说还行......)
功能扩展与优化思路
下面提供一些可扩展的功能思路:
- 对话历史记忆:通过
GetDATA_with_history
函数可以将多轮对话历史作为上下文输入,增强模型的上下文理解能力 - 多样化回复生成:通过温度参数(Temperature)控制输出概率分布的随机性,结合随机采样替代贪婪选择,生成更多样化的回复
- 回复质量控制:通过去重、过滤无效回复等策略提高生成回复的质量
写在最后
技术的进步从来不是一蹴而就的,我现在回头看最初的版本,也会觉得笨拙。但正是这些不完美的尝试,让我慢慢理解了聊天机器人的工作原理。
如果你在运行代码时遇到报错,别着急 —— 这太正常了!看看错误提示指向哪一行,想想这一步是在做什么(比如数据格式不对?模型参数不匹配?),试着改改看。解决问题的过程,就是进步最快的时候。
希望这份代码能成为你探索自然语言处理的起点。如果它能帮你少走一些弯路,或者让你感受到搭建机器人的乐趣,那我就很开心了。
最后想说:编程的乐趣不在于写出完美的代码,而在于亲手创造出能工作的东西。开始动手吧,你的第一个聊天机器人,可能比你想象的离你更近!