【大模型】使用Vllm加载AWQ量化模型

1.基础加载

import asyncio
from transformers import AutoTokenizer, PreTrainedTokenizer
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs

model_path = "casperhansen/mixtral-instruct-awq"

# prompting
prompt = "You're standing on the surface of the Earth. "\
         "You walk one mile south, one mile west and one mile north. "\
         "You end up exactly where you started. Where are you?",

prompt_template = "[INST] {prompt} [/INST]"

# sampling params
sampling_params = SamplingParams(
    repetition_penalty=1.1,
    temperature=0.8,
    max_tokens=512
)

# tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)

# async engine args for streaming
engine_args = AsyncEngineArgs(
    model=model_path,
    quantization="awq",
    dtype="float16",
    max_model_len=512,
    enforce_eager=True,
    disable_log_requests=True,
    disable_log_stats=True,
)

async def generate(model: AsyncLLMEngine, tokenizer: PreTrainedTokenizer):
    tokens = tokenizer(prompt_template.format(prompt=prompt)).input_ids

    outputs = model.generate(
        prompt=prompt,
        sampling_params=sampling_params,
        request_id=1,
        prompt_token_ids=tokens,
    )

    print("\n** Starting generation!\n")
    last_index = 0

    async for output in outputs:
        print(output.outputs[0].text[last_index:], end="", flush=True)
        last_index = len(output.outputs[0].text)
    
    print("\n\n** Finished generation!\n")

if __name__ == '__main__':
    model = AsyncLLMEngine.from_engine_args(engine_args)
    asyncio.run(generate(model, tokenizer))

2.开放接口

2.1 服务端

import asyncio
from fastapi import FastAPI
from transformers import AutoTokenizer, PreTrainedTokenizer
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs

app = FastAPI()

model_path = "casperhansen/mixtral-instruct-awq"

# prompting
prompt_template = "[INST] {prompt} [/INST]"

# async engine args for streaming
engine_args = AsyncEngineArgs(
    model=model_path,
    quantization="awq",
    dtype="float16",
    max_model_len=512,
    enforce_eager=True,
    disable_log_requests=True,
    disable_log_stats=True,
)

# 初始化模型引擎
model = AsyncLLMEngine.from_engine_args(engine_args)
tokenizer = AutoTokenizer.from_pretrained(model_path)

@app.post("/v1/chat/completions")
async def generate_text(prompt: str, 
                        repetition_penalty: float = 1.1,
                        temperature: float = 0.8,
                        max_tokens: int = 512,
                        top_p: float = 0.9):
    try:
        sampling_params = SamplingParams(
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p
        )
        tokens = tokenizer(prompt_template.format(prompt=prompt)).input_ids
        outputs = model.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=1,
            prompt_token_ids=tokens,
        )
        result = ""
        async for output in outputs:
            result += output.outputs[0].text
        return {
            "choices": [
                {
                    "message": {
                        "content": result
                    }
                }
            ]
        }
    except Exception as e:
        return {"error": str(e)}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

2.2 客户端

import requests

# 定义 API 端点
url = "http://localhost:8000/v1/chat/completions"

# 定义要发送的提示词和参数
data = {
    "prompt": "You're standing on the surface of the Earth. You walk one mile south, one mile west and one mile north. You end up exactly where you started. Where are you?",
    "repetition_penalty": 1.1,
    "temperature": 0.8,
    "max_tokens": 512,
    "top_p": 0.9
}

# 发起 POST 请求
response = requests.post(url, json=data)

# 检查响应状态码
if response.status_code == 200:
    result = response.json()
    if "choices" in result:
        message_content = result["choices"][0]["message"]["content"]
        print("生成的文本: ", message_content)
    else:
        print("发生错误: ", result["error"])
else:
    print(f"请求失败,状态码: {response.status_code}")
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值