1.基础加载
import asyncio
from transformers import AutoTokenizer, PreTrainedTokenizer
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs
model_path = "casperhansen/mixtral-instruct-awq"
# prompting
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?",
prompt_template = "[INST] {prompt} [/INST]"
# sampling params
sampling_params = SamplingParams(
repetition_penalty=1.1,
temperature=0.8,
max_tokens=512
)
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# async engine args for streaming
engine_args = AsyncEngineArgs(
model=model_path,
quantization="awq",
dtype="float16",
max_model_len=512,
enforce_eager=True,
disable_log_requests=True,
disable_log_stats=True,
)
async def generate(model: AsyncLLMEngine, tokenizer: PreTrainedTokenizer):
tokens = tokenizer(prompt_template.format(prompt=prompt)).input_ids
outputs = model.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=1,
prompt_token_ids=tokens,
)
print("\n** Starting generation!\n")
last_index = 0
async for output in outputs:
print(output.outputs[0].text[last_index:], end="", flush=True)
last_index = len(output.outputs[0].text)
print("\n\n** Finished generation!\n")
if __name__ == '__main__':
model = AsyncLLMEngine.from_engine_args(engine_args)
asyncio.run(generate(model, tokenizer))
2.开放接口
2.1 服务端
import asyncio
from fastapi import FastAPI
from transformers import AutoTokenizer, PreTrainedTokenizer
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs
app = FastAPI()
model_path = "casperhansen/mixtral-instruct-awq"
# prompting
prompt_template = "[INST] {prompt} [/INST]"
# async engine args for streaming
engine_args = AsyncEngineArgs(
model=model_path,
quantization="awq",
dtype="float16",
max_model_len=512,
enforce_eager=True,
disable_log_requests=True,
disable_log_stats=True,
)
# 初始化模型引擎
model = AsyncLLMEngine.from_engine_args(engine_args)
tokenizer = AutoTokenizer.from_pretrained(model_path)
@app.post("/v1/chat/completions")
async def generate_text(prompt: str,
repetition_penalty: float = 1.1,
temperature: float = 0.8,
max_tokens: int = 512,
top_p: float = 0.9):
try:
sampling_params = SamplingParams(
repetition_penalty=repetition_penalty,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p
)
tokens = tokenizer(prompt_template.format(prompt=prompt)).input_ids
outputs = model.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=1,
prompt_token_ids=tokens,
)
result = ""
async for output in outputs:
result += output.outputs[0].text
return {
"choices": [
{
"message": {
"content": result
}
}
]
}
except Exception as e:
return {"error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
2.2 客户端
import requests
# 定义 API 端点
url = "http://localhost:8000/v1/chat/completions"
# 定义要发送的提示词和参数
data = {
"prompt": "You're standing on the surface of the Earth. You walk one mile south, one mile west and one mile north. You end up exactly where you started. Where are you?",
"repetition_penalty": 1.1,
"temperature": 0.8,
"max_tokens": 512,
"top_p": 0.9
}
# 发起 POST 请求
response = requests.post(url, json=data)
# 检查响应状态码
if response.status_code == 200:
result = response.json()
if "choices" in result:
message_content = result["choices"][0]["message"]["content"]
print("生成的文本: ", message_content)
else:
print("发生错误: ", result["error"])
else:
print(f"请求失败,状态码: {response.status_code}")