vLLM 中的推测性解码#

警告

请注意,vLLM 中的推测性解码尚未优化,并且通常不会为所有提示数据集或采样参数产生令牌间延迟减少。优化它的工作正在进行中,可以在 此问题 中跟踪。

本文档介绍了如何在 vLLM 中使用“推测解码 <https://x.com/karpathy/status/1697318534555336961>”。推测解码是一种技术,可以提高内存受限的 LLM 推理中的令牌间延迟。

使用草稿模型进行推测#

以下代码将 vLLM 配置为离线模式,以使用草稿模型进行推测解码,每次推测 5 个令牌。

from vllm import LLM, SamplingParams

prompts = [
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
    model="facebook/opt-6.7b",
    tensor_parallel_size=1,
    speculative_model="facebook/opt-125m",
    num_speculative_tokens=5,
    use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

要使用在线模式执行相同的操作,请启动服务器:

   python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
   --seed 42 -tp 1 --speculative_model facebook/opt-125m --use-v2-block-manager \
   --num_speculative_tokens 5 --gpu_memory_utilization 0.8

Then use a client:
from openai import OpenAI

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    # defaults to os.environ.get("OPENAI_API_KEY")
    api_key=openai_api_key,
    base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

# Completion API
stream = False
completion = client.completions.create(
    model=model,
    prompt="The future of AI is",
    echo=False,
    n=1,
    stream=stream,
)

print("Completion results:")
if stream:
    for c in completion:
        print(c)
else:
    print(completion)

通过匹配提示中的 n 元语法进行推测#

以下代码将 vLLM 配置为使用推测解码,其中提案是通过匹配提示中的 n 元语法生成的。有关更多信息,请阅读“此线程 <https://x.com/joao_gante/status/1747322413006643259>”。

from vllm import LLM, SamplingParams

prompts = [
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
    model="facebook/opt-6.7b",
    tensor_parallel_size=1,
    speculative_model="[ngram]",
    num_speculative_tokens=5,
    ngram_prompt_lookup_max=4,
    use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

使用 MLP 推测器进行推测#

以下代码将 vLLM 配置为使用推测解码,其中提案是通过草稿模型生成的,这些草稿模型将草稿预测条件化为上下文向量和采样令牌。有关更多信息,请参阅“此博客 <https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/>”或“此技术报告 <https://arxiv.org/abs/2404.19124>”。

from vllm import LLM, SamplingParams

prompts = [
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
    model="meta-llama/Meta-Llama-3.1-70B-Instruct",
    tensor_parallel_size=4,
    speculative_model="ibm-fms/llama3-70b-accelerator",
    speculative_draft_tensor_parallel_size=1,
    use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

请注意,这些推测模型目前需要在没有张量并行的情况下运行,尽管可以使用张量并行运行主模型(参见上面的示例)。由于推测模型相对较小,我们仍然看到了显著的加速。但是,此限制将在未来版本中修复。

各种类型的推测模型在 HF 集线器上可用:

推测解码的无损保证#

在 vLLM 中,推测解码旨在提高推理效率,同时保持准确性。本节讨论推测解码的无损保证,将保证细分为三个关键领域:

  1. 理论无损性 - 推测解码采样在理论上是无损的,直到硬件数值的精度限制。浮点误差可能会导致输出分布略有变化,如 Accelerating Large Language Model Decoding with Speculative Sampling 中所述。

  2. 算法无损性 - vLLM 对推测解码的实现经过算法验证,证明是无损的。关键验证测试包括:

    • 拒绝采样收敛: 确保来自 vLLM 拒绝采样的样本与目标分布一致。查看测试代码

    • 贪婪采样等式: 确认带有推测解码的贪婪采样与不带推测解码的贪婪采样匹配。这验证了 vLLM 的推测解码框架在与 vLLM 前向传递和 vLLM 拒绝采样器集成时,提供了无损保证。此目录 中几乎所有测试都使用`此断言实现 <vllm-project/vllm>`_ 验证了此属性。

  3. vLLM 对数概率稳定性 - vLLM 目前不保证稳定的令牌对数概率(logprob)。这可能导致相同请求在不同运行中产生不同的输出。有关更多详细信息,请参阅`常见问题解答 <../serving/faq.rst>`_ 中标题为“vLLM 中提示的输出是否会在不同运行中有所不同?”的常见问题解答部分。

结论

虽然 vLLM 努力确保推测解码中的无损性,但由于以下因素,带有和不带推测解码的生成输出可能会出现差异:

  • 浮点精度: 硬件数值精度的差异可能会导致输出分布的细微差异。

  • 批次大小和数值稳定性: 批次大小的变化可能会导致 logprob 和输出概率的变化,这可能是由于批次操作中的非确定性行为或数值不稳定性造成的。

缓解策略

有关缓解策略,请参阅`常见问题解答 <../serving/faq.rst>`_ 中的常见问题解答条目“vLLM 中提示的输出是否会在不同运行中有所不同?”。

面向 vLLM 贡献者的资源#