原文链接: https://zhuanlan.zhihu.com/p/2011494527026894039
本文主要参考资料:
One Kernel for All Your GPUs
https://github.com/HazyResearch/ThunderKittens
部分写作 powered by Kimi K2.5,学习过程 powered by Claude opus 4.6。
Ulysses 的基本思路
Transformer 里除了 Attention 都是 element-wise 的操作,这些部分切序列长度 N 很方便。但 Attention 切 N 比较麻烦,切 head 才顺手。
Ulysses 的做法是在进出 Attention 的时候做"切 N"到"切 head"的转换,用 all-to-all 通信来做这个分布式转置。
实际 PyTorch 代码里,Ulysses 需要在通信前后做数据重排:
1 2 3 4 5 6 7 8 9 10 input_t = input .view(B, N_per_rank, world_size, H_per_rank, D) \ .permute(2 , 1 , 0 , 3 , 4 ).contiguous() torch.distributed.all_to_all_single(output_t, input_t) output.copy_(output_t.permute(2 , 0 , 1 , 3 , 4 ) .reshape(B, N, H_per_rank, D))
all to all 要求输入必须连续,那两个 contiguous() 导致 all to all 的带宽根本没法跑满。用 https://github.com/HazyResearch/ThunderKittens/blob/main/kernels/parallel/all_to_all/benchmark.py 在 8p H20 上面跑,NCCL 的 Ulysses 方案只能跑到 155 GB/s,而不考虑拷贝,纯 all to all 能跑到 285 GB/s。
ParallelKittens
ParallelKittens: Simple and Fast Multi-GPU AI Kernels
ParallelKittens 是 ThunderKittens 的多 GPU 扩展版本,也是 ThunderKittens 2.0 的一部分。通过 ParallelKittens 做 Ulysses 能在 8p H20 上做到 344 GB/s 的带宽。
核心思路
PK 的做法很直接:kernel 内部通过坐标计算直接确定每个 tile 该去哪,不需要显式的数据重排。
1 2 3 4 5 6 每个 GPU 跑同一个 kernel,每个 block 1 个线程,处理 1 个 tile: GPU_i 的 block_j: 1. TMA load: 本 GPU HBM → SMEM(硬件 DMA) 2. 索引计算: 确定目标 GPU 和目标位置 3. TMA store: SMEM → 目标 GPU HBM(通过 NVLink)
kernel 也没几行
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 template <int SCATTER_AXIS, int GATHER_AXIS>__device__ inline void kernel (const globals &G) { static_assert (0 <= SCATTER_AXIS && SCATTER_AXIS < 4 && 0 <= GATHER_AXIS && GATHER_AXIS < 4 , "Scatter and gather axes must be 0, 1, 2, or 3" ); static_assert (SCATTER_AXIS != GATHER_AXIS, "Scatter and gather axes must be different" ); extern __shared__ int __shm[]; tma_swizzle_allocator allocator ((int *)&__shm[0 ]) ; globals::shared_tile &tile = allocator.allocate <globals::shared_tile>(); int task_idx = blockIdx.x; int batch_idx = task_idx / (G.input.depth () * (G.input.rows () / globals::ROW_BLOCK_SIZE) * (G.input.cols () / globals::COL_BLOCK_SIZE)); task_idx %= (G.input.depth () * (G.input.rows () / globals::ROW_BLOCK_SIZE) * (G.input.cols () / globals::COL_BLOCK_SIZE)); int depth_idx = task_idx / (G.input.rows () / globals::ROW_BLOCK_SIZE * (G.input.cols () / globals::COL_BLOCK_SIZE)); task_idx %= (G.input.rows () / globals::ROW_BLOCK_SIZE * (G.input.cols () / globals::COL_BLOCK_SIZE)); int row_block_idx = task_idx / (G.input.cols () / globals::COL_BLOCK_SIZE); task_idx %= (G.input.cols () / globals::COL_BLOCK_SIZE); int col_block_idx = task_idx; __shared__ semaphore arrived; init_semaphore (arrived, 0 , 1 ); tma::expect_bytes (arrived, sizeof (tile)); tma::load_async (tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); int dst_dev_idx; if constexpr (SCATTER_AXIS == 0 ) { dst_dev_idx = batch_idx / G.output.batch (); batch_idx %= G.output.batch (); } else if constexpr (SCATTER_AXIS == 1 ) { dst_dev_idx = depth_idx / G.output.depth (); depth_idx %= G.output.depth (); } else if constexpr (SCATTER_AXIS == 2 ) { dst_dev_idx = row_block_idx / (G.output.rows () / globals::ROW_BLOCK_SIZE); row_block_idx %= (G.output.rows () / globals::ROW_BLOCK_SIZE); } else { dst_dev_idx = col_block_idx / (G.output.cols () / globals::COL_BLOCK_SIZE); col_block_idx %= (G.output.cols () / globals::COL_BLOCK_SIZE); } if constexpr (GATHER_AXIS == 0 ) { batch_idx += G.input.batch () * G.dev_idx; } else if constexpr (GATHER_AXIS == 1 ) { depth_idx += G.input.depth () * G.dev_idx; } else if constexpr (GATHER_AXIS == 2 ) { row_block_idx += (G.input.rows () / globals::ROW_BLOCK_SIZE) * G.dev_idx; } else { col_block_idx += (G.input.cols () / globals::COL_BLOCK_SIZE) * G.dev_idx; } wait (arrived, 0 ); tma::store_async (G.output[dst_dev_idx], tile, {batch_idx, depth_idx, row_block_idx, col_block_idx}); }
怎么做到的
1. 跨 GPU 内存访问(IPC + VMM)
要让 GPU 直接写别的 GPU 显存,先用 CUDA 的 Virtual Memory Management API:
1 2 3 1. VMM 分配 (cuMemCreate + cuMemMap + cuMemSetAccess) 2. 交换 IPC handle(Unix Domain Socket) 3. 映射远端内存(cuMemImportFromShareableHandle + cuMemMap)
最后每个 GPU 都能通过本地指针访问其他 7 个 GPU 的显存。博客里面有个图:
2. PGL(Parallel Global Layout)
TK 搞了个 pgl 结构管理跨 GPU tensor:
1 2 3 4 5 template <GL, NUM_DEVICES=8 >struct pgl { GL gls[NUM_DEVICES]; };
G.output[dst_dev_idx] 返回目标 GPU 的 GL,其 TMA 描述符指向远端显存。
3. Kernel 逻辑
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 __device__ void kernel(const globals &G) { // 从 blockIdx.x 解码 4D 坐标 (batch_idx, depth_idx, row_block_idx, col_block_idx) ← blockIdx.x // TMA load: 本 GPU HBM → SMEM tma::load_async(tile, G.input[G.dev_idx], coords, semaphore); // 计算目标 GPU 和目标坐标 dst_dev_idx = ...; output_coords = ...; // TMA store 到远端 wait(semaphore, 0); tma::store_async(G.output[dst_dev_idx], tile, output_coords); }
关键就是索引重映射替代了显式 permute 。kernel 里算好坐标,数据直接搬到目标位置。
4. TMA(Tensor Memory Accelerator)
用的 PTX 指令:
1 2 3 4 5 # Load: HBM → SMEM(异步,硬件 DMA) cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes # Store: SMEM → HBM(异步,可写远端) cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group
TMA 是硬件 DMA,不占用 SM 计算资源,1 个线程就能驱动。目标地址是 IPC 映射的远端指针时,TMA 自动走 NVLink 传输。
一些限制
主要限制在于跨 GPU 内存访问,只能在一个 NVLink 域内通信。
还有一点是对于超节点 (例如 GB200 NVL72, 一个容器只能拿到四个卡),没法通过单机进程通信交换 handle。不过 NV 对此早有考虑 ,CUDA 12.4+ 有 CU_MEM_HANDLE_TYPE_FABRIC,配合 nvidia-imex daemon 可以跨节点交换 handle。只是 TK 还没支持这个场景,能拿到超节点的大哥可以试试看效果。