vLLM 分页注意力#

  • 目前,vLLM 使用其自己的多头查询注意力内核实现 (csrc/attention/attention_kernels.cu)。此内核旨在与 vLLM 的分页 KV 缓存兼容,其中键和值缓存存储在单独的块中(请注意,此块概念不同于 GPU 线程块。因此,在后面的文档中,我将 vLLM 分页注意力块称为“块”,而将 GPU 线程块称为“线程块”)。

  • 为了实现高性能,此内核依赖于专门设计的内存布局和访问方法,特别是在线程从全局内存读取数据到共享内存时。本文档的目的是逐步提供对内核实现的高级解释,帮助那些希望了解 vLLM 多头查询注意力内核的人。阅读完本文档后,用户可能会更好地理解并更容易地理解实际实现。

  • 请注意,本文档可能不会涵盖所有细节,例如如何计算对应数据的正确索引或点乘实现。但是,在阅读完本文档并熟悉高级逻辑流程后,你应该更容易阅读实际代码并理解细节。

输入#

  • 内核函数接收当前线程执行其分配工作的一系列参数。三个最重要的参数是输入指针``q``、k_cache``和``v_cache,它们指向需要读取和处理的全局内存上的查询、键和值数据。输出指针``out``指向应写入结果的全局内存。这四个指针实际上是指向多维数组,但每个线程只访问分配给它的数据部分。为了简单起见,我省略了这里的所有其他运行时参数。

    template<
    typename scalar_t,
    int HEAD_SIZE,
    int BLOCK_SIZE,
    int NUM_THREADS,
    int PARTITION_SIZE = 0>
    __device__ void paged_attention_kernel(
    ... // Other side args.
    const scalar_t* __restrict__ out,       // [num_seqs, num_heads, max_num_partitions, head_size]
    const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
    const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
    const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
    ... // Other side args.
    )
    
  • 函数签名上方还有一些在编译时确定的模板参数。``scalar_t``表示查询、键和值数据元素的数据类型,例如 FP16。``HEAD_SIZE``表示每个头中的元素数量。``BLOCK_SIZE``指的是每个块中的标记数量。``NUM_THREADS``表示每个线程块中的线程数量。``PARTITION_SIZE``表示张量并行 GPU 的数量(为简单起见,我们假设它为 0 且张量并行被禁用)。

  • 有了这些参数,我们需要执行一系列准备工作。这包括计算当前头索引、块索引和其他必要的变量。但是,现在我们可以忽略这些准备工作,直接进行实际计算。一旦我们掌握了整个流程,就会更容易理解它们。

概念#

  • 在我们深入计算流程之前,我想描述一些后面部分需要的概念。但是,如果你遇到任何令人困惑的术语,你可以跳过本节,稍后再回来。

  • 序列: 序列代表一个客户端请求。例如,由 q 指向的数据形状为 [num_seqs, num_heads, head_size]。这意味着 q 指向的查询序列数据共有 num_seqs 个。由于此内核是一个单查询注意力内核,因此每个序列只有一个查询令牌。因此,num_seqs 等于批处理中处理的令牌总数。

  • 上下文: 上下文由序列生成的令牌组成。例如,["What", "is", "your"] 是上下文令牌,输入查询令牌是 "name"。模型可能会生成令牌 "?"

  • 向量: 向量是按顺序获取和计算的一组元素。对于查询和键数据,向量大小 (VEC_SIZE) 的确定是为了让每个线程组能够一次获取和计算 16 字节的数据。对于值数据,向量大小 (V_VEC_SIZE) 的确定是为了让每个线程能够一次获取和计算 16 字节的数据。例如,如果 scalar_t 是 FP16 (2 字节) 且 THREAD_GROUP_SIZE 为 2,则 VEC_SIZE 将为 4,而 V_VEC_SIZE 将为 8。

  • 线程组: 线程组是一个由少量线程 (THREAD_GROUP_SIZE) 组成的组,它们一次获取和计算一个查询令牌和一个键令牌。每个线程只处理令牌数据的一部分。一个线程组处理的元素总数称为 x。例如,如果线程组包含 2 个线程,且头部大小为 8,则线程 0 处理索引为 0、2、4、6 的查询和键元素,而线程 1 处理索引为 1、3、5、7 的元素。

  • : vLLM 中的键和值缓存数据被分成块。每个块存储一个头部中固定数量 (BLOCK_SIZE) 的令牌的数据。每个块可能只包含整个上下文令牌的一部分。例如,如果块大小为 16,头部大小为 128,则对于一个头部,一个块可以存储 16 * 128 = 2048 个元素。

  • Warp: 一个 warp 是一个包含 32 个线程(WARP_SIZE)的组,它们在流式多处理器 (SM) 上同时执行。在这个内核中,每个 warp 每次处理一个查询令牌与一个完整块的键令牌之间的计算(它可能在多个迭代中处理多个块)。例如,如果一个上下文有 4 个 warp 和 6 个块,则分配方式如下:warp 0 处理第 0 个和第 4 个块,warp 1 处理第 1 个和第 5 个块,warp 2 处理第 2 个块,warp 3 处理第 3 个块。

  • 线程块: 一个线程块是一个包含多个线程(NUM_THREADS)的组,它们可以访问相同的共享内存。每个线程块包含多个 warp(NUM_WARPS),在这个内核中,每个线程块处理一个查询令牌与整个上下文的键令牌之间的计算。

  • 网格: 一个网格是线程块的集合,定义了集合的形状。在这个内核中,形状为 (num_heads, num_seqs, max_num_partitions)。因此,每个线程块只处理一个头、一个序列和一个分区上的计算。

查询#

  • 本节将介绍查询数据如何在内存中存储以及如何被每个线程获取。如上所述,每个线程组获取一个查询令牌数据,而每个线程本身只处理一个查询令牌数据的一部分。在每个 warp 内,每个线程组将获取相同的查询令牌数据,但会将其与不同的键令牌数据相乘。

    const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
    
    查询

    一个头上的一个令牌的查询数据#

  • 每个线程定义自己的 q_ptr,它指向全局内存中分配的查询令牌数据。例如,如果 VEC_SIZE 为 4,HEAD_SIZE 为 128,则 q_ptr 指向包含总共 128 个元素的数据,这些元素被分成 128 / 4 = 32 个向量。

    一个线程组的 ``q_vecs``

    q_vecs for one thread group#

    __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
    
  • 接下来,我们需要将 q_ptr 指向的全局内存数据读入共享内存作为 q_vecs。需要注意的是,每个 vecs 对应不同的行。例如,如果 THREAD_GROUP_SIZE 为 2,线程 0 将处理第 0 行 vecs,而线程 1 处理第 1 行 vecs。通过这种方式读取查询数据,相邻线程(如线程 0 和线程 1)可以读取相邻内存,实现内存合并以提高性能。

#

  • 与“查询”部分类似,本部分介绍了键的内存布局和分配。虽然每个线程组在一个内核运行中只处理一个查询令牌,但它可能在多个迭代中处理多个键令牌。同时,每个 warp 将在多个迭代中处理多个键令牌块,确保在内核运行结束后,所有上下文令牌都被整个线程组处理。在此背景下,“处理”指的是执行查询数据和键数据之间的点乘运算。

    const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
                        + kv_head_idx * kv_head_stride
                        + physical_block_offset * x;
    
  • q_ptr 不同,每个线程中的 k_ptr 在不同的迭代中将指向不同的键令牌。如上所示,k_ptr 根据分配的块、分配的头和分配的令牌指向键令牌数据,这些数据基于 k_cache

    键

    一个头中所有上下文令牌的键数据#

  • 上图说明了键数据的内存布局。假设 BLOCK_SIZE 为 16,HEAD_SIZE 为 128,x 为 8,THREAD_GROUP_SIZE 为 2,并且总共有 4 个 warp。每个矩形代表一个头中一个键令牌的所有元素,这些元素将由一个线程组处理。左侧显示了 warp 0 的总共 16 个键令牌数据块,而右侧代表其他 warp 或迭代的剩余键令牌数据。在每个矩形内部,总共有 32 个 vecs(一个令牌的 128 个元素),将由 2 个线程(一个线程组)分别处理。

    k_vecs

    一个线程的 k_vecs#

    K_vec k_vecs[NUM_VECS_PER_THREAD]
    
  • 接下来,我们需要从 k_ptr 中读取关键令牌数据并将其存储在寄存器内存中作为 k_vecs。我们使用寄存器内存来存储 k_vecs,因为它们只会被一个线程访问一次,而 q_vecs 会被多个线程多次访问。每个 k_vecs 将包含多个向量,用于后续计算。每个向量将在每次内部迭代中设置。向量的分配允许一个 warp 中的相邻线程一起读取相邻内存,这再次促进了内存合并。例如,线程 0 将读取向量 0,而线程 1 将读取向量 1。在下一个内部循环中,线程 0 将读取向量 2,而线程 1 将读取向量 3,依此类推。

  • 你可能仍然对整体流程有点困惑。别担心,请继续阅读下一节“QK”。它将以更清晰、更高级的方式说明查询和键计算流程。

QK#

  • 如下面的伪代码所示,在整个 for 循环块之前,我们获取一个令牌的查询数据并将其存储在 q_vecs 中。然后,在外部 for 循环中,我们遍历指向不同令牌的不同 k_ptrs,并在内部 for 循环中准备 k_vecs。最后,我们执行 q_vecs 和每个 k_vecs 之间的点乘。

    q_vecs = ...
    for ... {
       k_ptr = ...
       for ... {
          k_vecs[i] = ...
       }
       ...
       float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
    }
    
  • 如前所述,对于每个线程,它一次只获取部分查询和关键令牌数据。但是,在 Qk_dot<>::dot 中会发生跨线程组的归约。因此,这里返回的 qk 不仅仅是查询和关键令牌点乘的一部分,而是实际上是整个查询和关键令牌数据之间的完整结果。

  • 例如,如果 HEAD_SIZE 的值为 128,而 THREAD_GROUP_SIZE 为 2,则每个线程的 k_vecs 将包含总共 64 个元素。但是,返回的 qk 实际上是 128 个查询元素与 128 个键元素之间的点乘结果。如果你想了解更多关于点乘和约简的细节,可以参考 Qk_dot<>::dot 的实现。但是,为了简单起见,我不会在本文件中介绍它。

Softmax#

  • 接下来,我们需要计算所有 qk 的归一化 softmax,如上所示,其中每个 \(x\) 代表一个 qk。为此,我们必须获得 qk_max (\(m(x)\)) 的约简值和所有 qkexp_sum (\(\ell(x)\))。约简应该在整个线程块中执行,涵盖查询标记和所有上下文键标记之间的结果。

    \begin{gather*} m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\ \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} \end{gather*}

qk_maxlogits#

  • 在我们获得 qk 结果后,我们可以使用 qk 设置临时 logits 结果(最终,logits 应该存储归一化 softmax 结果)。我们还可以比较和收集当前线程组计算的所有 qkqk_max

    if (thread_group_offset == 0) {
       const bool mask = token_idx >= context_len;
       logits[token_idx - start_token_idx] = mask ? 0.f : qk;
       qk_max = mask ? qk_max : fmaxf(qk_max, qk);
    }
    
  • 请注意,这里的 logits 位于共享内存中,因此每个线程组将为其分配的上下文标记设置字段。总的来说,logits 的大小应该是上下文标记的数量。

    for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
        qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
    }
    
    if (lane == 0) {
       red_smem[warp_idx] = qk_max;
    }
    
  • 然后我们需要在每个 warp 中获取约简后的 qk_max。主要思想是让 warp 中的线程相互通信,并获得最终的最大 qk

    for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
        qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
    }
    qk_max = VLLM_SHFL_SYNC(qk_max, 0);
    
  • 最后,我们可以通过比较该线程块中所有 warp 的 qk_max 来获得整个线程块的约简后的 qk_max。然后我们需要将最终结果广播到每个线程。

exp_sum#

  • qk_max 类似,我们也需要从整个线程块中获取缩减后的求和值。

    for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
        float val = __expf(logits[i] - qk_max);
        logits[i] = val;
        exp_sum += val;
    }
    ...
    exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
    
  • 首先,将每个线程组中的所有 exp 值求和,同时将 logits 中的每个条目从 qk 转换为 exp(qk - qk_max)。请注意,这里的 qk_max 已经是整个线程块中最大的 qk。然后,我们可以像 qk_max 一样对整个线程块的 exp_sum 进行缩减。

    const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
    for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
       logits[i] *= inv_sum;
    }
    
  • 最后,使用缩减后的 qk_maxexp_sum,我们可以获得最终的归一化 softmax 结果作为 logits。这个 logits 变量将在后面的步骤中用于与值数据进行点乘。现在,它应该存储所有分配的上下文标记的 qk 的归一化 softmax 结果。

#

值

一个头部所有上下文标记的值数据#

logits_vec

一个线程的 logits_vec#

v_vec

一个线程的 v_vec 列表#

  • 现在我们需要检索值数据并与 logits 进行点乘。与查询和键不同,值数据没有线程组的概念。如示意图所示,与键标记内存布局不同,同一列中的元素对应于同一个值标记。对于一个值数据块,有 HEAD_SIZE 行和 BLOCK_SIZE 列,它们被分成多个 v_vecs

  • 每个线程总是从同一 V_VEC_SIZE 个标记中一次获取 V_VEC_SIZE 个元素。因此,单个线程通过多个内部迭代从不同行和同一列中检索多个 v_vec。对于每个 v_vec,它需要与相应的 logits_vec 进行点乘,该 logits_vec 也是来自 logitsV_VEC_SIZE 个元素。总的来说,通过多个内部迭代,每个 warp 将处理一个值标记块。通过多个外部迭代,整个上下文值标记被处理。

    float accs[NUM_ROWS_PER_THREAD];
    for ... { // Iteration over different blocks.
        logits_vec = ...
        for ... { // Iteration over different rows.
            v_vec = ...
            ...
            accs[i] += dot(logits_vec, v_vec);
        }
    }
    
  • 如上伪代码所示,在外层循环中,类似于 k_ptrlogits_vec 遍历不同的块并从 logits 中读取 V_VEC_SIZE 个元素。在内层循环中,每个线程从相同 token 中读取 V_VEC_SIZE 个元素作为 v_vec 并执行点乘。需要注意的是,在每次内循环迭代中,线程会为相同 token 获取不同的头部位置元素。点乘结果随后累积到 accs 中。因此,accs 中的每个条目都映射到分配给当前线程的头部位置。

  • 例如,如果 BLOCK_SIZE 为 16 且 V_VEC_SIZE 为 8,则每个线程一次性为 8 个 token 获取 8 个值元素。每个元素都来自相同头部位置的不同 token。如果 HEAD_SIZE 为 128 且 WARP_SIZE 为 32,则对于每个内循环,一个 warp 需要获取 WARP_SIZE * V_VEC_SIZE = 256 个元素。这意味着一个 warp 需要进行 128 * 16 / 256 = 8 次内循环迭代才能处理一整块值 token。并且每个线程中的每个 accs 都包含 8 个元素,这些元素是在 8 个不同的头部位置累积的。对于线程 0,accs 变量将包含 8 个元素,分别是值头的第 0 个、第 32 个……第 224 个元素,这些元素是从所有分配的 8 个 token 中累积的。

LV#

  • 现在,我们需要在每个 warp 内对 accs 进行归约。此过程允许每个线程累积分配给一个块中所有 token 的头部位置的 accs

    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
       float acc = accs[i];
       for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
          acc += VLLM_SHFL_XOR_SYNC(acc, mask);
       }
       accs[i] = acc;
    }
    
  • 接下来,我们对所有 warp 中的 accs 进行归约,允许每个线程拥有分配给所有上下文 token 的头部位置的 accs 的累积值。请注意,每个线程中的每个 accs 只存储所有上下文 token 的整个头的部分元素的累积值。但是,总体而言,所有输出结果都已计算出来,只是存储在不同的线程寄存器内存中。

    float* out_smem = reinterpret_cast<float*>(shared_mem);
    for (int i = NUM_WARPS; i > 1; i /= 2) {
        // Upper warps write to shared memory.
        ...
            float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
            for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
                    ...
            dst[row_idx] = accs[i];
        }
    
        // Lower warps update the output.
            const float* src = &out_smem[warp_idx * HEAD_SIZE];
        for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
                    ...
            accs[i] += src[row_idx];
        }
    
            // Write out the accs.
    }
    

输出#

  • 现在,我们可以将所有计算结果从本地寄存器内存写入最终输出全局内存。

    scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
                    + head_idx * max_num_partitions * HEAD_SIZE
                    + partition_idx * HEAD_SIZE;
    
  • 首先,我们需要定义 out_ptr 变量,它指向分配序列和分配头的起始地址。

    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
    if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
        from_float(*(out_ptr + row_idx), accs[i]);
    }
    }
    
  • 最后,我们需要遍历不同的分配头位置,并根据 out_ptr 写出相应的累积结果。