ThunderKittens 2.0(1): 如何优化 Ulysses

原文链接: https://zhuanlan.zhihu.com/p/2011494527026894039

本文主要参考资料:

One Kernel for All Your GPUs

https://github.com/HazyResearch/ThunderKittens

部分写作 powered by Kimi K2.5,学习过程 powered by Claude opus 4.6。

Ulysses 的基本思路

Transformer 里除了 Attention 都是 element-wise 的操作,这些部分切序列长度 N 很方便。但 Attention 切 N 比较麻烦,切 head 才顺手。

Ulysses 的做法是在进出 Attention 的时候做"切 N"到"切 head"的转换,用 all-to-all 通信来做这个分布式转置。

实际 PyTorch 代码里,Ulysses 需要在通信前后做数据重排:

1
2
3
4
5
6
7
8
9
10
# 通信前(切的是 N, HEAD 是完整的)
input_t = input.view(B, N_per_rank, world_size, H_per_rank, D) \
.permute(2, 1, 0, 3, 4).contiguous() # 一次完整拷贝

# 通信
torch.distributed.all_to_all_single(output_t, input_t)

# 通信后 (切的是 HEAD,N 是完整的)
output.copy_(output_t.permute(2, 0, 1, 3, 4)
.reshape(B, N, H_per_rank, D)) # 又一次拷贝

all to all 要求输入必须连续,那两个 contiguous() 导致 all to all 的带宽根本没法跑满。用 https://github.com/HazyResearch/ThunderKittens/blob/main/kernels/parallel/all_to_all/benchmark.py 在 8p H20 上面跑,NCCL 的 Ulysses 方案只能跑到 155 GB/s,而不考虑拷贝,纯 all to all 能跑到 285 GB/s。

ParallelKittens

ParallelKittens: Simple and Fast Multi-GPU AI Kernels

ParallelKittens 是 ThunderKittens 的多 GPU 扩展版本,也是 ThunderKittens 2.0 的一部分。通过 ParallelKittens 做 Ulysses 能在 8p H20 上做到 344 GB/s 的带宽。

核心思路

PK 的做法很直接:kernel 内部通过坐标计算直接确定每个 tile 该去哪,不需要显式的数据重排。

1
2
3
4
5
6
每个 GPU 跑同一个 kernel,每个 block 1 个线程,处理 1 个 tile:

GPU_i 的 block_j:
1. TMA load: 本 GPU HBM → SMEM(硬件 DMA)
2. 索引计算: 确定目标 GPU 和目标位置
3. TMA store: SMEM → 目标 GPU HBM(通过 NVLink)

kernel 也没几行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
template <int SCATTER_AXIS, int GATHER_AXIS>
__device__ inline void kernel(const globals &G) {
static_assert(0 <= SCATTER_AXIS && SCATTER_AXIS < 4 && 0 <= GATHER_AXIS && GATHER_AXIS < 4,
"Scatter and gather axes must be 0, 1, 2, or 3");
static_assert(SCATTER_AXIS != GATHER_AXIS, "Scatter and gather axes must be different");

extern __shared__ int __shm[];
tma_swizzle_allocator allocator((int*)&__shm[0]);
globals::shared_tile &tile = allocator.allocate<globals::shared_tile>();

// Calculate the input indices
int task_idx = blockIdx.x;
int batch_idx = task_idx / (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE));
task_idx %= (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE));
int depth_idx = task_idx / (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE));
task_idx %= (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE));
int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE);
task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE);
int col_block_idx = task_idx;

// Load input data (assume a single-threaded block)
__shared__ semaphore arrived;
init_semaphore(arrived, 0, 1);
tma::expect_bytes(arrived, sizeof(tile));
tma::load_async(tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived);

// Calculate the output indices
int dst_dev_idx;

if constexpr (SCATTER_AXIS == 0) {
dst_dev_idx = batch_idx / G.output.batch();
batch_idx %= G.output.batch();
} else if constexpr (SCATTER_AXIS == 1) {
dst_dev_idx = depth_idx / G.output.depth();
depth_idx %= G.output.depth();
} else if constexpr (SCATTER_AXIS == 2) {
dst_dev_idx = row_block_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE);
row_block_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE);
} else {
dst_dev_idx = col_block_idx / (G.output.cols() / globals::COL_BLOCK_SIZE);
col_block_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE);
}

if constexpr (GATHER_AXIS == 0) {
batch_idx += G.input.batch() * G.dev_idx;
} else if constexpr (GATHER_AXIS == 1) {
depth_idx += G.input.depth() * G.dev_idx;
} else if constexpr (GATHER_AXIS == 2) {
row_block_idx += (G.input.rows() / globals::ROW_BLOCK_SIZE) * G.dev_idx;
} else {
col_block_idx += (G.input.cols() / globals::COL_BLOCK_SIZE) * G.dev_idx;
}

// Wait for inputs to arrive and store data to destination device
wait(arrived, 0);
tma::store_async(G.output[dst_dev_idx], tile,
{batch_idx, depth_idx, row_block_idx, col_block_idx});
}

怎么做到的

1. 跨 GPU 内存访问(IPC + VMM)

要让 GPU 直接写别的 GPU 显存,先用 CUDA 的 Virtual Memory Management API:

1
2
3
1. VMM 分配 (cuMemCreate + cuMemMap + cuMemSetAccess)
2. 交换 IPC handle(Unix Domain Socket)
3. 映射远端内存(cuMemImportFromShareableHandle + cuMemMap)

最后每个 GPU 都能通过本地指针访问其他 7 个 GPU 的显存。博客里面有个图:

2. PGL(Parallel Global Layout)

TK 搞了个 pgl 结构管理跨 GPU tensor:

1
2
3
4
5
template<GL, NUM_DEVICES=8>
struct pgl {
GL gls[NUM_DEVICES]; // 8 个 GL,每个包含 raw_ptr 和 TMA 描述符
};

G.output[dst_dev_idx] 返回目标 GPU 的 GL,其 TMA 描述符指向远端显存。

3. Kernel 逻辑

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
__device__ void kernel(const globals &G) {
// 从 blockIdx.x 解码 4D 坐标
(batch_idx, depth_idx, row_block_idx, col_block_idx) ← blockIdx.x

// TMA load: 本 GPU HBM → SMEM
tma::load_async(tile, G.input[G.dev_idx], coords, semaphore);

// 计算目标 GPU 和目标坐标
dst_dev_idx = ...;
output_coords = ...;

// TMA store 到远端
wait(semaphore, 0);
tma::store_async(G.output[dst_dev_idx], tile, output_coords);
}

关键就是索引重映射替代了显式 permute。kernel 里算好坐标,数据直接搬到目标位置。

4. TMA(Tensor Memory Accelerator)

用的 PTX 指令:

1
2
3
4
5
# Load: HBM → SMEM(异步,硬件 DMA)
cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes

# Store: SMEM → HBM(异步,可写远端)
cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group

TMA 是硬件 DMA,不占用 SM 计算资源,1 个线程就能驱动。目标地址是 IPC 映射的远端指针时,TMA 自动走 NVLink 传输。

一些限制

主要限制在于跨 GPU 内存访问,只能在一个 NVLink 域内通信。

还有一点是对于超节点 (例如 GB200 NVL72, 一个容器只能拿到四个卡),没法通过单机进程通信交换 handle。不过 NV 对此早有考虑 ,CUDA 12.4+ 有 CU_MEM_HANDLE_TYPE_FABRIC,配合 nvidia-imex daemon 可以跨节点交换 handle。只是 TK 还没支持这个场景,能拿到超节点的大哥可以试试看效果。