error: #929: incorrect use of va_arg

本文探讨了FATFS在Keil环境下编译时遇到的关于stdarg.h头文件冲突导致的一系列错误,包括va_start和va_arg使用不当等问题,并提供了检查系统环境变量和更新FATFS版本的解决方案。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

..\FATFS\src\ff.c(3995): error:  #928: incorrect use of va_start
..\FATFS\src\ff.c(3995): error:  #253: expected a ","
..\FATFS\src\ff.c(3995): error:  #29: expected an expression
..\FATFS\src\ff.c(4026): error:  #929: incorrect use of va_arg
..\FATFS\src\ff.c(4037): error:  #929: incorrect use of va_arg
..\FATFS\src\ff.c(4052): error:  #929: incorrect use of va_arg
..\FATFS\src\ff.c(4052): error:  #929: incorrect use of va_arg
..\FATFS\src\ff.c(4052): error:  #929: incorrect use of va_arg



你这是头文件包含有问题吧。我用4.70a,很好用啊。



原子哥和各位大神帮忙看下是什么原因吧,错误如下:            
..\FATFS\src\ff.c(3995): error:  #928: incorrect use of va_start 
..\FATFS\src\ff.c(3995): error:  #253: expected "," 
..\FATFS\src\ff.c(3995): error:  #29: expected an expression 
..\FATFS\src\ff.c(4026): error:  #929: incorrect use of va_arg 
..\FATFS\src\ff.c(4037): error:  #929: incorrect use of va_arg 
..\F 
...... 
--------------------------------- 
把fprintf屏蔽掉吧,



我也遇到了。。不知道怎么解决,只能屏蔽fprintf...


用4.72A编译正常。


这个主要是因为keil与其他编译环境冲突造成;你可以看一下stdarg.h的路径是否是在keil的安装目录下;我的系统就是在安装了ADS1.2后才出现的这个问题;查看系统环境变量就可以看到ARMINC变量被注册到ADS的安装目录下;卸载ADS与KEIL后,重启计算机,安装软件,问题解决;
RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. !pip install transformers datasets torch rouge-score matplotlib import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from transformers import BertTokenizerFast import time import numpy as np from datasets import load_dataset from rouge_score import rouge_scorer import matplotlib.pyplot as plt from IPython.display import clear_output # 设备配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 数据预处理(严格过滤无效样本) class SummaryDataset(Dataset): def __init__(self, dataset_split, tokenizer, max_article_len=384, max_summary_len=96, subset_size=0.01): self.tokenizer = tokenizer self.max_article_len = max_article_len self.max_summary_len = max_summary_len self.subset = dataset_split.select(range(int(len(dataset_split) * subset_size))) # 严格过滤无效样本 self.articles = [] self.summaries = [] self.vocab = set(tokenizer.vocab.keys()) for item in self.subset: article = item[&#39;article&#39;].strip() summary = item[&#39;highlights&#39;].strip() if len(article) > 20 and len(summary) > 10: article_tokens = tokenizer.tokenize(article) summary_tokens = tokenizer.tokenize(summary) if all(t in self.vocab for t in article_tokens) and all(t in self.vocab for t in summary_tokens): self.articles.append(article) self.summaries.append(summary) self.pad_token_id = tokenizer.pad_token_id self.unk_token_id = tokenizer.unk_token_id def __len__(self): return len(self.articles) def __getitem__(self, idx): src = self.tokenizer( self.articles[idx], max_length=self.max_article_len, truncation=True, padding=&#39;max_length&#39;, return_tensors=&#39;pt&#39;, add_special_tokens=True ) tgt = self.tokenizer( self.summaries[idx], max_length=self.max_summary_len, truncation=True, padding=&#39;max_length&#39;, return_tensors=&#39;pt&#39;, add_special_tokens=True ) tgt_labels = tgt[&#39;input_ids&#39;].squeeze() tgt_labels[tgt_labels == self.pad_token_id] = -100 # 忽略填充 tgt_labels[tgt_labels >= len(self.tokenizer.vocab)] = self.unk_token_id # 过滤无效id return { &#39;input_ids&#39;: src[&#39;input_ids&#39;].squeeze(), &#39;attention_mask&#39;: src[&#39;attention_mask&#39;].squeeze(), &#39;labels&#39;: tgt_labels } # 基础Seq2Seq模型 class BasicEncoder(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0) self.gru = nn.GRU(emb_dim, hidden_dim, num_layers=2, batch_first=True, bidirectional=True) self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim) def forward(self, src): embedded = self.embedding(src) outputs, hidden = self.gru(embedded) # 取第二层双向隐藏状态 forward_hidden = hidden[-2, :, :] # 第二层正向 backward_hidden = hidden[-1, :, :] # 第二层反向 hidden = torch.cat([forward_hidden, backward_hidden], dim=1) # (batch, 2*hidden_dim) hidden = self.fc_hidden(hidden).unsqueeze(0) # (1, batch, hidden_dim) return hidden class BasicDecoder(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0) self.gru = nn.GRU(emb_dim + hidden_dim, hidden_dim, num_layers=1, batch_first=True) self.fc = nn.Linear(hidden_dim * 2 + emb_dim, vocab_size) def forward(self, input_ids, hidden, context): input_embedded = self.embedding(input_ids.unsqueeze(1)) # (batch, 1, emb_dim) input_combined = torch.cat([input_embedded, context.unsqueeze(1)], dim=2) # (batch, 1, emb_dim+hidden_dim) output, hidden = self.gru(input_combined, hidden) # (batch, 1, hidden_dim) output = output.squeeze(1) # (batch, hidden_dim) combined = torch.cat([output, context, input_embedded.squeeze(1)], dim=1) # (batch, 2*hidden_dim+emb_dim) logits = self.fc(combined) return logits, hidden class BasicSeq2Seq(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.encoder = BasicEncoder(vocab_size, emb_dim, hidden_dim) self.decoder = BasicDecoder(vocab_size, emb_dim, hidden_dim) self.device = device self.sos_token_id = 101 # [CLS] self.eos_token_id = 102 # [SEP] self.unk_token_id = 100 # [UNK] def forward(self, src, tgt): hidden = self.encoder(src) context = hidden.squeeze(0) batch_size, tgt_len = tgt.size() outputs = torch.zeros(batch_size, tgt_len, self.decoder.fc.out_features).to(device) input_ids = tgt[:, 0] for t in range(1, tgt_len): logits, hidden = self.decoder(input_ids, hidden, context) outputs[:, t] = logits input_ids = tgt[:, t] return outputs def generate(self, src, max_length=80): src = src.to(device) hidden = self.encoder(src) context = hidden.squeeze(0) # 修正后的生成初始化 generated = torch.full((src.size(0), 1), self.sos_token_id, device=device) # 注意这里的修正 for _ in range(max_length-1): logits, hidden = self.decoder(generated[:, -1], hidden, context) next_token = torch.argmax(logits, dim=1, keepdim=True) # 防止过早生成标点 if generated.size(1) < 5: punctuation = [&#39;,&#39;, &#39;.&#39;, &#39;;&#39;, &#39;:&#39;, &#39;!&#39;, &#39;?&#39;, "&#39;", &#39;"&#39;, &#39;`&#39;, &#39;~&#39;] punct_ids = [self.tokenizer.convert_tokens_to_ids(p) for p in punctuation] if next_token.item() in punct_ids: # 替换为最常见的实词 next_token = torch.tensor([[self.tokenizer.convert_tokens_to_ids(&#39;the&#39;)]], device=device) generated = torch.cat([generated, next_token], dim=1) if (next_token == self.eos_token_id).all(): break return generated # 注意力Seq2Seq模型 class Attention(nn.Module): def __init__(self, hidden_dim): super().__init__() self.W = nn.Linear(2 * hidden_dim, hidden_dim) self.v = nn.Linear(hidden_dim, 1, bias=False) def forward(self, hidden, encoder_outputs): src_len = encoder_outputs.size(1) hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) # (batch, src_len, hidden_dim) combined = torch.cat([hidden, encoder_outputs], dim=2) # (batch, src_len, 2*hidden_dim) energy = self.v(torch.tanh(self.W(combined))).squeeze(2) # (batch, src_len) return torch.softmax(energy, dim=1) class AttnEncoder(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0) self.lstm = nn.LSTM(emb_dim, hidden_dim, num_layers=2, batch_first=True, bidirectional=True, dropout=0.1) self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim) # 双向输出拼接 self.fc_cell = nn.Linear(hidden_dim * 2, hidden_dim) def forward(self, src): embedded = self.embedding(src) outputs, (hidden, cell) = self.lstm(embedded) # outputs: (batch, src_len, 2*hidden_dim) # 取第二层双向隐藏状态 hidden = torch.cat([hidden[-2, :, :], hidden[-1, :, :]], dim=1) # (batch, 2*hidden_dim) cell = torch.cat([cell[-2, :, :], cell[-1, :, :]], dim=1) hidden = self.fc_hidden(hidden).unsqueeze(0) # (1, batch, hidden_dim) cell = self.fc_cell(cell).unsqueeze(0) return outputs, (hidden, cell) class AttnDecoder(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0) self.attention = Attention(hidden_dim) self.lstm = nn.LSTM(emb_dim + 2 * hidden_dim, hidden_dim, num_layers=1, batch_first=True) self.fc = nn.Linear(hidden_dim + emb_dim, vocab_size) def forward(self, input_ids, hidden, cell, encoder_outputs): input_embedded = self.embedding(input_ids.unsqueeze(1)) # (batch, 1, emb_dim) attn_weights = self.attention(hidden.squeeze(0), encoder_outputs) # (batch, src_len) context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs) # (batch, 1, 2*hidden_dim) lstm_input = torch.cat([input_embedded, context], dim=2) # (batch, 1, emb_dim+2*hidden_dim) output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell)) # output: (batch, 1, hidden_dim) logits = self.fc(torch.cat([output.squeeze(1), input_embedded.squeeze(1)], dim=1)) # (batch, vocab_size) return logits, hidden, cell class AttnSeq2Seq(nn.Module): def __init__(self, vocab_size, emb_dim=128, hidden_dim=256): super().__init__() self.encoder = AttnEncoder(vocab_size, emb_dim, hidden_dim) self.decoder = AttnDecoder(vocab_size, emb_dim, hidden_dim) self.device = device self.sos_token_id = 101 # [CLS] self.eos_token_id = 102 # [SEP] self.unk_token_id = 100 # [UNK] def forward(self, src, tgt): encoder_outputs, (hidden, cell) = self.encoder(src) batch_size, tgt_len = tgt.size() outputs = torch.zeros(batch_size, tgt_len, self.decoder.fc.out_features).to(device) input_ids = tgt[:, 0] for t in range(1, tgt_len): logits, hidden, cell = self.decoder(input_ids, hidden, cell, encoder_outputs) outputs[:, t] = logits input_ids = tgt[:, t] return outputs def generate(self, src, max_length=80): encoder_outputs, (hidden, cell) = self.encoder(src) # 修正后的生成初始化 generated = torch.full((src.size(0), 1), self.sos_token_id, device=device) # 注意这里的修正 for _ in range(max_length-1): logits, hidden, cell = self.decoder(generated[:, -1], hidden, cell, encoder_outputs) next_token = torch.argmax(logits, dim=1, keepdim=True) # 防止过早生成标点 if generated.size(1) < 5: punctuation = [&#39;,&#39;, &#39;.&#39;, &#39;;&#39;, &#39;:&#39;, &#39;!&#39;, &#39;?&#39;, "&#39;", &#39;"&#39;, &#39;`&#39;, &#39;~&#39;] punct_ids = [self.tokenizer.convert_tokens_to_ids(p) for p in punctuation] if next_token.item() in punct_ids: # 替换为最常见的实词 next_token = torch.tensor([[self.tokenizer.convert_tokens_to_ids(&#39;the&#39;)]], device=device) generated = torch.cat([generated, next_token], dim=1) if (next_token == self.eos_token_id).all(): break return generated # Transformer模型 class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer(&#39;pe&#39;, pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)] class TransformerModel(nn.Module): def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=3, dim_feedforward=512, max_len=5000): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0) self.pos_encoder = PositionalEncoding(d_model, max_len) # 编码器 encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.1) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers) # 解码器 decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout=0.1) self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers) self.fc = nn.Linear(d_model, vocab_size) self.d_model = d_model self.sos_token_id = 101 # [CLS] self.eos_token_id = 102 # [SEP] def _generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float(&#39;-inf&#39;)).masked_fill(mask == 1, float(0.0)) return mask def forward(self, src, tgt): src_mask = None tgt_mask = self._generate_square_subsequent_mask(tgt.size(1)).to(device) src_key_padding_mask = (src == 0) tgt_key_padding_mask = (tgt == 0) src = self.embedding(src) * np.sqrt(self.d_model) src = self.pos_encoder(src) tgt = self.embedding(tgt) * np.sqrt(self.d_model) tgt = self.pos_encoder(tgt) memory = self.transformer_encoder(src.transpose(0, 1), src_mask, src_key_padding_mask) output = self.transformer_decoder( tgt.transpose(0, 1), memory, tgt_mask, None, tgt_key_padding_mask, src_key_padding_mask ) output = self.fc(output.transpose(0, 1)) return output def generate(self, src, max_length=80): src_mask = None src_key_padding_mask = (src == 0) src = self.embedding(src) * np.sqrt(self.d_model) src = self.pos_encoder(src) memory = self.transformer_encoder(src.transpose(0, 1), src_mask, src_key_padding_mask) batch_size = src.size(0) generated = torch.full((batch_size, 1), self.sos_token_id, device=device) for i in range(max_length-1): tgt_mask = self._generate_square_subsequent_mask(generated.size(1)).to(device) tgt_key_padding_mask = (generated == 0) tgt = self.embedding(generated) * np.sqrt(self.d_model) tgt = self.pos_encoder(tgt) output = self.transformer_decoder( tgt.transpose(0, 1), memory, tgt_mask, None, tgt_key_padding_mask, src_key_padding_mask ) output = self.fc(output.transpose(0, 1)[:, -1, :]) next_token = torch.argmax(output, dim=1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) if (next_token == self.eos_token_id).all(): break return generated # 训练函数 def train_model(model, train_loader, optimizer, criterion, epochs=3): model.train() optimizer = optim.Adam(model.parameters(), lr=1e-4) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, &#39;min&#39;, patience=1, factor=0.5) start_time = time.time() for epoch in range(epochs): total_loss = 0 model.train() for i, batch in enumerate(train_loader): src = batch[&#39;input_ids&#39;].to(device) tgt = batch[&#39;labels&#39;].to(device) optimizer.zero_grad() outputs = model(src, tgt[:, :-1]) # 检查模型输出有效性 if torch.isnan(outputs).any(): print("警告:模型输出包含NaN,跳过此批次") continue loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # 梯度裁剪 optimizer.step() total_loss += loss.item() if (i+1) % 10 == 0: print(f"Epoch {epoch+1}/{epochs} | Batch {i+1}/{len(train_loader)} | Loss: {loss.item():.4f}") avg_loss = total_loss / len(train_loader) scheduler.step(avg_loss) print(f"Epoch {epoch+1} | 平均损失: {avg_loss:.4f}") torch.cuda.empty_cache() total_time = time.time() - start_time print(f"训练完成!总耗时: {total_time:.2f}s ({total_time/60:.2f}分钟)") return model, total_time # 评估函数 def evaluate_model(model, val_loader, tokenizer, num_examples=2): model.eval() scorer = rouge_scorer.RougeScorer([&#39;rouge1&#39;, &#39;rouge2&#39;, &#39;rougeL&#39;], use_stemmer=True) rouge_scores = {&#39;rouge1&#39;: [], &#39;rouge2&#39;: [], &#39;rougeL&#39;: []} valid_count = 0 with torch.no_grad(): for i, batch in enumerate(val_loader): src = batch[&#39;input_ids&#39;].to(device) tgt = batch[&#39;labels&#39;].to(device) generated = model.generate(src) for s, p, t in zip(src, generated, tgt): src_txt = tokenizer.decode(s, skip_special_tokens=True) pred_txt = tokenizer.decode(p, skip_special_tokens=True) true_txt = tokenizer.decode(t[t != -100], skip_special_tokens=True) if len(pred_txt.split()) > 3 and len(true_txt.split()) > 3: valid_count += 1 if valid_count <= num_examples: print(f"\n原文: {src_txt[:100]}...") print(f"生成: {pred_txt}") print(f"参考: {true_txt[:80]}...") print("-"*60) if true_txt and pred_txt: scores = scorer.score(true_txt, pred_txt) for key in rouge_scores: rouge_scores[key].append(scores[key].fmeasure) if valid_count > 0: avg_scores = {key: sum(rouge_scores[key])/len(rouge_scores[key]) for key in rouge_scores} print(f"\n评估结果 (基于{valid_count}个样本):") print(f"ROUGE-1: {avg_scores[&#39;rouge1&#39;]*100:.2f}%") print(f"ROUGE-2: {avg_scores[&#39;rouge2&#39;]*100:.2f}%") print(f"ROUGE-L: {avg_scores[&#39;rougeL&#39;]*100:.2f}%") else: print("警告:未生成有效摘要") avg_scores = {key: 0.0 for key in rouge_scores} return avg_scores # 可视化模型性能 def visualize_model_performance(model_names, train_times, rouge_scores): plt.figure(figsize=(15, 6)) # 训练时间对比图 plt.subplot(1, 2, 1) bars = plt.bar(model_names, train_times) plt.title(&#39;模型训练时间对比&#39;) plt.ylabel(&#39;时间 (分钟)&#39;) for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f&#39;{height:.1f} min&#39;, ha=&#39;center&#39;, va=&#39;bottom&#39;) # ROUGE分数对比图 plt.subplot(1, 2, 2) x = np.arange(len(model_names)) width = 0.25 plt.bar(x - width, [scores[&#39;rouge1&#39;] for scores in rouge_scores], width, label=&#39;ROUGE-1&#39;) plt.bar(x, [scores[&#39;rouge2&#39;] for scores in rouge_scores], width, label=&#39;ROUGE-2&#39;) plt.bar(x + width, [scores[&#39;rougeL&#39;] for scores in rouge_scores], width, label=&#39;ROUGE-L&#39;) plt.title(&#39;模型ROUGE分数对比&#39;) plt.ylabel(&#39;F1分数&#39;) plt.xticks(x, model_names) plt.legend() plt.tight_layout() plt.savefig(&#39;performance_comparison.png&#39;) plt.show() print("性能对比图已保存为 performance_comparison.png") # 交互式文本摘要生成 def interactive_summarization(models, tokenizer, model_names, max_length=80): while True: print("\n" + "="*60) print("文本摘要交互式测试 (输入 &#39;q&#39; 退出)") print("="*60) input_text = input("请输入要摘要的文本:\n") if input_text.lower() == &#39;q&#39;: break if len(input_text) < 50: print("请输入更长的文本(至少50个字符)") continue # 生成摘要 inputs = tokenizer( input_text, max_length=384, truncation=True, padding=&#39;max_length&#39;, return_tensors=&#39;pt&#39; ).to(device) print("\n生成摘要中...") all_summaries = [] for i, model in enumerate(models): model.eval() with torch.no_grad(): generated = model.generate(inputs["input_ids"]) summary = tokenizer.decode(generated[0], skip_special_tokens=True) all_summaries.append(summary) # 打印结果 print(f"\n{model_names[i]} 摘要:") print("-"*50) print(summary) print("-"*50) print("\n所有模型摘要对比:") for i, (name, summary) in enumerate(zip(model_names, all_summaries)): print(f"{i+1}. {name}: {summary}") # 主程序 print("加载数据集...") dataset = load_dataset("cnn_dailymail", "3.0.0") tokenizer = BertTokenizerFast.from_pretrained(&#39;bert-base-uncased&#39;) vocab_size = len(tokenizer.vocab) # 准备训练数据 print("准备训练数据...") train_ds = SummaryDataset(dataset[&#39;train&#39;], tokenizer, subset_size=0.01) # 使用1%的数据 val_ds = SummaryDataset(dataset[&#39;validation&#39;], tokenizer, subset_size=0.01) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0) val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=0) # 定义损失函数 criterion = nn.CrossEntropyLoss(ignore_index=-100) # 训练基础Seq2Seq print("\n" + "="*60) print("训练基础Seq2Seq模型") print("="*60) basic_model = BasicSeq2Seq(vocab_size).to(device) trained_basic, basic_time = train_model(basic_model, train_loader, None, criterion, epochs=3) basic_rouge = evaluate_model(trained_basic, val_loader, tokenizer) # 训练注意力Seq2Seq print("\n" + "="*60) print("训练注意力Seq2Seq模型") print("="*60) attn_model = AttnSeq2Seq(vocab_size).to(device) trained_attn, attn_time = train_model(attn_model, train_loader, None, criterion, epochs=3) attn_rouge = evaluate_model(trained_attn, val_loader, tokenizer) # 训练Transformer print("\n" + "="*60) print("训练Transformer模型") print("="*60) transformer_model = TransformerModel(vocab_size).to(device) trained_transformer, transformer_time = train_model(transformer_model, train_loader, None, criterion, epochs=3) transformer_rouge = evaluate_model(trained_transformer, val_loader, tokenizer) # 可视化模型性能 print("\n" + "="*60) print("模型性能对比") print("="*60) model_names = [&#39;基础Seq2Seq&#39;, &#39;注意力Seq2Seq&#39;, &#39;Transformer&#39;] train_times = [basic_time/60, attn_time/60, transformer_time/60] rouge_scores = [basic_rouge, attn_rouge, transformer_rouge] visualize_model_performance(model_names, train_times, rouge_scores) # 交互式测试 print("\n" + "="*60) print("交互式文本摘要测试") print("="*60) print("提示:输入一段文本,将同时生成三个模型的摘要结果") interactive_summarization( [trained_basic, trained_attn, trained_transformer], tokenizer, model_names ) 修改完错误后发完整代码给我
06-09
/************************************************************************************************** \Copyright (c) Novatek Microelectronics Corp., Ltd. All Rights Reserved. \file client.c \brief client of full duplex chat tools \project Training \chip NA \Date 2025.08.15 **************************************************************************************************/ #include "common.h" #define USER_FILE "users.dat" #define MAX_USERS 100 typedef struct { char sUserName[MAX_NAME_LEN]; char sPassword[MAX_PASSWORD_LEN]; } User; ClientInfo g_stClients[MAX_CLIENTS] = { [0 ... MAX_CLIENTS-1] = { .sockfd = -1, .sUserName = "\0", .readThread = 0, .bActive = false } }; pthread_mutex_t g_clientsMutex = PTHREAD_MUTEX_INITIALIZER; pthread_mutex_t g_usersMutex = PTHREAD_MUTEX_INITIALIZER; pthread_mutex_t g_logMutex = PTHREAD_MUTEX_INITIALIZER; FILE *g_logFile = NULL; User g_users[MAX_USERS] = {0}; int g_iUserCount = 0; atomic_bool g_bServerShutdown = ATOMIC_VAR_INIT(false); int g_serverFd = 0; /*! \fn void load_users() \brief Save user information \return success : true; fail : false */ bool load_users() { TEST_PRINT(EN_TRACE, "in"); bool bRes = true; FILE *file = fopen(USER_FILE, "rb"); if (!file) { TEST_PRINT(EN_ERROR, "user file is NULL"); bRes = false; goto end; } pthread_mutex_lock(&g_usersMutex); g_iUserCount = fread(g_users, sizeof(User), MAX_USERS, file); pthread_mutex_unlock(&g_usersMutex); fclose(file); file = NULL; end: TEST_PRINT(EN_TRACE, "out"); return bRes; } /*! \fn void save_users() \brief Save user information \return success : true; fail : false */ bool save_users() { TEST_PRINT(EN_TRACE, "in"); bool bRes = true; FILE *file = fopen(USER_FILE, "wb"); if (!file) { TEST_PRINT(EN_ERROR, "user file is NULL"); bRes = false; goto end; } pthread_mutex_lock(&g_usersMutex); if (g_iUserCount != (int)fwrite(g_users, sizeof(User), g_iUserCount, file)) { TEST_PRINT(EN_ERROR, "File write quantity error"); bRes = false; } pthread_mutex_unlock(&g_usersMutex); fclose(file); file = NULL; end: TEST_PRINT(EN_TRACE, "out"); return bRes; } /*! \fn bool register_user(const char *sUserName, const char *sPassword) \brief User register \param sUserName: sUserName of user \param sPassword: sPassword of user \return success : true; fail : false */ bool register_user(const char *sUserName, const char *sPassword) { TEST_PRINT(EN_TRACE, "in"); if (g_iUserCount >= MAX_USERS) { TEST_PRINT(EN_ERROR, "The number of g_users exceeds the limit"); return false; } pthread_mutex_lock(&g_usersMutex); for (int index = 0; index < g_iUserCount; index++) { if (strncmp(g_users[index].sUserName, sUserName, MAX_NAME_LEN) == 0) { TEST_PRINT(EN_ERROR, "Same username exists"); pthread_mutex_unlock(&g_usersMutex); return false; } } strncpy(g_users[g_iUserCount].sUserName, sUserName, MAX_NAME_LEN); g_users[g_iUserCount].sUserName[MAX_NAME_LEN-1] = &#39;\0&#39;; strncpy(g_users[g_iUserCount].sPassword, sPassword, MAX_PASSWORD_LEN); g_users[g_iUserCount].sPassword[MAX_PASSWORD_LEN-1] = &#39;\0&#39;; g_iUserCount++; pthread_mutex_unlock(&g_usersMutex); save_users(); TEST_PRINT(EN_TRACE, "out"); return true; } /*! \fn int login_user(const char *sUserName, const char *sPassword) \brief User login \param sUserName: sUserName of user \param sPassword: sPassword of user \return success : true; fail : false */ bool login_user(const char *sUserName, const char *sPassword) { TEST_PRINT(EN_TRACE, "in"); int index = 0; bool bRes = false; pthread_mutex_lock(&g_usersMutex); for (; index < g_iUserCount; index++) { if (strncmp(g_users[index].sUserName, sUserName, strlen(sUserName)) == 0 && strncmp(g_users[index].sPassword, sPassword, strlen(sPassword)) == 0) { bRes = true; break; } } pthread_mutex_unlock(&g_usersMutex); TEST_PRINT(EN_TRACE, "out"); return bRes; } /*! \fn void cleanup_resources() \brief Clean up server resources */ void cleanup_resources() { TEST_PRINT(EN_TRACE, "in"); int index = 0; TEST_PRINT(EN_MSG, "Cleaning up server resources"); pthread_mutex_lock(&g_clientsMutex); for (index = 0; index < MAX_CLIENTS; index++) { if (g_stClients[index].bActive) { shutdown(g_stClients[index].sockfd, SHUT_RDWR); } } pthread_mutex_unlock(&g_clientsMutex); pthread_mutex_lock(&g_clientsMutex); for (index = 0; index < MAX_CLIENTS; index++) { if (g_stClients[index].bActive) { struct timespec ts; clock_gettime(CLOCK_REALTIME, &ts); ts.tv_sec += 5; if (pthread_timedjoin_np(g_stClients[index].readThread, NULL, &ts) != 0) { TEST_PRINT(EN_ERROR, "Tread %ld failed to exit normally, forced cancellation", g_stClients[index].readThread); pthread_cancel(g_stClients[index].readThread); } close(g_stClients[index].sockfd); g_stClients[index].sockfd = -1; g_stClients[index].bActive = false; } } pthread_mutex_unlock(&g_clientsMutex); pthread_mutex_destroy(&g_clientsMutex); pthread_mutex_destroy(&g_usersMutex); pthread_mutex_destroy(&g_logMutex); if (g_serverFd != -1) { close(g_serverFd); g_serverFd = -1; } if (g_logFile) { fclose(g_logFile); g_logFile = NULL; } atomic_store_explicit(&g_bServerShutdown, false, memory_order_release); TEST_PRINT(EN_TRACE, "out"); return; } /*! \fn void handle_signal(int sig) \brief signal processing function \param sig: signal value */ void handle_signal(int sig) { TEST_PRINT(EN_TRACE, "in"); TEST_PRINT(EN_MSG, "Received signal %d, shutting down", sig); cleanup_resources(); TEST_PRINT(EN_TRACE, "out"); exit(EXIT_SUCCESS); } /*! \fn void log_message(ChatMessage msg) \brief Record chat messages to a file \param msg: Message sContent */ void log_message(ChatMessage msg) { TEST_PRINT(EN_TRACE, "in"); time_t now = time(NULL); struct tm *stTime = localtime(&now); if (!g_logFile) { TEST_PRINT(EN_ERROR, "g_logFile is NULL"); goto end; } pthread_mutex_lock(&g_logMutex); fprintf(g_logFile, "[%04d-%02d-%02d %02d:%02d:%02d] [%s]:%s\n", stTime->tm_year + 1900, stTime->tm_mon + 1, stTime->tm_mday, stTime->tm_hour, stTime->tm_min, stTime->tm_sec, msg.sSender, msg.sContent); fflush(g_logFile); pthread_mutex_unlock(&g_logMutex); end: TEST_PRINT(EN_TRACE, "out"); return; } /*! \fn void broadcast_message(int sender_sock, const char* msg) \brief Broadcast messages to all g_stClients \param sender_sock: Sender socket \param msg: Message sContent */ void broadcast_message(int sender_sock, ChatMessage stChatMsg) { TEST_PRINT(EN_TRACE, "in"); int index = 0; TEST_PRINT(EN_DEBUG, "sender_sock is %d, msg is %s", sender_sock, stChatMsg.sContent); pthread_mutex_lock(&g_clientsMutex); for (; index < MAX_CLIENTS; index++) { if (g_stClients[index].bActive && g_stClients[index].sockfd != sender_sock) { if ((unsigned int)sizeof(ChatMessage) > send(g_stClients[index].sockfd, &stChatMsg, sizeof(ChatMessage), 0)) { TEST_PRINT(EN_ERROR, "Send failed to client %d", index); } } } pthread_mutex_unlock(&g_clientsMutex); log_message(stChatMsg); TEST_PRINT(EN_TRACE, "out"); return; } /*! \fn void* client_read_thread(void* arg) \brief Client Read Thread Function \param arg: Client side readThread \return NULL */ void* server_read_thread(void* arg) { TEST_PRINT(EN_TRACE, "in"); int idx = 0; ChatMessage stChatMsg = {0}; int index = 0; int iRecvLen = 0; if (!arg) { TEST_PRINT(EN_ERROR, "server_read_thread input is NULL"); goto end; } idx = *(int*)arg; free(arg); arg = NULL; while (g_stClients[idx].bActive && !atomic_load_explicit(&g_bServerShutdown, memory_order_acquire)) { pthread_testcancel(); iRecvLen = recv(g_stClients[idx].sockfd, &stChatMsg, sizeof(ChatMessage), 0); if (0 >= iRecvLen) { if (0 == iRecvLen) { TEST_PRINT(EN_MSG, "Client %d disconnected", idx); } else if (errno != EAGAIN && errno != EWOULDBLOCK) { TEST_PRINT(EN_ERROR, "Recv error: %s", strerror(errno)); } else { continue; } pthread_mutex_lock(&g_clientsMutex); g_stClients[idx].bActive = false; pthread_mutex_unlock(&g_clientsMutex); goto end; } TEST_PRINT(EN_DEBUG, "Recv msg is %s", stChatMsg.sContent); switch (stChatMsg.type) { case MSG_LOGIN: { if (login_user(stChatMsg.sSender, stChatMsg.sContent)) { strncpy(g_stClients[idx].sUserName, stChatMsg.sSender, MAX_NAME_LEN); strncpy(stChatMsg.sReceiver, stChatMsg.sSender, MAX_NAME_LEN); strncpy(stChatMsg.sSender, "Server", MAX_NAME_LEN); strncpy(stChatMsg.sContent, "Login success", MAX_MSG_LEN); } else { strncpy(stChatMsg.sReceiver, stChatMsg.sSender, MAX_NAME_LEN); strncpy(stChatMsg.sSender, "Server", MAX_NAME_LEN); strncpy(stChatMsg.sContent, "Incorrect Password", MAX_MSG_LEN); } goto sendlable; } case MSG_TEXT: { int targetIdx = -1; pthread_mutex_lock(&g_clientsMutex); for (index = 0; index < MAX_CLIENTS; index++) { if (g_stClients[index].sockfd != -1 && strncmp(g_stClients[index].sUserName, stChatMsg.sReceiver, MAX_NAME_LEN) == 0) { targetIdx = index; log_message(stChatMsg); if (0 > send(g_stClients[targetIdx].sockfd, &stChatMsg, sizeof(ChatMessage), 0)) { TEST_PRINT(EN_ERROR, "Send failed to client %d", index); } break; } } pthread_mutex_unlock(&g_clientsMutex); if (-1 == targetIdx) { stChatMsg.type = MSG_TEXT; strncpy(stChatMsg.sReceiver, stChatMsg.sSender, MAX_NAME_LEN); strncpy(stChatMsg.sSender, "Server", MAX_NAME_LEN); strncpy(stChatMsg.sContent, "Cannot find this user\n", MAX_MSG_LEN); goto sendlable; } break; } case MSG_GROUP: { TEST_PRINT(EN_DEBUG, "case msg is %s", stChatMsg.sContent); broadcast_message(g_stClients[idx].sockfd, stChatMsg); break; } case MSG_LIST: { TEST_PRINT(EN_ERROR, "Send list"); stChatMsg.type = MSG_TEXT; strncpy(stChatMsg.sReceiver, stChatMsg.sSender, MAX_NAME_LEN); strncpy(stChatMsg.sSender, "Server", MAX_NAME_LEN); stChatMsg.sContent[0] = &#39;\0&#39;; pthread_mutex_lock(&g_clientsMutex); for (index = 0; index < MAX_CLIENTS; index++) { if (g_stClients[index].sockfd != -1 && g_stClients[index].bActive) { strncat(stChatMsg.sContent, g_stClients[index].sUserName, MAX_NAME_LEN); strncat(stChatMsg.sContent, "\n", 1); } } pthread_mutex_unlock(&g_clientsMutex); goto sendlable; } case MSG_REGIST: { if (!register_user(stChatMsg.sSender, stChatMsg.sContent)) { strncpy(stChatMsg.sContent, "Register failed\n", MAX_MSG_LEN); } else { strncpy(stChatMsg.sContent, "Register success\n", MAX_MSG_LEN); } stChatMsg.type = MSG_TEXT; strncpy(stChatMsg.sReceiver, stChatMsg.sSender, MAX_NAME_LEN); strncpy(stChatMsg.sSender, "Server", MAX_NAME_LEN); goto sendlable; } default: break; } continue; sendlable: if (0 > send(g_stClients[idx].sockfd, &stChatMsg, sizeof(ChatMessage), 0)) { TEST_PRINT(EN_ERROR, "Send failed to client %d", idx); } } end: pthread_mutex_lock(&g_clientsMutex); shutdown(g_stClients[idx].sockfd, SHUT_RDWR); close(g_stClients[idx].sockfd); g_stClients[idx].sockfd = -1; g_stClients[idx].bActive = false; pthread_mutex_unlock(&g_clientsMutex); TEST_PRINT(EN_TRACE, "out"); return NULL; } /*! \fn int setup_server(int port) \brief Set TCP server socket \param port: listen port \return Successfully returned socket fd, failed returned -1 */ int setup_server(int port) { TEST_PRINT(EN_TRACE, "in"); struct sockaddr_in address = {0}; int iOpt = 1; g_serverFd = socket(AF_INET, SOCK_STREAM, 0); if (0 > g_serverFd) { TEST_PRINT(EN_ERROR, "Socket creation failed"); g_serverFd = -1; goto end; } // Set SO-REUSEADDR if (setsockopt(g_serverFd, SOL_SOCKET, SO_REUSEADDR, &iOpt, sizeof(iOpt))) { TEST_PRINT(EN_ERROR, "setsockopt(SO_REUSEADDR) failed: %s", strerror(errno)); close(g_serverFd); g_serverFd = -1; goto end; } address.sin_family = AF_INET; address.sin_addr.s_addr = INADDR_ANY; address.sin_port = htons(port); if (bind(g_serverFd, (struct sockaddr*)&address, sizeof(address)) < 0) { TEST_PRINT(EN_ERROR, "Bind failed: %s", strerror(errno)); close(g_serverFd); g_serverFd = -1; goto end; } if (0 > listen(g_serverFd, 5)) { TEST_PRINT(EN_ERROR, "Listen failed"); close(g_serverFd); g_serverFd = -1; goto end; } atomic_store_explicit(&g_bServerShutdown, false, memory_order_release); end: TEST_PRINT(EN_TRACE, "out"); return g_serverFd; } int main() { TEST_PRINT(EN_TRACE, "in"); int *arg = NULL; int idx = -1; struct sockaddr_in clientAddr = {0}; int clientFd = -1; struct timeval stTimeVal = {0}; int index = 0; signal(SIGINT, handle_signal); signal(SIGTSTP, handle_signal); load_users(); g_logFile = fopen("chat_server.log", "a"); if (!g_logFile) { TEST_PRINT(EN_ERROR, "Failed to open log file"); return EXIT_FAILURE; } g_serverFd = setup_server(SERVER_PORT); if (0 > g_serverFd) { TEST_PRINT(EN_ERROR, "Server setup failed"); return EXIT_FAILURE; } TEST_PRINT(EN_MSG, "Server started on port %d", SERVER_PORT); // The server needs to continuously monitor client connections while (!atomic_load_explicit(&g_bServerShutdown, memory_order_acquire)) { idx = -1; socklen_t addr_len = sizeof(clientAddr); clientFd = accept(g_serverFd, (struct sockaddr*)&clientAddr, &addr_len); if (0 > clientFd) { TEST_PRINT(EN_ERROR, "Accept failed: %s", strerror(errno)); continue; } stTimeVal.tv_sec = TIMEOUT_SEC; stTimeVal.tv_usec = TIMEOUT_USEC; setsockopt(clientFd, SOL_SOCKET, SO_RCVTIMEO, &stTimeVal, sizeof(stTimeVal)); setsockopt(clientFd, SOL_SOCKET, SO_SNDTIMEO, &stTimeVal, sizeof(stTimeVal)); pthread_mutex_lock(&g_clientsMutex); for (index = 0; index < MAX_CLIENTS; ++index) { if (!g_stClients[index].bActive) { idx = index; g_stClients[index].sockfd = clientFd; g_stClients[index].bActive = true; break; } } pthread_mutex_unlock(&g_clientsMutex); if (-1 == idx) { TEST_PRINT(EN_ERROR, "Max g_stClients reached"); close(clientFd); continue; } TEST_PRINT(EN_MSG, "New client connected: (ID: %d)", idx); arg = malloc(sizeof(int)); *arg = idx; if (0 != pthread_create(&g_stClients[idx].readThread, NULL, server_read_thread, arg)) { TEST_PRINT(EN_ERROR, "Failed to create read thread"); pthread_mutex_lock(&g_clientsMutex); g_stClients[idx].bActive = false; g_stClients[idx].sockfd = -1; pthread_mutex_unlock(&g_clientsMutex); close(clientFd); free(arg); arg = NULL; continue; } } cleanup_resources(); TEST_PRINT(EN_TRACE, "out"); return EXIT_SUCCESS; } /************************************************************************************************** \Copyright (c) Novatek Microelectronics Corp., Ltd. All Rights Reserved. \file commom.h \brief Universal library of full duplex chat tools \project Training \chip NA \Date 2025.08.15 **************************************************************************************************/ #ifndef COMMON_H #define COMMON_H #include <stdio.h> #include <stdlib.h> #include <string.h> #include <unistd.h> #include <pthread.h> #include <signal.h> #include <sys/socket.h> #include <netinet/in.h> #include <arpa/inet.h> #include <time.h> #include <errno.h> #include <stdatomic.h> #include <stdbool.h> #define _test_print(fmt, ...) \ printf("[%s():%d] " fmt "\n", __func__, __LINE__, ##__VA_ARGS__) #define TEST_PRINT(level, fmt, ...) \ if ((level) <= enDebugLevel) _test_print(fmt, ##__VA_ARGS__) #define MAX_CLIENTS 10 #define MAX_MSG_LEN 1024 #define MAX_NAME_LEN 30 #define MAX_PASSWORD_LEN 10 #define SERVER_PORT 8080 #define TIMEOUT_SEC 5 #define TIMEOUT_USEC 0 /*! \brief Printing level */ typedef enum { EN_ERROR = 0, EN_MSG = 1, EN_DEBUG = 2, EN_TRACE = 3 } enLOG_LEVELS; static enLOG_LEVELS enDebugLevel = EN_MSG; /*! \brief Client Information Structure */ typedef struct { int sockfd; char sUserName[MAX_NAME_LEN]; pthread_t readThread; bool bActive; } ClientInfo; /*! \brief Enumeration of message types */ typedef enum { MSG_LOGIN = 0, MSG_TEXT = 1, MSG_LIST = 2, MSG_GROUP = 3, MSG_REGIST = 4 } msg_type_t; /*! \brief Message structure */ typedef struct { msg_type_t type; char sSender[MAX_NAME_LEN]; char sReceiver[MAX_NAME_LEN]; char sContent[MAX_MSG_LEN]; } ChatMessage; #endif /************************************************************************************************** \Copyright (c) Novatek Microelectronics Corp., Ltd. All Rights Reserved. \file server.c \brief server of full duplex chat tools \project Training \chip NA \Date 2025.08.15 **************************************************************************************************/ #include "common.h" int g_sockfd = -1; atomic_bool g_bRunning = ATOMIC_VAR_INIT(true); atomic_bool g_bClientShutdown = ATOMIC_VAR_INIT(false); pthread_t g_readThread, g_writeThread; char g_sClientName[MAX_NAME_LEN] = {&#39;\0&#39;}; /*! \fn void cleanup_client() \brief Clean up client resources */ void cleanup_client() { TEST_PRINT(EN_TRACE, "in"); TEST_PRINT(EN_MSG, "Cleaning up client resources"); atomic_store_explicit(&g_bRunning, false, memory_order_release); atomic_store_explicit(&g_bClientShutdown, true, memory_order_release); shutdown(g_sockfd, SHUT_RDWR); pthread_cancel(g_readThread); pthread_cancel(g_writeThread); if (0 <= g_sockfd) { close(g_sockfd); g_sockfd = -1; } TEST_PRINT(EN_TRACE, "out"); return; } /*! \fn void handle_signal(int sig) \brief signal processing function \param sig: Signal value */ void handle_signal(int sig) { TEST_PRINT(EN_TRACE, "in"); TEST_PRINT(EN_MSG, "Received signal %d, exiting...", sig); cleanup_client(); TEST_PRINT(EN_TRACE, "out"); } /*! \fn void* read_thread(void* arg) \brief Receive server messages \param arg: unused \return NULL */ void* read_thread(void* arg) { TEST_PRINT(EN_TRACE, "in"); ChatMessage chatMsg = {0}; int recvLen = 0; (void)arg; while (atomic_load_explicit(&g_bRunning, memory_order_acquire) && !atomic_load_explicit(&g_bClientShutdown, memory_order_acquire)) { pthread_testcancel(); recvLen = recv(g_sockfd, &chatMsg, sizeof(ChatMessage), 0); if (0 < recvLen) { TEST_PRINT(EN_MSG, "\n[%s]:\n%s", chatMsg.sSender,chatMsg.sContent); fflush(stdout); } else if (0 == recvLen) { TEST_PRINT(EN_MSG, "Server disconnected"); atomic_store_explicit(&g_bRunning, false, memory_order_release); goto end; } else { if (errno != EAGAIN && errno != EWOULDBLOCK) { TEST_PRINT(EN_ERROR, "Recv error: %s", strerror(errno)); atomic_store_explicit(&g_bRunning, false, memory_order_release); goto end; } } } end: TEST_PRINT(EN_TRACE, "out"); return NULL; } /*! \fn void* write_thread(void* arg) \brief Read user input \param arg: unused \return NULL */ void* write_thread(void* arg) { TEST_PRINT(EN_TRACE, "in"); int iChoice = 0; ChatMessage chatMsg = {0}; (void)arg; while (atomic_load_explicit(&g_bRunning, memory_order_acquire) && !atomic_load_explicit(&g_bClientShutdown, memory_order_acquire)) { pthread_testcancel(); TEST_PRINT(EN_MSG, "\n==== Chat menu ====\n"); TEST_PRINT(EN_MSG, "\n1. Private chat\n2. Group chat\n3. Online users\n4. Exit"); fflush(stdout); scanf("%d", &iChoice); getchar(); switch (iChoice) { case 1: chatMsg.type = MSG_TEXT; strncpy(chatMsg.sSender, g_sClientName, MAX_NAME_LEN); chatMsg.sSender[MAX_NAME_LEN - 1] = &#39;\0&#39;; TEST_PRINT(EN_MSG, "Please enter the recipient&#39;s username: "); fgets(chatMsg.sReceiver, MAX_NAME_LEN, stdin); chatMsg.sReceiver[strcspn(chatMsg.sReceiver, "\n")] = &#39;\0&#39;; TEST_PRINT(EN_MSG, "Input message: "); fgets(chatMsg.sContent, MAX_MSG_LEN, stdin); chatMsg.sContent[strcspn(chatMsg.sContent, "\n")] = &#39;\0&#39;; chatMsg.type = MSG_TEXT; break; case 2: chatMsg.type = MSG_GROUP; strncpy(chatMsg.sSender, g_sClientName, MAX_NAME_LEN); chatMsg.sSender[MAX_NAME_LEN - 1] = &#39;\0&#39;; TEST_PRINT(EN_MSG, "Enter group messaging: "); fgets(chatMsg.sContent, MAX_MSG_LEN, stdin); chatMsg.sContent[strcspn(chatMsg.sContent, "\n")] = &#39;\0&#39;; strncpy(chatMsg.sReceiver, "ALL", MAX_NAME_LEN); chatMsg.sReceiver[MAX_NAME_LEN - 1] = &#39;\0&#39;; break; case 3: chatMsg.type = MSG_LIST; break; case 4: atomic_store_explicit(&g_bRunning, false, memory_order_release); goto end; default: TEST_PRINT(EN_ERROR, "invalid Choice"); } TEST_PRINT(EN_DEBUG, "send msg"); if (0 > send(g_sockfd, &chatMsg, sizeof(ChatMessage), 0)) { TEST_PRINT(EN_ERROR, "Send failed: %s", strerror(errno)); goto end; } } end: TEST_PRINT(EN_TRACE, "out"); return NULL; } /*! \fn bool client_login() \brief client login function \return login success : true; failed : false */ bool client_login() { int iChoice = 0; ChatMessage chatMsg = {0}; bool bRes = false; while (!atomic_load_explicit(&g_bClientShutdown, memory_order_acquire)) { start: TEST_PRINT(EN_MSG, "\n1. Login in\n2. Register\n3. Exit"); fflush(stdout); scanf("%d", &iChoice); getchar(); switch (iChoice) { case 1: chatMsg.type = MSG_LOGIN; TEST_PRINT(EN_MSG, "Enter user name"); fflush(stdout); fgets(chatMsg.sSender, MAX_NAME_LEN, stdin); chatMsg.sSender[strcspn(chatMsg.sSender, "\n")] = &#39;\0&#39;; TEST_PRINT(EN_DEBUG, "User name %s", chatMsg.sSender); TEST_PRINT(EN_MSG, "Enter user password"); fflush(stdout); fgets(chatMsg.sContent, MAX_PASSWORD_LEN, stdin); send(g_sockfd, &chatMsg, sizeof(ChatMessage), 0); break; case 2: chatMsg.type = MSG_REGIST; TEST_PRINT(EN_MSG, "Enter user name"); fflush(stdout); fgets(chatMsg.sSender, MAX_NAME_LEN, stdin); chatMsg.sSender[strcspn(chatMsg.sSender, "\n")] = &#39;\0&#39;; TEST_PRINT(EN_MSG, "Enter user password"); fflush(stdout); fgets(chatMsg.sContent, MAX_PASSWORD_LEN, stdin); chatMsg.sSender[strcspn(chatMsg.sContent, "\n")] = &#39;\0&#39;; send(g_sockfd, &chatMsg, sizeof(ChatMessage), 0); break; case 3: bRes = false; goto end; default: TEST_PRINT(EN_ERROR, "invalid Choice"); goto start; } int recvLen = recv(g_sockfd, &chatMsg, sizeof(ChatMessage), 0); if (0 < recvLen) { TEST_PRINT(EN_MSG, "\n[%s]:\n%s", chatMsg.sSender,chatMsg.sContent); fflush(stdout); } else if (0 == recvLen) { TEST_PRINT(EN_MSG, "Server disconnected"); bRes = false; goto end; } else { if (errno != EAGAIN && errno != EWOULDBLOCK) { TEST_PRINT(EN_ERROR, "Recv error: %s", strerror(errno)); bRes = false; goto end; } } if (0 == strncmp(chatMsg.sContent, "Login success", MAX_MSG_LEN)) { strncpy(g_sClientName, chatMsg.sReceiver, MAX_NAME_LEN); g_sClientName[MAX_NAME_LEN - 1] = &#39;\0&#39;; bRes = true; break; } } end: TEST_PRINT(EN_TRACE, "out"); return bRes; } int main(int argc, char* argv[]) { TEST_PRINT(EN_TRACE, "in"); struct timeval tv = {0}; struct sockaddr_in servAddr = {0}; signal(SIGINT, handle_signal); signal(SIGTSTP, handle_signal); if (2 != argc) { TEST_PRINT(EN_ERROR, "Usage: %s <server_ip>\n", argv[0]); return EXIT_FAILURE; } g_sockfd = socket(AF_INET, SOCK_STREAM, 0); if (0 > g_sockfd) { TEST_PRINT(EN_ERROR, "Socket creation failed"); return EXIT_FAILURE; } tv.tv_sec = TIMEOUT_SEC; tv.tv_usec = TIMEOUT_USEC; setsockopt(g_sockfd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); setsockopt(g_sockfd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); servAddr.sin_family = AF_INET; servAddr.sin_port = htons(SERVER_PORT); if (0 >= inet_pton(AF_INET, argv[1], &servAddr.sin_addr)) { TEST_PRINT(EN_ERROR, "Invalid address"); goto end; } if (0 > connect(g_sockfd, (struct sockaddr*)&servAddr, sizeof(servAddr))) { TEST_PRINT(EN_ERROR, "Connection failed: %s", strerror(errno)); goto end; } TEST_PRINT(EN_MSG, "Connected to %s:%d", argv[1], SERVER_PORT); if (!client_login()) { TEST_PRINT(EN_ERROR, "Client login failed"); goto end; } if (0 != pthread_create(&g_readThread, NULL, read_thread, NULL)) { TEST_PRINT(EN_ERROR, "Failed to create read thread"); goto end; } if (0 != pthread_create(&g_writeThread, NULL, write_thread, NULL)) { TEST_PRINT(EN_ERROR, "Failed to create write thread"); goto end; } if (0 != pthread_join(g_readThread, NULL)) { TEST_PRINT(EN_ERROR, "Failed to join read thread"); goto end; } if (0 != pthread_join(g_writeThread, NULL)) { TEST_PRINT(EN_ERROR, "Failed to join write thread"); goto end; } end: cleanup_client(); TEST_PRINT(EN_TRACE, "out"); return EXIT_SUCCESS; } 整理三个文件之间的关系
最新发布
08-22
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值