BERT的中文问答系统66

完善BERT的中文问答系统65后的代码
简述
异常处理:增加了更多的异常处理,确保程序在遇到错误时能够优雅地退出或提示用户。
多线程:使用 threading 模块处理耗时操作,避免阻塞主界面。
配置文件管理:优化了配置文件的读取和写入,确保配置文件的完整性和安全性。
日志记录:增强了日志记录功能,记录更多详细的日志信息,便于调试和维护。
用户反馈:增加了更多的用户反馈机制,如进度条、提示信息等,提高用户体验。
模型保存与加载:优化了模型的保存和加载过程,确保模型的完整性和可恢复性。
数据加载:优化了数据加载过程,确保数据的完整性和正确性。
界面美化:优化了界面布局和样式,提高了界面的美观度和可用性。
新增功能:在GUI中增加了多个按钮,分别对应不同的模型类型,点击按钮会自动完成数据集收集及模型训练保存等操作。

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk, simpledialog
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup
import tkcalendar
import locale
import threading
import configparser

# 设置本地化为中文
locale.setlocale(locale.LC_ALL, 'zh_CN.UTF-8')

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        try:
            if file_path.endswith('.jsonl'):
                with jsonlines.open(file_path) as reader:
                    return [item for item in reader]
            elif file_path.endswith('.json'):
                with open(file_path, 'r') as f:
                    return json.load(f)
        except (json.JSONDecodeError, jsonlines.jsonlines.InvalidLineError) as e:
            logging.warning(f"加载数据失败: {
     
     e}")
            return []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item.get('question', '')
        human_answer = item.get('human_answers', [''])[0]
        chatgpt_answer = item.get('chatgpt_answers', [''])[0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {
     
     idx}: {
     
     e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
   
   
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
    model.train()
    total_loss = 0.0
    num_batches = len(data_loader)
    for batch_idx, batch in enumerate(data_loader):
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            optimizer.zero_grad()
            human_logits = model(human_input_ids, human_attention_mask)
            chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

            human_labels = torch.ones(human_logits.size(0), 1).to(device)
            chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

            loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if progress_var:
                progress_var.set((batch_idx + 1) / num_batches * 100)
        except Exception as e:
            logging.warning(f"跳过无效批次: {
     
     e}")

    return total_loss / len(data_loader)

# 模型评估函数
def evaluate_model(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            human_logits = model(human_input_ids, human_attention_mask)
            chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

            human_labels = torch.ones(human_logits.size(0), 1).to(device)
            chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

            human_correct = (torch.sigmoid(human_logits) > 0.5).float() == human_labels
            chatgpt_correct = (torch.sigmoid(chatgpt_logits) > 0.5)
BERT中文问答系统是一种基于BERT预训练模型的自然语言处理技术,用于回答中文问答系统中的问题。其主要思想是将问题和文本段落作为输入,然后使用BERT模型对其进行编码,最后使用softmax函数计算每个文本段落与问题的匹配程度,从而选择最佳答案。 以下是BERT中文问答系统的实现步骤: 1.准备数据集:将问题和文本段落组成的数据集进行预处理,将其转换为BERT模型可以接受的格式。 2.加载预训练模型:使用huggingface transformers库加载预训练的BERT模型。 3.对输入进行编码:使用BERT模型对问题和文本段落进行编码,得到它们的向量表示。 4.计算匹配程度:使用softmax函数计算每个文本段落与问题的匹配程度,从而选择最佳答案。 5.输出答案:输出匹配程度最高的文本段落作为答案。 以下是一个使用BERT中文问答系统回答问题的例子: ```python from transformers import BertTokenizer, BertForQuestionAnswering import torch # 加载预训练模型和分词器 model = BertForQuestionAnswering.from_pretrained('bert-base-chinese') tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') # 输入问题和文本段落 question = "什么是BERT中文问答系统?" text = "BERT中文问答系统是一种基于BERT预训练模型的自然语言处理技术,用于回答中文问答系统中的问题。" # 对输入进行编码 input_ids = tokenizer.encode(question, text) tokens = tokenizer.convert_ids_to_tokens(input_ids) # 获取答案 start_scores, end_scores = model(torch.tensor([input_ids])) start_index = torch.argmax(start_scores) end_index = torch.argmax(end_scores) answer = ''.join(tokens[start_index:end_index+1]).replace('##', '') # 输出答案 print(answer) # 输出:一种基于BERT预训练模型的自然语言处理技术,用于回答中文问答系统中的问题。 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yehaiwz

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值