Skip to content

Commit 035d138

Browse files
committed
更新pyproject.toml版本至0.2.2,优化vllm_infer函数以支持前缀缓存,调整qa_generator.py以处理结果并保存为JSON格式。
1 parent 5fea291 commit 035d138

File tree

5 files changed

+22
-6
lines changed

5 files changed

+22
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ dependencies = [
2020

2121
[tool.weclone]
2222
# 配置文件的版本号,当配置文件结构或重要默认值发生变化时,应增加此版本号
23-
config_version = "0.2.1"
23+
config_version = "0.2.2"
2424

2525
# 配置文件更新日志
2626
config_changelog = """

weclone/core/inference/vllm_infer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def vllm_infer(
139139
"pipeline_parallel_size": pipeline_parallel_size,
140140
"disable_log_stats": True,
141141
"enable_lora": model_args.adapter_name_or_path is not None,
142+
"enable_prefix_caching": True, # 是否启用前缀缓存
142143
}
143144
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
144145
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
@@ -155,5 +156,3 @@ def vllm_infer(
155156
print("*" * 70)
156157
print(f"{len(prompts)} generated results have been saved at {save_name}.")
157158
print("*" * 70)
158-
159-

weclone/data/clean/clean_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#请检查sorce字段分布是否正常 调整accept_score 调整后按Y保存继续清洗

weclone/data/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ class CutMessage:
2222
CreateTime: Timestamp
2323

2424

25+
# TODO 未使用QaPair
26+
@dataclass
27+
class QaPair:
28+
id: int
29+
system: str
30+
instruction: str
31+
output: str
32+
history: list[ChatMessage]
33+
time: Timestamp
34+
score: int
35+
36+
2537
skip_type_list = [
2638
"添加好友",
2739
"推荐公众号",

weclone/data/qa_generator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def main(self):
7575
template=self.c["template"],
7676
interval=self.c["cutoff_len"],
7777
)
78-
logger.success(f"聊天记录处理成功,共{len(qa_res)}条,保存到./dataset/res_csv/sft/sft-my.json")
78+
logger.success(f"聊天记录处理成功,共{len(qa_res)}条,保存到 ./dataset/res_csv/sft/sft-my.json")
7979

8080
def get_csv_files(self):
8181
"""遍历文件夹获取所有CSV文件路径"""
@@ -348,13 +348,17 @@ def process_text(self, chat_message: ChatMessage):
348348
pass
349349

350350
def save_result(self, qa_res: List[Dict]):
351-
# 保存结果
351+
processed_qa_res = []
352+
for idx, item in enumerate(qa_res):
353+
if isinstance(item, dict):
354+
item = {"id": idx, **item}
355+
processed_qa_res.append(item)
352356
with open(
353357
"./dataset/res_csv/sft/sft-my.json",
354358
"w",
355359
encoding="utf-8",
356360
) as f:
357-
json.dump(qa_res, f, ensure_ascii=False)
361+
json.dump(processed_qa_res, f, ensure_ascii=False)
358362

359363

360364
if __name__ == "__main__":

0 commit comments

Comments
 (0)