import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from torch.optim import AdamW
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/train.csv')
df['combined_text'] = df['QuestionText'].fillna('') + ' ' + df['StudentExplanation'].fillna('')+ ' ' + df['MC_Answer'].fillna('')
le_cat = LabelEncoder()
le_mis = LabelEncoder()
df['Category_enc'] = le_cat.fit_transform(df['Category'])
df['Misconception'] = df['Misconception'].fillna('NA')
df['Misconception_enc'] = le_mis.fit_transform(df['Misconception'])
train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)
class MultiTaskDataset(Dataset):
def __init__(self, df, tokenizer, max_len=256):
self.df = df
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
encoding = self.tokenizer(row['combined_text'], truncation=True, padding='max_length', max_length=self.max_len, return_tensors="pt")
return {
'input_ids': encoding['input_ids'].squeeze(),
'attention_mask': encoding['attention_mask'].squeeze(),
'label_category': torch.tensor(row['Category_enc'], dtype=torch.long),
'label_misconception': torch.tensor(row['Misconception_enc'], dtype=torch.long)
}
class MultiTaskClassifier(nn.Module):
def __init__(self, num_cat_classes, num_mis_classes):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.dropout = nn.Dropout(0.3)
self.cat_head = nn.Linear(768, num_cat_classes)
self.mis_head = nn.Linear(768, num_mis_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled = self.dropout(outputs.pooler_output)
cat_logits = self.cat_head(pooled)
mis_logits = self.mis_head(pooled)
return cat_logits, mis_logits
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_dataset = MultiTaskDataset(train_df, tokenizer)
val_dataset = MultiTaskDataset(val_df, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)
model = MultiTaskClassifier(num_cat_classes=len(le_cat.classes_), num_mis_classes=len(le_mis.classes_)).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
def train_one_epoch():
model.train()
total_loss = 0
for batch in train_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
label_cat = batch['label_category'].to(device)
label_mis = batch['label_misconception'].to(device)
optimizer.zero_grad()
cat_logits, mis_logits = model(input_ids, attention_mask)
loss_cat = loss_fn(cat_logits, label_cat)
loss_mis = loss_fn(mis_logits, label_mis)
loss = loss_cat + loss_mis
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Train loss: {total_loss / len(train_loader):.4f}")
def evaluate():
model.eval()
preds, labels = [], []
with torch.no_grad():
for batch in val_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
label_cat = batch['label_category'].to(device)
label_mis = batch['label_misconception'].to(device)
cat_logits, cat_mis = model(input_ids, attention_mask)
pred_cat = torch.argmax(cat_logits, dim=1)
pred_mis = torch.argmax(cat_mis, dim=1)
preds_cat.extend(pred_cat.cpu().numpy())
preds_mis.extend(pred_mis.cpu().numpy())
labels_cat.extend(label_cat.cpu().numpy())
labels_mis.extend(label_cat.cpu().numpy())
print("分类报告:\n", classification_report(labels_cat, preds_cat, target_names=le_cat.classes_))
print("分类报告:\n", classification_report(labels_mis, preds_mis, target_names=le_mis.classes_))
for epoch in range(3):
print(f"Epoch {epoch+1}")
train_one_epoch()
evaluate()