<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
  <title>Reku</title>
  
  
  <link href="https://wyc-ruiker.github.io/atom.xml" rel="self"/>
  
  <link href="https://wyc-ruiker.github.io/"/>
  <updated>2026-04-19T17:08:50.178Z</updated>
  <id>https://wyc-ruiker.github.io/</id>
  
  <author>
    <name>Reku</name>
    
  </author>
  
  <generator uri="https://hexo.io/">Hexo</generator>
  
  <entry>
    <title>Claude opus 4.6 独立发现的 DeepEP 时序问题</title>
    <link href="https://wyc-ruiker.github.io/2026/03/08/zhihu/Claude-opus-4.6-%E7%8B%AC%E7%AB%8B%E5%8F%91%E7%8E%B0%E7%9A%84-DeepEP-%E6%97%B6%E5%BA%8F%E9%97%AE%E9%A2%98/"/>
    <id>https://wyc-ruiker.github.io/2026/03/08/zhihu/Claude-opus-4.6-%E7%8B%AC%E7%AB%8B%E5%8F%91%E7%8E%B0%E7%9A%84-DeepEP-%E6%97%B6%E5%BA%8F%E9%97%AE%E9%A2%98/</id>
    <published>2026-03-07T16:00:00.000Z</published>
    <updated>2026-04-19T17:08:50.178Z</updated>
    
    <content type="html"><![CDATA[<blockquote><p>原文链接: <a href="https://zhuanlan.zhihu.com/p/2013987217081644893">https://zhuanlan.zhihu.com/p/2013987217081644893</a></p></blockquote><p>DeepEP Issue：<a href="https://github.com/deepseek-ai/DeepEP/issues/589">https://github.com/deepseek-ai/DeepEP/issues/589</a></p><h2 id="背景">背景</h2><p>前几天同事在跑实验，发现某些场景下 EP 通信方式用 DeepEP 会导致梯度变成 nan，通信方式改成 allgather 就好了。组里大哥们一起尝试复现了一下，发现了几点：</p><ul><li>可以把网络规模缩小的很小，单机就可以复现</li><li>开启 CUDA_LAUNCH_BLOCKING=1 后问题消失</li><li>有时候添加打印或者同步，问题又不复现</li><li>EP 越小越容易复现</li></ul><p>顿时觉得头皮发麻，很显然是一种复杂的时序问题，不清楚是 megatron 的问题还是 DeepEP 的问题，也可能是在某种神秘场景，联合作用下导致的 bug。而 DeepEP 我又不怎么精通，感觉 debug 会非常困难。</p><h2 id="vibe-debug-流水账">vibe debug 流水账</h2><p>本周三临睡前，准备试一下用 cc (Claude Code) 去定位这个问题，因为我有个稳定的复现脚本（agent 可以直接观察他是不是 nan，信号很明确），各种源码也都有（agent 能拿到所有的上下文），理论上非常适合 vibe。把复现脚本、megatron 代码、DeepEP 代码都给 claude 准备好，不过我们还没什么项目 skills，cc 的领域知识只能从源码获取，只告诉 cc 这个脚本有 bug，去研究一下为什么。yolo 一开直接睡觉。</p><p>周四感觉 vibe 情况不是非常乐观，cc 给出了几个错误的猜想都被我否定了，就开始自暴自弃，说什么问题很难定位不出来。我准备换个思路，让他先定界是 megatron 的问题还是 DeepEP 的问题，尝试把这个网络的 DeepEP 相关业务逻辑抠出来，脱离 megatron 进行复现。</p><p>周四周五公司有活动，大部分时间都在坐飞机/吃饭/喝酒/吹牛逼/听别人吹牛逼，只能每隔几个小时去看一下 cc 的 debug 情况，我和同事戏称为“收菜”。但实际上 cc 的定位进度还是堪忧，脱离 megatron 也没法复现，而且特别容易放弃，只能反复让他尝试。</p><p>复现到了周五晚上也没什么进展，准备再换个思路，不尝试脱离 megatron 复现，而是让他简化 megatron，一步步的把没用的组件剥离出去，yolo 一开继续睡觉。</p><p>周六一早，发现很牛逼，直接和我说复现出来了，我自己一跑发现确实如此，脚本的现象是这样的：</p><ul><li>torch 开启 deterministic + DeepEP 有 nan</li><li>torch 关闭 deterministic + DeepEP 没有 nan</li><li>torch 开启 deterministic + DeepEP + CUDA_LAUNCH_BLOCKING 没有 nan</li></ul><p>需要 deterministic + DeepEP + cpu 异步同时满足，非常神秘啊。yolo 一开，继续让 cc 找一下问题根因。</p><p>周六吃喝玩乐一天，晚上回家发现已经搞定了，根因如下：</p><ul><li>同事为了对精度，开了 <code>torch.use_deterministic_algorithms(True)</code></li><li>开启确定性之后，torch 会自动把新分配出来的显存写成 nan，参考 <a href="https://docs.pytorch.org/docs/stable/deterministic.html">torch.utils.deterministic.fill_uninitialized_memory</a></li><li>而 DeepEP 的业务逻辑是这样的：</li></ul><figure class="highlight cpp"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// Step 1: stream_wait synchronizes comm_stream with compute_stream</span></span><br><span class="line"><span class="comment">//         comm_stream will wait for everything *currently* enqueued on compute_stream</span></span><br><span class="line"><span class="built_in">stream_wait</span>(comm_stream, compute_stream);</span><br><span class="line"></span><br><span class="line"><span class="comment">// Step 2: torch::empty() allocates tensors AFTER the sync point</span></span><br><span class="line"><span class="comment">//         When fill_uninitialized_memory=True, this launches a NaN-fill kernel</span></span><br><span class="line"><span class="comment">//         on compute_stream — which comm_stream does NOT wait for</span></span><br><span class="line"><span class="keyword">auto</span> recv_x = torch::<span class="built_in">empty</span>(&#123;num_recv_tokens, hidden&#125;, x.<span class="built_in">options</span>());</span><br><span class="line"></span><br><span class="line"><span class="comment">// Step 3: communication kernel runs on comm_stream, writing to recv_x</span></span><br><span class="line">intranode::<span class="built_in">combine</span>(..., recv_x.<span class="built_in">data_ptr</span>(), ..., comm_stream, ...);</span><br><span class="line"></span><br></pre></td></tr></table></figure><ul><li>DeepEP 分配一块 empty 当输出，本来只有通信流在写数据，现在主流也会往里面写 nan，导致两个写操作并发了，制造了时序问题。这个问题的复现频率与 EP 执行速度、算子下发速度都有关系，所以看起来很复杂。</li><li>本质是 torch 框架在 deterministic 的场景下破坏了约定，在 empty 的时候有写操作。但很难说是 bug 还是 feature。</li></ul><h2 id="思考">思考</h2><p>claude debug 成功之后，我是非常震撼的，因为这不是一个很简单的 bug，涉及到很多功能的交叉。各位 ai infra 程序员可以扪心自问一下，直接面对这个问题，大家要花几个工作日能定位出来，定位过程中心态会不会出现问题。但现在，我只是在娱乐过程中随便 prompt 了几下，claude 就全搞定了。</p><p>当然整个 vibe debug 过程没有这么顺利，如果公司没有活动，是正常的工作日，我可能对 claude 也没有这么多的耐心。这说明我们还是需要 agent 工程、需要 skills。但模型的智能已经够了，我觉得我不算很菜的程序员，但 opus 4.6 比我要强。</p><p>如何把智能充分发挥出来只是个时间问题，tipping point 已经来临了。</p>]]></content>
    
    
      
      
    <summary type="html">&lt;blockquote&gt;
&lt;p&gt;原文链接: &lt;a href=&quot;https://zhuanlan.zhihu.com/p/2013987217081644893&quot;&gt;https://zhuanlan.zhihu.com/p/2013987217081644893&lt;/a&gt;&lt;/p&gt;
&lt;/</summary>
      
    
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/categories/zhihu/"/>
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/tags/zhihu/"/>
    
  </entry>
  
  <entry>
    <title>ThunderKittens 2.0(1): 如何优化 Ulysses</title>
    <link href="https://wyc-ruiker.github.io/2026/03/01/zhihu/ThunderKittens-2.0(1)_-%E5%A6%82%E4%BD%95%E4%BC%98%E5%8C%96-Ulysses/"/>
    <id>https://wyc-ruiker.github.io/2026/03/01/zhihu/ThunderKittens-2.0(1)_-%E5%A6%82%E4%BD%95%E4%BC%98%E5%8C%96-Ulysses/</id>
    <published>2026-02-28T16:00:00.000Z</published>
    <updated>2026-04-19T17:08:50.201Z</updated>
    
    <content type="html"><![CDATA[<blockquote><p>原文链接: <a href="https://zhuanlan.zhihu.com/p/2011494527026894039">https://zhuanlan.zhihu.com/p/2011494527026894039</a></p></blockquote><p>本文主要参考资料：</p><p><a href="https://hazyresearch.stanford.edu/blog/2025-09-22-pgl">One Kernel for All Your GPUs</a></p><p><a href="https://github.com/HazyResearch/ThunderKittens">https://github.com/HazyResearch/ThunderKittens</a></p><p>部分写作 powered by Kimi K2.5，学习过程 powered by Claude opus 4.6。</p><h2 id="ulysses-的基本思路">Ulysses 的基本思路</h2><p>Transformer 里除了 Attention 都是 element-wise 的操作，这些部分切序列长度 N 很方便。但 Attention 切 N 比较麻烦，切 head 才顺手。</p><p>Ulysses 的做法是在进出 Attention 的时候做"切 N"到"切 head"的转换，用 all-to-all 通信来做这个分布式转置。</p><p>实际 PyTorch 代码里，Ulysses 需要在通信前后做数据重排：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 通信前（切的是 N， HEAD 是完整的）</span></span><br><span class="line">input_t = <span class="built_in">input</span>.view(B, N_per_rank, world_size, H_per_rank, D) \</span><br><span class="line">               .permute(<span class="number">2</span>, <span class="number">1</span>, <span class="number">0</span>, <span class="number">3</span>, <span class="number">4</span>).contiguous()   <span class="comment"># 一次完整拷贝</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 通信</span></span><br><span class="line">torch.distributed.all_to_all_single(output_t, input_t)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 通信后 （切的是 HEAD，N 是完整的）</span></span><br><span class="line">output.copy_(output_t.permute(<span class="number">2</span>, <span class="number">0</span>, <span class="number">1</span>, <span class="number">3</span>, <span class="number">4</span>)</span><br><span class="line">                      .reshape(B, N, H_per_rank, D))   <span class="comment"># 又一次拷贝</span></span><br></pre></td></tr></table></figure><p>all to all 要求输入必须连续，那两个 <code>contiguous()</code> 导致 all to all 的带宽根本没法跑满。用 <a href="https://github.com/HazyResearch/ThunderKittens/blob/main/kernels/parallel/all_to_all/benchmark.py">https://github.com/HazyResearch/ThunderKittens/blob/main/kernels/parallel/all_to_all/benchmark.py</a> 在 8p H20 上面跑，NCCL 的 Ulysses 方案只能跑到 155 GB/s，而不考虑拷贝，纯 all to all 能跑到 285 GB/s。</p><h2 id="parallelkittens">ParallelKittens</h2><p><a href="https://hazyresearch.stanford.edu/blog/2025-11-17-pk">ParallelKittens: Simple and Fast Multi-GPU AI Kernels</a></p><p>ParallelKittens 是 ThunderKittens 的多 GPU 扩展版本，也是 ThunderKittens 2.0 的一部分。通过 ParallelKittens 做 Ulysses 能在 8p H20 上做到 344 GB/s 的带宽。</p><h3 id="核心思路">核心思路</h3><p>PK 的做法很直接：kernel 内部通过坐标计算直接确定每个 tile 该去哪，不需要显式的数据重排。</p><figure class="highlight text"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">每个 GPU 跑同一个 kernel，每个 block 1 个线程，处理 1 个 tile：</span><br><span class="line"></span><br><span class="line">GPU_i 的 block_j:</span><br><span class="line">  1. TMA load:  本 GPU HBM → SMEM（硬件 DMA）</span><br><span class="line">  2. 索引计算:  确定目标 GPU 和目标位置</span><br><span class="line">  3. TMA store: SMEM → 目标 GPU HBM（通过 NVLink）</span><br></pre></td></tr></table></figure><p>kernel 也没几行</p><figure class="highlight cpp"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">template</span> &lt;<span class="type">int</span> SCATTER_AXIS, <span class="type">int</span> GATHER_AXIS&gt;</span><br><span class="line"><span class="function">__device__ <span class="keyword">inline</span> <span class="type">void</span> <span class="title">kernel</span><span class="params">(<span class="type">const</span> globals &amp;G)</span> </span>&#123;</span><br><span class="line">    <span class="built_in">static_assert</span>(<span class="number">0</span> &lt;= SCATTER_AXIS &amp;&amp; SCATTER_AXIS &lt; <span class="number">4</span> &amp;&amp; <span class="number">0</span> &lt;= GATHER_AXIS &amp;&amp; GATHER_AXIS &lt; <span class="number">4</span>, </span><br><span class="line">        <span class="string">&quot;Scatter and gather axes must be 0, 1, 2, or 3&quot;</span>);</span><br><span class="line">    <span class="built_in">static_assert</span>(SCATTER_AXIS != GATHER_AXIS, <span class="string">&quot;Scatter and gather axes must be different&quot;</span>);</span><br><span class="line"></span><br><span class="line">    <span class="keyword">extern</span> __shared__ <span class="type">int</span> __shm[];</span><br><span class="line">    <span class="function">tma_swizzle_allocator <span class="title">allocator</span><span class="params">((<span class="type">int</span>*)&amp;__shm[<span class="number">0</span>])</span></span>;</span><br><span class="line">    globals::shared_tile &amp;tile = allocator.<span class="built_in">allocate</span>&lt;globals::shared_tile&gt;();</span><br><span class="line"></span><br><span class="line">    <span class="comment">// Calculate the input indices</span></span><br><span class="line">    <span class="type">int</span> task_idx = blockIdx.x;</span><br><span class="line">    <span class="type">int</span> batch_idx = task_idx / (G.input.<span class="built_in">depth</span>() * (G.input.<span class="built_in">rows</span>() / globals::ROW_BLOCK_SIZE) * (G.input.<span class="built_in">cols</span>() / globals::COL_BLOCK_SIZE));</span><br><span class="line">    task_idx %= (G.input.<span class="built_in">depth</span>() * (G.input.<span class="built_in">rows</span>() / globals::ROW_BLOCK_SIZE) * (G.input.<span class="built_in">cols</span>() / globals::COL_BLOCK_SIZE));</span><br><span class="line">    <span class="type">int</span> depth_idx = task_idx / (G.input.<span class="built_in">rows</span>() / globals::ROW_BLOCK_SIZE * (G.input.<span class="built_in">cols</span>() / globals::COL_BLOCK_SIZE));</span><br><span class="line">    task_idx %= (G.input.<span class="built_in">rows</span>() / globals::ROW_BLOCK_SIZE * (G.input.<span class="built_in">cols</span>() / globals::COL_BLOCK_SIZE));</span><br><span class="line">    <span class="type">int</span> row_block_idx = task_idx / (G.input.<span class="built_in">cols</span>() / globals::COL_BLOCK_SIZE);</span><br><span class="line">    task_idx %= (G.input.<span class="built_in">cols</span>() / globals::COL_BLOCK_SIZE);</span><br><span class="line">    <span class="type">int</span> col_block_idx = task_idx;</span><br><span class="line"></span><br><span class="line">    <span class="comment">// Load input data (assume a single-threaded block)</span></span><br><span class="line">    __shared__ semaphore arrived;</span><br><span class="line">    <span class="built_in">init_semaphore</span>(arrived, <span class="number">0</span>, <span class="number">1</span>);</span><br><span class="line">    tma::<span class="built_in">expect_bytes</span>(arrived, <span class="built_in">sizeof</span>(tile));</span><br><span class="line">    tma::<span class="built_in">load_async</span>(tile, G.input[G.dev_idx], &#123;batch_idx, depth_idx, row_block_idx, col_block_idx&#125;, arrived);</span><br><span class="line"></span><br><span class="line">    <span class="comment">// Calculate the output indices</span></span><br><span class="line">    <span class="type">int</span> dst_dev_idx;</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">if</span> <span class="title">constexpr</span> <span class="params">(SCATTER_AXIS == <span class="number">0</span>)</span> </span>&#123;</span><br><span class="line">        dst_dev_idx = batch_idx / G.output.<span class="built_in">batch</span>();</span><br><span class="line">        batch_idx %= G.output.<span class="built_in">batch</span>();</span><br><span class="line">    &#125; <span class="keyword">else</span> <span class="keyword">if</span> <span class="built_in">constexpr</span> (SCATTER_AXIS == <span class="number">1</span>) &#123;</span><br><span class="line">        dst_dev_idx = depth_idx / G.output.<span class="built_in">depth</span>();</span><br><span class="line">        depth_idx %= G.output.<span class="built_in">depth</span>();</span><br><span class="line">    &#125; <span class="keyword">else</span> <span class="keyword">if</span> <span class="built_in">constexpr</span> (SCATTER_AXIS == <span class="number">2</span>) &#123;</span><br><span class="line">        dst_dev_idx = row_block_idx / (G.output.<span class="built_in">rows</span>() / globals::ROW_BLOCK_SIZE);</span><br><span class="line">        row_block_idx %= (G.output.<span class="built_in">rows</span>() / globals::ROW_BLOCK_SIZE);</span><br><span class="line">    &#125; <span class="keyword">else</span> &#123;</span><br><span class="line">        dst_dev_idx = col_block_idx / (G.output.<span class="built_in">cols</span>() / globals::COL_BLOCK_SIZE);</span><br><span class="line">        col_block_idx %= (G.output.<span class="built_in">cols</span>() / globals::COL_BLOCK_SIZE);</span><br><span class="line">    &#125;</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">if</span> <span class="title">constexpr</span> <span class="params">(GATHER_AXIS == <span class="number">0</span>)</span> </span>&#123;</span><br><span class="line">        batch_idx += G.input.<span class="built_in">batch</span>() * G.dev_idx;</span><br><span class="line">    &#125; <span class="keyword">else</span> <span class="keyword">if</span> <span class="built_in">constexpr</span> (GATHER_AXIS == <span class="number">1</span>) &#123;</span><br><span class="line">        depth_idx += G.input.<span class="built_in">depth</span>() * G.dev_idx;</span><br><span class="line">    &#125; <span class="keyword">else</span> <span class="keyword">if</span> <span class="built_in">constexpr</span> (GATHER_AXIS == <span class="number">2</span>) &#123;</span><br><span class="line">        row_block_idx += (G.input.<span class="built_in">rows</span>() / globals::ROW_BLOCK_SIZE) * G.dev_idx;</span><br><span class="line">    &#125; <span class="keyword">else</span> &#123;</span><br><span class="line">        col_block_idx += (G.input.<span class="built_in">cols</span>() / globals::COL_BLOCK_SIZE) * G.dev_idx;</span><br><span class="line">    &#125;</span><br><span class="line"></span><br><span class="line">    <span class="comment">// Wait for inputs to arrive and store data to destination device</span></span><br><span class="line">    <span class="built_in">wait</span>(arrived, <span class="number">0</span>);</span><br><span class="line">    tma::<span class="built_in">store_async</span>(G.output[dst_dev_idx], tile, </span><br><span class="line">        &#123;batch_idx, depth_idx, row_block_idx, col_block_idx&#125;);</span><br><span class="line">&#125;</span><br><span class="line"></span><br></pre></td></tr></table></figure><h3 id="怎么做到的">怎么做到的</h3><p><strong>1. 跨 GPU 内存访问（IPC + VMM）</strong></p><p>要让 GPU 直接写别的 GPU 显存，先用 CUDA 的 Virtual Memory Management API：</p><figure class="highlight text"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">1. VMM 分配 (cuMemCreate + cuMemMap + cuMemSetAccess)</span><br><span class="line">2. 交换 IPC handle（Unix Domain Socket）</span><br><span class="line">3. 映射远端内存（cuMemImportFromShareableHandle + cuMemMap）</span><br></pre></td></tr></table></figure><p>最后每个 GPU 都能通过本地指针访问其他 7 个 GPU 的显存。博客里面有个图：</p><p><img src="/zhihu/images/ThunderKittens-2.0(1)_-如何优化-Ulysses_img_0.jpg"></p><p><strong>2. PGL（Parallel Global Layout）</strong></p><p>TK 搞了个 <code>pgl</code> 结构管理跨 GPU tensor：</p><figure class="highlight cpp"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">template</span>&lt;GL, NUM_DEVICES=<span class="number">8</span>&gt;</span><br><span class="line"><span class="keyword">struct</span> pgl &#123;</span><br><span class="line">    GL gls[NUM_DEVICES];  <span class="comment">// 8 个 GL，每个包含 raw_ptr 和 TMA 描述符</span></span><br><span class="line">&#125;;</span><br><span class="line"></span><br></pre></td></tr></table></figure><p><code>G.output[dst_dev_idx]</code> 返回目标 GPU 的 GL，其 TMA 描述符指向远端显存。</p><p><strong>3. Kernel 逻辑</strong></p><figure class="highlight text"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line">__device__ void kernel(const globals &amp;G) &#123;</span><br><span class="line">    // 从 blockIdx.x 解码 4D 坐标</span><br><span class="line">    (batch_idx, depth_idx, row_block_idx, col_block_idx) ← blockIdx.x</span><br><span class="line"></span><br><span class="line">    // TMA load: 本 GPU HBM → SMEM</span><br><span class="line">    tma::load_async(tile, G.input[G.dev_idx], coords, semaphore);</span><br><span class="line"></span><br><span class="line">    // 计算目标 GPU 和目标坐标</span><br><span class="line">    dst_dev_idx = ...;</span><br><span class="line">    output_coords = ...;</span><br><span class="line"></span><br><span class="line">    // TMA store 到远端</span><br><span class="line">    wait(semaphore, 0);</span><br><span class="line">    tma::store_async(G.output[dst_dev_idx], tile, output_coords);</span><br><span class="line">&#125;</span><br><span class="line"></span><br></pre></td></tr></table></figure><p>关键就是<strong>索引重映射替代了显式 permute</strong>。kernel 里算好坐标，数据直接搬到目标位置。</p><p><strong>4. TMA（Tensor Memory Accelerator）</strong></p><p>用的 PTX 指令：</p><figure class="highlight text"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"># Load: HBM → SMEM（异步，硬件 DMA）</span><br><span class="line">cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes</span><br><span class="line"></span><br><span class="line"># Store: SMEM → HBM（异步，可写远端）</span><br><span class="line">cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group</span><br></pre></td></tr></table></figure><p>TMA 是硬件 DMA，不占用 SM 计算资源，1 个线程就能驱动。目标地址是 IPC 映射的远端指针时，TMA 自动走 NVLink 传输。</p><h2 id="一些限制">一些限制</h2><p>主要限制在于跨 GPU 内存访问，只能在一个 NVLink 域内通信。</p><p>还有一点是对于超节点 （例如 GB200 NVL72， 一个容器只能拿到四个卡），没法通过单机进程通信交换 handle。不过 NV 对此早有考虑 ，CUDA 12.4+ 有 <code>CU_MEM_HANDLE_TYPE_FABRIC</code>，配合 nvidia-imex daemon 可以跨节点交换 handle。只是 TK 还没支持这个场景，能拿到超节点的大哥可以试试看效果。</p>]]></content>
    
    
      
      
    <summary type="html">&lt;blockquote&gt;
&lt;p&gt;原文链接: &lt;a href=&quot;https://zhuanlan.zhihu.com/p/2011494527026894039&quot;&gt;https://zhuanlan.zhihu.com/p/2011494527026894039&lt;/a&gt;&lt;/p&gt;
&lt;/</summary>
      
    
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/categories/zhihu/"/>
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/tags/zhihu/"/>
    
  </entry>
  
  <entry>
    <title>深度学习框架中的虚拟显存/vllm sleep mode</title>
    <link href="https://wyc-ruiker.github.io/2025/06/21/zhihu/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E6%A1%86%E6%9E%B6%E4%B8%AD%E7%9A%84%E8%99%9A%E6%8B%9F%E6%98%BE%E5%AD%98_vllm-sleep-mode/"/>
    <id>https://wyc-ruiker.github.io/2025/06/21/zhihu/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E6%A1%86%E6%9E%B6%E4%B8%AD%E7%9A%84%E8%99%9A%E6%8B%9F%E6%98%BE%E5%AD%98_vllm-sleep-mode/</id>
    <published>2025-06-20T16:00:00.000Z</published>
    <updated>2026-04-19T17:00:43.062Z</updated>
    
    <content type="html"><![CDATA[<blockquote><p>原文链接: <a href="https://zhuanlan.zhihu.com/p/1919814348047623458">https://zhuanlan.zhihu.com/p/1919814348047623458</a></p></blockquote><h3 id="深度学习框架中的虚拟显存">深度学习框架中的虚拟显存</h3><p><a href="https://developer.nvidia.com/blog/introducing-low-level-gpu-virtual-memory-management/">CUDA</a>和<a href="https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/apiref/appdevgapi/aclcppdevg_03_0114.html">AscendCL</a>都支持虚拟内存管理。两个硬件的API也基本是对齐的，以AscendCL为例，大概需要这么几个API：</p><p><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_0.jpg"></p><p>CUDA和AscendCL的对应关系如下图：</p><p><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_1.jpg"></p><p>核心就是解耦了虚拟内存和实际的物理内存，实际的物理内存被搞成handle，可以map到提前预留好的虚拟地址，也可以从现成的虚拟地址上面unmap掉。有了这些接口，就可以很方便的实现碎片整理等功能。</p><p>具体而言，例如这样的场景，我们需要申请一个连续的2G的显存，但是显存池中只有两块不连续的1G的显存，在没有虚拟内存这一套接口的时候，只能抛出来OOM（Out Of Memory）这样的报错。</p><figure><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_2.jpg" alt="OOM"><figcaption aria-hidden="true">OOM</figcaption></figure><p>但有了虚拟内存这一套接口，我们就可以让两个1G的虚拟地址对应的handle从这两个虚拟地址上面unmap掉，然后map到一个连续的2G的虚拟地址。</p><figure><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_3.jpg" alt="碎片整理"><figcaption aria-hidden="true">碎片整理</figcaption></figure><p>这就是碎片整理，而且不需要框架层做类似传统碎片整理那样的搬运操作。torch里面的expandable_segments，mindspore的VMM，都是这个原理。缺点在于unmap和map这个过程有点慢，而且需要做流同步，需要打断device和host的流水。</p><p>除了碎片整理，其实虚拟内存在显存管理上还有很多帮助。例如现代的深度学习框架，为了方便多个进程使用同一张卡，显存都是慢慢拓展出来的，那最后的显存就会是这种形式：</p><p><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_4.jpg"></p><p>很显然，不论是cudaMalloc还是aclrtMalloc，都没法保证多次malloc出来的显存是连续的，这就会因为显存的不连续导致碎片。</p><p>而虚拟内存就能做到，提前预留好一大块物理地址，不断的申请handle map到当前的物理地址上，这样不管怎么拓展显存，都是一个连续的大块显存。</p><p><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_5.jpg"></p><p>当然如果有多池子的显存优化（大小显存池、param单独分池子），虚拟显存也可以做到用多少占用多少，不需要提前约定好每个池子占用的大小。而且经过一次碎片整理，不同池子的显存互相复用也会比较方便。（只要不在乎性能）</p><p>对于深度学习框架来说，虚拟内存比较好用的点就是这些。其实虚拟内存还有个很大的好处是他可以在虚拟地址（或者叫上层感知到的地址）完全不变的情况下，在底层做一些特殊操作。例如<a href="https://www.hiascend.com/graph-engine">Graph Engine</a>和<a href="https://developer.nvidia.com/blog/cuda-graphs/">CUDA Graph</a>这种硬件底层的图调度对于地址使用都有些不可变的约束，如果能够结合虚拟内存，就能玩一些花活。但因为map/unmap这种操作确实代价有点大，一直没有想到一个比较好的整活场景。</p><h3 id="vllm-sleep-mode">vllm sleep mode</h3><p>最近研究了一下<a href="https://github.com/vllm-project/vllm/pull/11743">vllm sleep mode</a>的实现，发现他通过虚拟内存的特性，在底层深度学习框架不感知的情况下完成了kvcache/weights的卸载/加载功能。</p><p>sleep mode主要是在强化学习共部署的场景下，做完推理之后，得把推理的显存都释放出来给训练用。在没有这个功能的时候，适配强化学习共部署，想卸载kvcache，都是通过<a href="https://objgraph.readthedocs.io/en/stable/">objgraph</a>等方式找到所有挂kvcache的python对象，然后一个个把持有kvcache的python对象都干掉，让python对象析构去触发框架显存的释放。跑推理之前还要一个个恢复回来，维护成本很高，每次升级vllm版本还要重新来。</p><p>vllm实现sleep mode这个功能，首先是在csrc/cumem_allocator.cpp里面实现了一个简易的显存申请释放的接口，但都是通过虚拟内存的方式进行的。也就是说cumem_allocator里面每申请一块显存，都会包含虚拟地址和底层对应的handle两个信息。这些信息会通过callback记录到python层。</p><p><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_6.jpg"></p><p>然后在sleep的时候，就申请一个cpu的tensor，把显存里面的内容都拷贝过去。虚拟地址不变，但是虚拟地址下面的handle都unmap并且释放掉。</p><figure><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_7.jpg" alt="sleep（卸载）流程"><figcaption aria-hidden="true">sleep（卸载）流程</figcaption></figure><p>同理wake up的时候，就申请新的handle map到虚拟地址上，然后再把cpu里面的内容拷贝上来。</p><figure><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_8.jpg" alt="wake up（恢复）流程"><figcaption aria-hidden="true">wake up（恢复）流程</figcaption></figure><p>那怎么让torch的显存申请走到cumem_allocator里面呢？vllm的大佬利用了<a href="https://docs.pytorch.org/docs/stable/notes/cuda.html#using-custom-memory-allocators-for-cuda">torch外挂显存池</a>的能力，如果在cumem定义的context中，就会自动走到对应cumem_allocator的显存申请流程：</p><figure><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_9.jpg" alt="外置显存池context"><figcaption aria-hidden="true">外置显存池context</figcaption></figure><p>这样只要做好context管理就行了，总比管理python对象的引用关系方便很多：</p><figure><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_10.jpg" alt="使用方法"><figcaption aria-hidden="true">使用方法</figcaption></figure><p>下面来分析，为什么这个东西非常的巧妙。先考虑我如果想实现这个功能会怎么搞？首先做好python对象的管理，然后基于python对象析构去触发显存释放是比较恶心的，也很容易出bug；还有一种方式是能不能把需要释放的tensor都调用一下to('cpu')？但pytorch的tensor to语义<a href="https://stackoverflow.com/questions/53570334/documentation-for-pytorch-tocpu-or-tocuda">不是inplace的</a>。如果想要释放显存需要类似a=a.to('cpu')这样的书写方式，那就和第一种做好python对象的管理一模一样了，没有解决根本的问题。</p><p>那我们自己管理显存呢？虽然torch做了外挂显存池的功能，但malloc和free接口都是torch框架来触发的。tensor.data_ptr()可以获取tensor的物理地址，但外部不能修改他。这也很好理解，如果外部能随便改data_ptr，框架的显存管理要怎么做呢？</p><p>这时候，虚拟内存的好处就来了。刚刚说虚拟内存还有个很大的好处是他可以在虚拟地址（或者叫上层感知到的地址）完全不变的情况下，在底层做一些特殊操作。vllm sleep mode通过在torch框架完全不感知的情况下，完成了对handle的卸载/加载。torch还以为自己的显存管理的好好的，殊不知在底层已经被偷偷的unmap掉了。vllm社区的大佬水平确实高，这个方案完美结合了torch/虚拟内存的底层各种机制，做到了vllm和torch的解耦。</p><p>当然，这套方案也不是完全没缺点，一个主要的缺点就是，因为是对torch框架的隐瞒，torch只要在sleep之后做了任何对这些tensor的读写操作，都会直接触发底层的ERROR，因为虚拟地址并没有映射到对应的物理地址。python的灵活性这么高，这个报错或者校验在vllm层是没法做的。</p><p>一个可能的完美解决方案是让torch框架提供inplace to/原地offload这样的语义，这样也不需要做任何python对象的管理，推理框架适配起来也很轻松，报错也能做的非常清晰。问题就是这个语义可能对框架的冲击比较大，对于成图更是重量级：）</p><p>从代码注释上看，这个方案也踩中了torch框架的一堆bug，蛮不容易的。比较令人在意的是这个：</p><p><img src="/zhihu/images/深度学习框架中的虚拟显存_vllm-sleep-mode_img_11.jpg"></p><p>用这个就不能用torch的expandable_segments了，而强化学习场景的load/offload又特别频繁，碎片会比较严重，长稳训练可能会出现一些问题。</p>]]></content>
    
    
      
      
    <summary type="html">&lt;blockquote&gt;
&lt;p&gt;原文链接: &lt;a href=&quot;https://zhuanlan.zhihu.com/p/1919814348047623458&quot;&gt;https://zhuanlan.zhihu.com/p/1919814348047623458&lt;/a&gt;&lt;/p&gt;
&lt;/</summary>
      
    
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/categories/zhihu/"/>
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/tags/zhihu/"/>
    
  </entry>
  
  <entry>
    <title>LLM RL入门</title>
    <link href="https://wyc-ruiker.github.io/2025/06/10/zhihu/LLM-RL%E5%85%A5%E9%97%A8/"/>
    <id>https://wyc-ruiker.github.io/2025/06/10/zhihu/LLM-RL%E5%85%A5%E9%97%A8/</id>
    <published>2025-06-09T16:00:00.000Z</published>
    <updated>2026-04-19T17:00:43.061Z</updated>
    
    <content type="html"><![CDATA[<blockquote><p>原文链接: <a href="https://zhuanlan.zhihu.com/p/27172237359">https://zhuanlan.zhihu.com/p/27172237359</a></p></blockquote><p>工作需要，最近在入门强化学习。这个笔记面向工程老哥们，知道老哥们看不懂数学（因为我就不懂），忽略了大量推导，意会为主。其中可能有些理解错误，望大家多多指正。</p><h2 id="llm中的强化学习定义">LLM中的强化学习定义</h2><p><img src="/zhihu/images/LLM-RL入门_img_0.jpg"></p><ul><li>智能体（Agent）: 经过预训练和SFT的LLM，可以对外界的输入做出回应，有一定的智能。</li><li>环境（Environment）: 其实是整个物理世界。如果是写算法题的话那就是leetcode平台，如果是RLHF的话，就是人类本身。</li><li>状态（State）: 环境的状态，这个东西在LLM RL里面有点模糊，不像游戏AI那么清晰；可以认为LLM之前的所有输出作为环境的状态吗？</li><li>动作（Action) : 对于LLM来说，就是decode过程中输出的每一个token。</li><li>奖励（Reward）: 就是环境的反馈，比如人类喜不喜欢这个回答或者生成的代码能不能通过leetcode的评测，这里是搞头比较大的地方。</li></ul><h2 id="强化学习的基础做法">强化学习的基础做法</h2><p><img src="/zhihu/images/LLM-RL入门_img_1.jpg"></p><p>基础的强化学习分成两个角色，actor <span class="math inline">\(\pi\)</span> 和critic <span class="math inline">\(V_{\pi}\)</span> 。其中actor代表的就是agent，基于环境的状态和奖励能够做出决策；critic类似教师，基于当前的环境状态和actor的水平，判断一下这个状态的价值。有一点很关键，critic是『因材施教』的，critic的输出一定是针对当前的actor而言的。</p><p>既然有两个概念，根据二进制，就有三种对应的做法：</p><ol type="1"><li><p>只有critic的实体。如果critic能学的很好的话，很明显，actor只需要看看他的每一个决策critic的反应是什么，选critic反应最好的决策就行（假设action是离散的）。这个方法就叫做value-based，比较典型的是Q-learning。</p></li><li><p>只有actor的实体。人教人教不会，事教人一下就会。有时候不需要critic，actor只需要和环境做互动，从环境拿反馈就行。这种方法叫做policy-based，就是一般的policy gradient方法。</p></li><li><p>两个实体都有，两个模型可以互相迭代。critic的学习可能是有偏的，连续的action不好处理（value-based），而直接从环境拿反馈可能方差太大，训练不稳定（policy-based）。所以两个模型互相迭代的方法被广为采用，叫做actor-critic。</p></li></ol><h2 id="actor-critic">actor-critic</h2><h3 id="数据来源">数据来源</h3><p>RL是在LLM的后训练环节，数据是从actor <span class="math inline">\(\pi\)</span> 采样出来的。因为actor就是要训练的LLM本身，所以这个过程就是给一些prompt，decode出一些回答，基于回答的反馈去更新LLM。</p><h3 id="actor的优化">actor的优化</h3><p>我们先用SFT来理解。如果我们把所有的输入用来做SFT，那所有样本对于LLM都是要学习的。 先写个SFT的梯度：</p><p><span class="math display">\[SFT_{\nabla}=\frac{1}{N}\sum_{i=1}^{n}\sum_{t=1}^{T_n-1}\nabla_{\theta}\pi_{\theta}(o_t|q,o_{&lt;t})\]</span></p><p>其中 nn 代表seq数量， <span class="math inline">\(T_n\)</span> 代表seq长度。这个就是很经典的预测next token的loss。 那强化学习里面要怎么优化actor呢？很简单，就是在每个梯度前面乘上一个reward。</p><p><span class="math display">\[Actor_{\nabla}=\frac{1}{N}\sum_{i=1}^{n}\sum_{t=1}^{T_n-1}R(\tau_n)\nabla_{\theta}\pi_{\theta}(o_t|q,o_{&lt;t})\]</span></p><p>其中 <span class="math inline">\(R(\tau_n)\)</span> 代表的是每个seq对应的reward，这个东西怎么算先按下不表。感性理解就是reward越大，我越希望拟合这个seq；reward是负的，就需要极力避免这个seq。 写成强化学习的符号，就是：</p><p><span class="math display">\[Actor_{\nabla}=\frac{1}{N}\sum_{i=1}^{n}\sum_{t=1}^{T_n-1}R(\tau_n)\nabla_{\theta}\pi_{\theta}(a_t|s_t)\]</span></p><p>其中 <span class="math inline">\(a_t\)</span> 代表action， <span class="math inline">\(s_t\)</span> 代表状态。就是基于此，我猜想LLM RL中的状态就是之前输出的token，但这个定义其实很古怪，随着交互环境的复杂，后面应该会有变化。</p><p>更加泛化一点， <span class="math inline">\(R(\tau_n)\)</span> 可以定义为 <span class="math inline">\(A_{\pi}(s_t, a_t)\)</span> ，因为这个奖励和actor本身、当前状态、当前动作都有关系。在很多资料中，这个也被叫做优势函数。</p><p><span class="math display">\[Actor_{\nabla}=\frac{1}{N}\sum_{i=1}^{n}\sum_{t=1}^{T_n-1}A_{\pi}(s_t, a_t)\nabla_{\theta}\pi_{\theta}(a_t|s_t)\]</span></p><h3 id="优势函数">优势函数</h3><p><span class="math inline">\(A_{\pi}(s_t, a_t)\)</span> 怎么算？这个推导涉及到Q-learning，虽然这个推导不算难，但看完感觉工程同学根本没啥学的必要。只需要知道他是通过critic计算出来的：</p><p><span class="math display">\[A_{\pi}(s_t, a_t)=r_t+\gamma V_{\pi}(s_{t+1})-V_{\pi}(s_t)\]</span></p><p>其中 <span class="math inline">\(r_t\)</span> 是当前步骤的反馈，一般是通过reward model计算得来； <span class="math inline">\(V_{\pi}\)</span> 就是critic模型； <span class="math inline">\(s_{t+1}\)</span> 是从 <span class="math inline">\(s_t\)</span> 基于 <span class="math inline">\(a_t\)</span> 这个东西转移过来的。这个公式其实很直观，我们想一下，怎么让actor满足crtiic的需求？当然就是找到一个 <span class="math inline">\(a_t\)</span> ，使得转移过去的 <span class="math inline">\(s_{t+1}\)</span> 尽可能的好。</p><h3 id="critic的优化">critic的优化</h3><p>接上面的逻辑，对于一个完美的 <span class="math inline">\(V_{\pi}\)</span> ，他的递推式子就应该是（但和上面的东西结合起来有点怪异，actor需要讨好critic，但是critic又要尽量客观）：</p><p><span class="math display">\[V_{\pi}(s_t)=r_t+\gamma V_{\pi}(s_{t+1})\]</span></p><p>为什么有个 <span class="math inline">\(\gamma\)</span> ，因为后续状态对前面的影响应该是越来越小的，这个也比较直观。 那critic的loss就天然应该是一个MSE：</p><p><span class="math display">\[Critic_{loss}=(r_t+\gamma V_{\pi}(s_{t+1})-V_{\pi}(s_t))^2\]</span></p><h2 id="ppo">PPO</h2><p>actor-critic的主要问题是，每次需要actor <span class="math inline">\(\pi\)</span> 去做采样，然后再回头更新出新的 <span class="math inline">\(\pi^*\)</span> ，再用 <span class="math inline">\(\pi^*\)</span> 去采样，循环往复。这个方法叫做on-policy。</p><p>采样后要计算reward function、要计算ref model，这些都挺慢的。很容易就想到能不能采一次样，后面多迭代几轮，利用好之前的采样信息。这种方法就叫做off-policy。</p><h3 id="重要性采样">重要性采样</h3><p>off-policy的核心问题是，之前采样的 <span class="math inline">\(\pi\)</span> 和训练的 <span class="math inline">\(\pi^*\)</span> 分布是不同的，数学原理上难以保证训练的有效性。从某个分布采样的数据对另一个分布做训练，有个解决方法叫做重要性采样。数学原理很多地方都有，这里就不加赘述了，直接给出带重要性采样的更新公式，只是给之前的actor梯度加了一项：</p><p><span class="math display">\[Actor_{\nabla}=\frac{1}{N}\sum_{i=1}^{n}\sum_{t=1}^{T_n-1}\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta old}(a_t|s_t)}A_{\pi}(s_t, a_t)\nabla_{\theta}\pi_{\theta}(a_t|s_t)\]</span></p><h3 id="gae">GAE</h3><p>再回头看优势函数，我们之前的优势函数 <span class="math inline">\(A_{\pi}(s_t, a_t)\)</span> 是完全依赖 <span class="math inline">\(V_{\pi}\)</span> 的，但训练的初期 <span class="math inline">\(V_{\pi}\)</span> 一定是非常不准的。 <span class="math inline">\(A^{GAE}_{\pi}(s_t, a_t)\)</span> 就是平衡了 <span class="math inline">\(V_{\pi}\)</span> 的重要性和之后所有 <span class="math inline">\(r_t\)</span> 的重要性：</p><p><span class="math display">\[A^{GAE}_{\pi}(s_t, a_t)=\sum_{l=0}^{\infty}(\gamma\lambda)^l\delta_{t+l}\]</span></p><p>其中：</p><p><span class="math display">\[\delta_{t}=r_t+\gamma V_{\pi}(s_{t+1})-V_{\pi}(s_t)\]</span></p><p>可以很容易看出来 <span class="math inline">\(\lambda=0\)</span> 就是原来的 <span class="math inline">\(A_{\pi}(s_t, a_t)\)</span> 。如果 <span class="math inline">\(\lambda=1\)</span> ，那就是 <span class="math inline">\(-V_{\pi}(s_t)+\sum_{l=0}^{\infty}\gamma^lr_{t+l}\)</span> （可以展开几项写一写）。 <span class="math inline">\(\lambda\)</span> 的取值代表有多少后续的 <span class="math inline">\(r_{t+l}\)</span> 对优势函数产生了影响。</p><h3 id="裁切设计">裁切设计</h3><p>上面的两个内容都不难理解，结合起来，actor的梯度应该是：</p><p><span class="math display">\[Actor_{\nabla}=\frac{1}{N}\sum_{i=1}^{n}\sum_{t=1}^{T_n-1}\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta old}(a_t|s_t)}A^{GAE}_{\pi}(s_t, a_t)\nabla_{\theta}\pi_{\theta}(a_t|s_t)\]</span></p><p>但通常来讲，PPO的loss形式都看起来很复杂，这是因为在更新的时候 <span class="math inline">\(\pi_{\theta}\)</span> 和 <span class="math inline">\(\pi_{\theta old}\)</span> 的差距不能太大，需要各种方式去限制两个分布的差异。效果最好的是对梯度进行裁切，如果两个分布偏差过大的话，直接加个裁切，不让他们更新的太远：</p><p><span class="math display">\[Actor_{\nabla}=\frac{1}{N}\sum_{i=1}^{n}\sum_{t=1}^{T_n-1}min[\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta old}(a_t|s_t)}A^{GAE}_{\pi}(s_t, a_t), clip(\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta old}(a_t|s_t)}, 1 - \epsilon, 1 + \epsilon)A^{GAE}_{\pi}(s_t, a_t)]\nabla_{\theta}\pi_{\theta}(a_t|s_t)\]</span></p><p>对于这个裁切的理解，下面的参考资料里面（李宏毅老师公开课和猛猿大神的文章）有基于折线图的很出色的解释，可以参考一下。</p><p>上面说critic是『因材施教』的，critic的输出一定是针对当前的actor而言的，所以在actor更新的过程中，critic也要跟着更新。回忆一下之前的critic loss：</p><p><span class="math display">\[Critic_{loss}=(r_t+\gamma V_{\pi}(s_{t+1})-V_{\pi}(s_t))^2\]</span></p><p>在PPO中，给这地方增加了一项：</p><p><span class="math display">\[Critic_{loss}=(r_t+\gamma V_{\pi}(s_{t+1})+\gamma\lambda A^{GAE}_{t+1}-V_{\pi}(s_t))^2\]</span></p><p>我没怎么明白为什么，可能是为了增大 <span class="math inline">\(r_t\)</span> 的影响力。左边的 <span class="math inline">\(r_t+\gamma V_{\pi}(s_{t+1})+\gamma\lambda A^{GAE}_{t+1}\)</span> 我们简写为 <span class="math inline">\(R_t\)</span> ，这里 <span class="math inline">\(R_t\)</span> 都是提前算好的，所以用的 <span class="math inline">\(V_{\pi}\)</span> 都是 <span class="math inline">\(V_{\pi}^{old}\)</span> 。类似的，PPO也不希望 <span class="math inline">\(V_{\pi}^{new}\)</span> 和 <span class="math inline">\(V_{\pi}^{old}\)</span> 离得太远，所以最后也加了个裁切：</p><p><span class="math display">\[Critic_{loss}=max([(V^{new}_{\pi}-R_t)^2, (clip(V^{new}_{\pi}, V^{old}_{\pi}-\epsilon, V^{old}_{\pi}+\epsilon)-R_t)^2])\]</span></p><h2 id="ref-model">ref model</h2><p>在LLM RL中，ref model是一个有点特殊的设计，之前RL提的不怎么多。主要目的是不让actor跑的太偏，依然保留一些之前预训练获得的能力。ref model和actor的初始化权重相同，一般训练很多步再更新一下权重或者根本不更新。在计算actor loss的时候，增加一项（比如KL散度），不让ref model和actor偏离的太远。</p><h2 id="grpo">GRPO</h2><p>GRPO因为deepseek成为了现在最火爆的强化学习算法。AIQL大神的回答里有个神图：</p><p><img src="/zhihu/images/LLM-RL入门_img_2.jpg"></p><p>GRPO最核心的点有两个：</p><ol type="1"><li><p>干掉了critic model，直接通过一条prompt rollout一堆回复来采样数据集。通过增大采样来抵消方差。</p></li><li><p>reward model也做了改动。对于很多问题，token level的奖励是不够合理的，DeepSeekMath引入了过程监督。DeepSeek-r1更是直接改成了rule-base。</p></li></ol><p>GRPO是对PPO算法的简化。之前的PPO训练模型中有actor推理、actor训练、critic、ref、reward 5个模型（因为训练和推理需要用不同的框架来加速，策略一般也不同）。GRPO一下子把critic和reward全干掉了，大力出奇迹。其实很多人在复现GRPO的过程中表示和ref model的对齐会影响效果，去掉ref model反而出来深度思考的过程了。那是不是可以幻想一下，以后只需要actor推理、actor训练，actor就和环境做互动就行了？大道至简。</p><p>下一篇文章会分析一下当前业界比较优秀的RL框架（openrlhf/verl/chatlearn），结合代码再深入理解一下LLM RL的细节。</p><h2 id="参考">参考</h2><ol type="1"><li><p><a href="https://www.bilibili.com/video/BV1ou411874G/">B站首推！李宏毅大佬花一周讲完！2023公认最通俗易懂的【强化学习教程】小白也能信手拈来（人工智能|机器学习|深度学习|强化学习）_哔哩哔哩_bilibili</a></p></li><li><p><a href="https://arxiv.org/pdf/2402.03300">[2402.03300] DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models</a></p></li><li><p><a href="https://zhuanlan.zhihu.com/p/13467768873">【LLM】PPO理论推导+代码实战 - 知乎</a></p></li><li><p><a href="https://zhuanlan.zhihu.com/p/15677409107">【LLM】GRPO：改进PPO增强推理能力 - 知乎</a></p></li><li><p><a href="https://zhuanlan.zhihu.com/p/7461863937">人人都能看懂的RL-PPO理论知识 - 知乎</a></p></li><li><p><a href="https://www.zhihu.com/question/10766825126/answer/88583863333">DeepSeek的GRPO算法是什么？ - 知乎</a></p></li><li><p><a href="https://zhuanlan.zhihu.com/p/25067791857">GRPO简化Tricks, 性能暴涨10%, 只改一个参数? - 知乎</a></p></li></ol>]]></content>
    
    
      
      
    <summary type="html">&lt;blockquote&gt;
&lt;p&gt;原文链接: &lt;a href=&quot;https://zhuanlan.zhihu.com/p/27172237359&quot;&gt;https://zhuanlan.zhihu.com/p/27172237359&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;工</summary>
      
    
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/categories/zhihu/"/>
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/tags/zhihu/"/>
    
  </entry>
  
  <entry>
    <title>[Attention]FlashAttention/Ring-Attention/混合序列并行的统一原理</title>
    <link href="https://wyc-ruiker.github.io/2025/03/23/zhihu/[Attention]FlashAttention_Ring-Attention_%E6%B7%B7%E5%90%88%E5%BA%8F%E5%88%97%E5%B9%B6%E8%A1%8C%E7%9A%84%E7%BB%9F%E4%B8%80%E5%8E%9F%E7%90%86/"/>
    <id>https://wyc-ruiker.github.io/2025/03/23/zhihu/[Attention]FlashAttention_Ring-Attention_%E6%B7%B7%E5%90%88%E5%BA%8F%E5%88%97%E5%B9%B6%E8%A1%8C%E7%9A%84%E7%BB%9F%E4%B8%80%E5%8E%9F%E7%90%86/</id>
    <published>2025-03-22T16:00:00.000Z</published>
    <updated>2026-04-19T17:00:43.062Z</updated>
    
    <content type="html"><![CDATA[<blockquote><p>原文链接: <a href="https://zhuanlan.zhihu.com/p/1887098218866794901">https://zhuanlan.zhihu.com/p/1887098218866794901</a></p></blockquote><p>个人学习记录。这三个东西本质的原理都差不多，就是<strong>attention怎么沿着序列维度切</strong>。</p><h2 id="online-softmax">online softmax</h2><p>回顾一下最初的attention：</p><p><span class="math inline">\(O=softmax(QK^T)V\)</span></p><p>先不考虑融合算子，对于原始的数学公式来说就是两次matmul+一次softmax。matul是最容易优化的，因为良好的数学定义，这个东西可以横着切也可以竖着切也可以分成一小块一小块的，无论是算子优化还是做分布式切分（其实算子优化就是在device内做并行，分布式切分就是在device间做并行），都非常方便。</p><figure><img src="/zhihu/images/%5BAttention%5DFlashAttention_Ring-Attention_混合序列并行的统一原理_img_0.jpg" alt="https://www.cs.sfu.ca/~ashriram/Courses/CS7ARCH/hw/hw4.html"><figcaption aria-hidden="true">https://www.cs.sfu.ca/~ashriram/Courses/CS7ARCH/hw/hw4.html</figcaption></figure><p>但是softmax就比较麻烦：</p><p><span class="math display">\[softmax(x_i)=\frac{e^{x_i-m}}{\sum^{N}_{j=1}e^{x_j-m}},m=\underset{1\leq i \leq N}{\max}(x_j)\]</span></p><p>如果用最naive的方式去算softmax，需要遍历三遍。第一次遍历求一下m，第二次遍历把分母求出来，第三次把每个元素求出来。如果基于这个计算模式去做分布式，整个的效率会特别的低。因为每次遍历都需要拿到整个N^2大小的矩阵，对于算子来说会频繁的做IO，对于分布式就是频繁的通信。</p><p>但指数函数有个很棒的性质，可以只遍历一次数据，就同时得到m和分母：</p><p><span class="math display">\[e^{x_j-m_{i-1}}e^{m_{i-1}-m_i}=e^{x_j-m_{i-1}+m_{i-1}-m_i}=e^{x_j-m_i}\]</span></p><p>通过上面的式子，不停的基于新元素，更新分母和m即可。</p><p>不仅如此，我们也基于这个性质把softmax分成两块去计算，用到的符号如下：</p><p><span class="math inline">\(A_l\)</span> 左边一半的softmax分子，<span class="math inline">\(A_r\)</span> 右边一半的softmax分子，都是个向量</p><p><span class="math inline">\(B_l\)</span> 左边一半的softmax分母，<span class="math inline">\(B_r\)</span> 右边一半的softmax分母，都是个数值</p><p><span class="math inline">\(m_l\)</span> 左边一半的最大 xx ，<span class="math inline">\(m_r\)</span> 右边一半的最大 xx，很容易得到全局 xx 的最大值 mm</p><p><span class="math display">\[softmax_{全局}=\frac{A_{全局}}{B_{全局}}=\frac{[e^{m_l-m}A_l,e^{m_r-m}A_r]}{e^{m_l-m}B_l+e^{m_r-m}B_r}\]</span></p><p>通过这种方式，打开了softmax的并行空间。</p><h2 id="online-attention">online attention</h2><p>再回头看attention的公式，实际上他的物理意义在于基于QK矩阵乘的值，对V的每一列做一下线性组合。</p><p><span class="math display">\[O=softmax(QK^T)V\]</span></p><p>把线性组合展开一下，对于O的每一行来说：</p><p><span class="math display">\[o=\sum_{i=1}^{N}(\frac{e^{x_i-m}}{\sum^{N}_{j=1}e^{x_j-m}} \times v_i)\]</span></p><p>把这一步也分成两块：</p><p><span class="math display">\[o_l=\frac{A_l\cdot v_l}{B_l},o_r=\frac{A_r \cdot v_r}{B_r}\]</span></p><p>左右两边合并的方式如下，符号含义和softmax的合并一致：</p><p><span class="math display">\[o=\frac{B_le^{m_l-m}o_l+B_re^{m_r-m}o_r}{B_le^{m_l-m}+B_re^{m_r-m}}\]</span></p><p>结合上面softmax的并行方式，整个attention的计算也可以像matmul一样做纵向的切分了。</p><p>再对着图理解一下，切Q的话，相对容易：</p><p><img src="/zhihu/images/%5BAttention%5DFlashAttention_Ring-Attention_混合序列并行的统一原理_img_1.jpg"></p><p>切KV，经过我们上面的推导，也是可以合并起来的：</p><p><img src="/zhihu/images/%5BAttention%5DFlashAttention_Ring-Attention_混合序列并行的统一原理_img_2.jpg"></p><p>RingAttention/FlashAttention正是基于上面的机制，实现了attention算子的device内并行和device间并行:</p><ol type="1"><li><p>device间的并行可以减少N（seq长度）对于训练的限制，将训练拓展到更长的序列上去。对于RA来说，不同设备分别做一下seq切分后的attention，然后通过send/recv，在下一个卡上完成对上一个attention的合并。</p></li><li><p>对于FA来说，就是利用切块去减小仿存，从而拿到融合算子的性能收益。</p></li></ol><p>这里突然想到一个手撕代码题，用线段树做区间softmax/attention，想用的大哥找我交一下版权费:)</p><h2 id="序列并行">序列并行</h2><p>说混合序列并行前，需要先了解一下Megatron和DeepSpeed两大巨头最开始是怎么做序列并行的。</p><figure><img src="/zhihu/images/%5BAttention%5DFlashAttention_Ring-Attention_混合序列并行的统一原理_img_3.jpg" alt="Megatron"><figcaption aria-hidden="true">Megatron</figcaption></figure><p>Megatron最开始的seq并行相对来说比较简单，transformer没切tp的地方都是element wise的。所以可以把tp的allreduce改成reducescatter，直接把N分到不同的卡上去做element wise的操作。序列并行转tensor并行的时候需要allgather，tensor并行转序列并行的时候需要reducescatter。这个方法的优点就是简单，缺点就是通信量比较大，而且序列并行的通信域被TP域的限制住，拓展不了特别大。</p><figure><img src="/zhihu/images/%5BAttention%5DFlashAttention_Ring-Attention_混合序列并行的统一原理_img_4.jpg" alt="deepspeed ulysses"><figcaption aria-hidden="true">deepspeed ulysses</figcaption></figure><p>DeepSpeed采用的方式叫做deepspeed ulysses。核心点就是除了attention之外的部分（rmsnorm、matmul）切N都比较方便，只有attention不好切N，切head比较方便。那就是在进出attention的时候做一下切N和切head的转换，这里转换的方式使用了alltoall的通信源语，就是个分布式转置：</p><figure><img src="/zhihu/images/%5BAttention%5DFlashAttention_Ring-Attention_混合序列并行的统一原理_img_5.jpg" alt="来自www.mindspore.cn"><figcaption aria-hidden="true">来自www.mindspore.cn</figcaption></figure><p>deepspeed ulysses的优点是通信量，成比例增加序列长度和device数量的时候，可以维持一个稳定的通信量。但和megatron的方式类似，切分数量被head num限制住，没法把序列切的特别小。</p><h2 id="混合序列并行">混合序列并行</h2><p>实际上megatron sp和deepspeed ulysses都可以叠加ring-attention，把attention做进一步细分。megatron官网给了一个tp2cp2的例子，cp和tp共用四张卡，可以发现同一个序列的attention被切到两个部分了，这里的合并就是借助ring-attention的方式：</p><figure><img src="/zhihu/images/%5BAttention%5DFlashAttention_Ring-Attention_混合序列并行的统一原理_img_6.jpg" alt="Megatron Context Parallelism"><figcaption aria-hidden="true">Megatron Context Parallelism</figcaption></figure><p>和deepspeed ulysses的结合思路也类似，在attention部分进一步切分。这样就可以解除head num对deepspeed ulysses的约束，可以把序列进一步细分。</p><figure><img src="/zhihu/images/%5BAttention%5DFlashAttention_Ring-Attention_混合序列并行的统一原理_img_7.jpg" alt="USP"><figcaption aria-hidden="true">USP</figcaption></figure><p>后面遇到其他并行优化，比如seqpp之类的，就知道attention部分是怎么切分，需要处理些什么了。当然想实现最好的性能还需要非常多的细节。</p><h2 id="参考">参考</h2><p><a href="https://zhuanlan.zhihu.com/p/668888063">DefTruth：[Attention优化][2w字] 原理篇: 从Online-Softmax到FlashAttention V1/V2/V3</a></p><p><a href="https://zhuanlan.zhihu.com/p/683714620">朱小霖：ring attention + flash attention：超长上下文之路</a></p><p><a href="https://zhuanlan.zhihu.com/p/698031151">方佳瑞：序列并行做大模型训练，你需要知道的六件事</a></p><p><a href="https://zhuanlan.zhihu.com/p/4496065391">猛猿：图解大模型训练系列：序列并行2，DeepSpeed Ulysses</a></p><p><a href="https://zhuanlan.zhihu.com/p/689067888">方佳瑞：大模型训练之序列并行双雄：DeepSpeed Ulysses &amp; Ring-Attention</a></p><p><a href="https://zhuanlan.zhihu.com/p/5502876106">猛猿：图解大模型训练系列：序列并行4，Megatron Context Parallel</a></p><p><a href="https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-ulysses/chinese/README.md">DeepSpeed/blogs/deepspeed-ulysses/chinese/README.md at master · deepspeedai/DeepSpeed</a></p><p><a href="https://arxiv.org/abs/2405.07719">USP: A Unified Sequence Parallelism Approach for Long Context Generative AI</a></p><p><a href="https://arxiv.org/abs/2205.05198">Reducing Activation Recomputation in Large Transformer Models</a></p><p><a href="https://arxiv.org/abs/2310.01889">Ring Attention with Blockwise Transformers for Near-Infinite Context</a></p><p><a href="https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/context_parallel.html">context_parallel package</a></p>]]></content>
    
    
      
      
    <summary type="html">&lt;blockquote&gt;
&lt;p&gt;原文链接: &lt;a href=&quot;https://zhuanlan.zhihu.com/p/1887098218866794901&quot;&gt;https://zhuanlan.zhihu.com/p/1887098218866794901&lt;/a&gt;&lt;/p&gt;
&lt;/</summary>
      
    
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/categories/zhihu/"/>
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/tags/zhihu/"/>
    
  </entry>
  
  <entry>
    <title>品鉴一下OpenRLHF和verl的系统设计</title>
    <link href="https://wyc-ruiker.github.io/2025/03/11/zhihu/%E5%93%81%E9%89%B4%E4%B8%80%E4%B8%8BOpenRLHF%E5%92%8Cverl%E7%9A%84%E7%B3%BB%E7%BB%9F%E8%AE%BE%E8%AE%A1/"/>
    <id>https://wyc-ruiker.github.io/2025/03/11/zhihu/%E5%93%81%E9%89%B4%E4%B8%80%E4%B8%8BOpenRLHF%E5%92%8Cverl%E7%9A%84%E7%B3%BB%E7%BB%9F%E8%AE%BE%E8%AE%A1/</id>
    <published>2025-03-10T16:00:00.000Z</published>
    <updated>2026-04-19T17:00:43.062Z</updated>
    
    <content type="html"><![CDATA[<blockquote><p>原文链接: <a href="https://zhuanlan.zhihu.com/p/29046833667">https://zhuanlan.zhihu.com/p/29046833667</a></p></blockquote><h2 id="openrlhf">OpenRLHF</h2><h3 id="spmd单程序多数据">SPMD（单程序多数据）</h3><p><a href="https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/cli/train_ppo.py">https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/cli/train_ppo.py</a></p><p>用SPMD做LLM RL不需要太复杂的系统设计，因为当前深度学习最广泛使用的分布式范式就是SPMD，遵循大部分算法框架的设计方式就好了。OpenRLHF SPMD ppo的系统架构很简单：</p><figure><img src="/zhihu/images/品鉴一下OpenRLHF和verl的系统设计_img_0.jpg" alt="train_ppo.py"><figcaption aria-hidden="true">train_ppo.py</figcaption></figure><p>基于各种配置项初始化出对应的模型后，将这些模型传入PPOTrainer中。PPOTrainer负责整个PPO算法的控制逻辑。此时，不同的模型在同一组卡和同一组进程上按照不同的时间片运行SPMD。这些共享同一组计算资源并按时间交替使用的模型被称为<strong>colocate models</strong>。</p><h3 id="mpmd多程序多数据">MPMD（多程序多数据）</h3><p><a href="https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/cli/train_ppo_ray.py">https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/cli/train_ppo_ray.py</a></p><p>SPMD虽然实现简单，但它要求不同的模型只能串行执行，即使没有数据依赖的模型也难以实现并发。由于强化学习涉及的模型数量较多，如果某些模型不需要占用全部计算卡，就会导致部分计算资源的闲置。此外，SPMD需要将多个模型的参数同时加载到一张计算卡上，如果不结合offload等技术，很容易引发显存OOM问题。</p><figure><img src="/zhihu/images/品鉴一下OpenRLHF和verl的系统设计_img_1.jpg" alt="https://arxiv.org/abs/2405.11143"><figcaption aria-hidden="true">https://arxiv.org/abs/2405.11143</figcaption></figure><p>所以，OpenRLHF还支持使用ray进行拉起。使用ray的好处是可以通过配置placement group，让模型绑定到不同的卡上，并通过ray完成不同进程的数据交换。这里参考<a href="https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/cli/train_ppo_ray.py">train_ppo_ray.py</a>，画一下critic和actor-ref分离部署的场景示意图：</p><figure><img src="/zhihu/images/品鉴一下OpenRLHF和verl的系统设计_img_2.jpg" alt="train_ppo_ray.py"><figcaption aria-hidden="true">train_ppo_ray.py</figcaption></figure><p>其实能很明显的看出来，OpenRLHF的ray流程基本上就是在SPMD流程上硬改过来的，并不像下面的verl一样，是基于ray的原生系统设计。有两点比较别扭：</p><ol type="1"><li>PPORayActorGroup都在主进程（或者叫driver process）实例化，但是算法的控制逻辑不在主进程上，而是在Actor对应的PPORayActorGroup里面（ActorPPOTrainer）。不同的PPORayActorGroup在逻辑上不是对等的，Actor所在PPORayActorGroup需要把RL算法中的所有组件串起来。当然了，这样实现Trainer的逻辑不需要大改，只需要从SPMD的PPOTrainer继承出来一个ActorPPOTrainer就行了，仅仅是架构概念上不太符合单一职责原则，真的要去理解流程还是比较清晰易懂的。</li><li>colocate的模型不能放在同一个进程。参考上面的图，actor和ref共部署在相同的placement group上，但因为主体控制逻辑在ActorPPOTrainer里面，他也不知道ref是不是和他共部署，所以critic和ref都只能通过.remote的方式去调用。最后的效果就是，ref虽然和actor跑在一张卡上，但是二者不在同一个进程里面。这个设计会影响很多优化的开展，之前讲了深度学习最广泛使用的分布式范式就是SPMD，从深度学习框架到底层的device，都认为大部分场景（或者极致性能的场景）下device和process是一对一的，通信、显存资源都按照process级别去做共享。以显存为例，<a href="https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDACachingAllocator.cpp">把用过的显存缓存下来</a>是最基本的性能优化手段，但colocate的模型不在同一个进程上，就需要频繁的empty_cache来释放显存给卡上的另一个进程。老调重弹，对于国产芯片和框架，这种设计对架构更是巨大冲撞。</li></ol><p>当然，好处就是系统设计的很清晰，相比verl的层层封装，OpenRLHF想动手改点东西是很简单的，算法工程师们也可以轻松理解。</p><p>OpenRLHF的简洁设计有一个很重要的前提，就是模型基本都是dp分片的（vllm里面有tp分片），训练基于FSDP或者deepspeed，分布式优化靠的都是zero系列。这种设计的好处是数据流通起来很方便，就算RL框架里面最复杂的actor推理训练权重同步，在OpenRLHF里面也只需要一个broadcast（因为vllm只会再多个TP分片）。因为都是dp分片，不同rank是完全对等的，不同模型的调度可以直接轮询。但正是这种选择，导致OpenRLHF在大集群上训练超大规模的网络很难用，只用dp是没法跑满血的deepseek v3的。</p><h2 id="verl">verl</h2><p>verl的论文写了single controller/multi-controller、zero redundancy model resharding之类的贡献点。但我这里直接恶意揣测一下，verl最核心的动机以及设计上最漂亮的点是colocate模型的<strong>共进程</strong>，这一点对系统优化非常关键，但是不好发论文吹牛，所以包装了几个点出来发论文。</p><figure><img src="/zhihu/images/品鉴一下OpenRLHF和verl的系统设计_img_3.jpg" alt="共进程的秘密"><figcaption aria-hidden="true">共进程的秘密</figcaption></figure><p><a href="https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py">https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py</a></p><p><a href="https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py">https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py</a></p><p>上面OpenRLHF的ray流程最大缺点就是colocate模型没法共进程。想一下，要实现共进程这个特性，得让colocate的不同模型共享同一个ray remote实例。为了达成这个目的，并且让上层的编程接口尽量干净，verl做了巨复杂的封装，我上面画了个图，展示了层层封装之后最终的调用形态。</p><p>用户需要编写的是ActorRolloutRefWorker或CriticWorker等worker，并将这些worker传入RayPPOTrainer中。verl会自动将colocate的模型集中到一个WorkerDict中，并通过setattr为每个worker设置所需的方法，从而实现worker层面的任务分发。最终，系统会为每个worker生成一个RayWorkerGroup，这些RayWorkerGroup的对外接口与其对应的worker完全一致，但经过了多次转发。colocate的RayWorkerGroup的成员方法会转发到同一个WorkerDict中，以此实现共进程的执行机制。</p><p>此外，verl支持FSDP/Megatron等多种后端，也支持3D并行策略的配置，这样数据流动的方式就会很复杂。为了实现这个功能，verl搞了一套协议，自动在函数的前后插上对应并行方式的dispatch/collect方法。反正都把函数调用搞成这么复杂的闭包了，ray相关的操作（ray.gets/.remote）也可以隐藏在闭包里面（OpenRLHF要直接写在脚本上），主体逻辑就看起来很简洁。</p><figure><img src="/zhihu/images/品鉴一下OpenRLHF和verl的系统设计_img_4.jpg" alt="data proto（https://arxiv.org/pdf/2409.19256v2）"><figcaption aria-hidden="true">data proto（https://arxiv.org/pdf/2409.19256v2）</figcaption></figure><figure><img src="/zhihu/images/品鉴一下OpenRLHF和verl的系统设计_img_5.jpg" alt="主逻辑（https://arxiv.org/pdf/2409.19256v2）"><figcaption aria-hidden="true">主逻辑（https://arxiv.org/pdf/2409.19256v2）</figcaption></figure><p>这个封装合不合适，见仁见智，对verl可能大部分人会有个螺旋上升的认识过程。首先看到fit函数这么简洁，肯定觉得很舒服；但后来发现每个函数怎么都点不进去，各种调用怎么一层套一层，就觉得这又是大公司开源狗屎给大家；最后明白了verl的设计理念，感叹人类的智慧。当然，如果不用大集群，7B/30B这种规模做做RL，还想自己改改东西，无脑推荐OpenRLHF。</p><h2 id="one-more-thingray真的很重要吗">One more thing：ray真的很重要吗？</h2><p>回归前提，<a href="https://zhuanlan.zhihu.com/p/27172237359">上一篇文章</a>我猜RL算法的演进会越来越简洁，GRPO干掉了reward和critic。按照历史经验来看，RL需要的算力会越来越大，各种算法设计需要在大算力下充分验证。这一定会让RL走向力大砖飞的方向，各种小技巧会越来越少。在RL过程中，actor的推理和训练会越来越重，二者的系统差异也会越来越大，这两模块会吃掉所有的算力。为了达成更好的MFU，如果还是这种训练模式，actor训推共部署是必然的选择。在这个前提下，ray很可能是伪需求，最后还是回归到SPMD的怀抱。</p><h2 id="参考">参考</h2><p><a href="https://github.com/OpenRLHF/OpenRLHF">https://github.com/OpenRLHF/OpenRLHF</a></p><p><a href="https://github.com/volcengine/verl">https://github.com/volcengine/verl</a></p><p><a href="https://arxiv.org/abs/2405.11143">OpenRLHF: An Easy-to-use, Scalable and High-performance RLHF Framework</a></p><p><a href="https://arxiv.org/pdf/2409.19256v2">arXiv reCAPTCHA</a></p><p><a href="https://zhuanlan.zhihu.com/p/635757674">低级炼丹师：强化学习从零到RLHF（八）一图拆解RLHF中的PPO</a></p><p><a href="https://zhuanlan.zhihu.com/p/27676081245">不关岳岳的事：[AI Infra] VeRL 框架入门&amp;代码带读</a></p><p><a href="https://zhuanlan.zhihu.com/p/26833089345">杨远航：基于 Ray 的分离式架构：veRL、OpenRLHF 工程设计</a></p><p><a href="https://zhuanlan.zhihu.com/p/12871616401">猛猿：图解OpenRLHF中基于Ray的分布式训练流程</a></p><p><a href="https://verl.readthedocs.io/en/latest/hybrid_flow.html">HybridFlow Programming Guide</a></p>]]></content>
    
    
      
      
    <summary type="html">&lt;blockquote&gt;
&lt;p&gt;原文链接: &lt;a href=&quot;https://zhuanlan.zhihu.com/p/29046833667&quot;&gt;https://zhuanlan.zhihu.com/p/29046833667&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2 </summary>
      
    
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/categories/zhihu/"/>
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/tags/zhihu/"/>
    
  </entry>
  
  <entry>
    <title>书籍推荐《深度学习入门2:自制框架》</title>
    <link href="https://wyc-ruiker.github.io/2025/02/17/zhihu/%E4%B9%A6%E7%B1%8D%E6%8E%A8%E8%8D%90%E3%80%8A%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A82_%E8%87%AA%E5%88%B6%E6%A1%86%E6%9E%B6%E3%80%8B/"/>
    <id>https://wyc-ruiker.github.io/2025/02/17/zhihu/%E4%B9%A6%E7%B1%8D%E6%8E%A8%E8%8D%90%E3%80%8A%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A82_%E8%87%AA%E5%88%B6%E6%A1%86%E6%9E%B6%E3%80%8B/</id>
    <published>2025-02-16T16:00:00.000Z</published>
    <updated>2026-04-19T17:00:43.062Z</updated>
    
    <content type="html"><![CDATA[<blockquote><p>原文链接: <a href="https://zhuanlan.zhihu.com/p/22879442357">https://zhuanlan.zhihu.com/p/22879442357</a></p></blockquote><p>作者使用纯Python实现了一个名为DeZero的深度学习框架，其编程风格与PyTorch类似。通过结合代码实现，由浅入深地拆解了深度学习框架的前端功能，包括自动微分、Layer、Optimizer、DataLoader等模块。</p><p><a href="https://book.douban.com/subject/36303408/"><img src="/zhihu/images/书籍推荐《深度学习入门2_自制框架》_img_0.jpg">深度学习入门2book.douban.com/subject/36303408/</a>这本书在国内比较冷门，但是我读过的最好的关于深度学习框架的书籍。这里所说的“好”，并不是指内容面面俱到，而是因为它的逻辑非常清晰，对入门十分友好。同时又不算很粗浅，像高阶微分这样的深度话题也有探讨。</p><p>本书的作者是Chainer的核心开发者Koki Saitoh。Chainer可能是最早在Python实现的define by run的深度学习框架（更常见的叫法是<strong>动态图</strong>），非常具有前瞻性，可惜在2019年的时候母公司（Preferred Networks）转投了PyTorch，项目停止了维护<a href="#ref_1">[1]</a>。</p><p>这个母公司Preferred Networks也值得聊一聊。从主页来看啥都搞，AI外包公司，像是日本的商汤科技。</p><p><a href="https://www.preferred.jp/en/projects/"><img src="/zhihu/images/书籍推荐《深度学习入门2_自制框架》_img_1.jpg">Projects - Preferred Networks, Inc.www.preferred.jp/en/projects/</a>创始人兼CEO叫Toru Nishikawa，代表东京大学参加了2006年ACM ICPC世界总决赛，荣获第19名。</p><p><img src="/zhihu/images/书籍推荐《深度学习入门2_自制框架》_img_2.jpg"></p><p>比较有意思的一点是前一年（2005）的ACM ICPC世界总决赛冠军是戴文渊大神，后来创建了第四范式，也是个AI外包公司。大神们的水平都不需要质疑，但最后都成为了AI外包公司，可能这就是这代人的宿命吧。。。</p><p>Preferred Networks还有个很出名的项目叫CuPy<a href="#ref_2">[2]</a>，相当于numpy在cuda上的实现，也为书里的DeZero框架提供了GPU能力（不过这个前后端解耦设计其实不怎么合理，后端能拿到的前端信息太少了）。这个项目现在还活着，而且还在发版本。从产出来看，Preferred Networks确实聚集了一批日本AI仙人，做出了一些很有前瞻性的且很有影响力的软件，希望该公司业务蒸蒸日上。</p><p>再说一下书的内容，有几点我觉得是写的非常出色的：</p><ol type="1"><li>每一章都基于代码讲解，接着上一章的代码一步步从简单到复杂，阅读体验很流畅。而且看DeZero的代码实现能学到很多技巧（对于我这种没怎么写过Python的人来说），例如Layer这一节通过重写__setattr__的魔术方法实现了每个深度学习层的Parameter自动统计。</li><li>概念的逐步引入。DeZero不是一股脑给一堆概念，对着概念一顿实现就完事了，而是一步步的基于功能的需要引入这些概念。比如实现自动微分的时候为什么需要引入Variable和Function；为什么要进行运算符重载、Layer/Module/Optimizer/DataLoader的封装，通过这些方式能把代码优化到什么程度。这种写作方式类似OSTEP，围绕着虚拟化、并发和持久性，一层一层的展示每个功能为什么要做，做完的好处是什么。一个好教材甚至于说一个好项目都应该是这样的，不应该有太多无用的概念，每一层的抽象都得有明确的目的。</li><li>组合的力量。看高阶微分这一章的时候让我想起了SICP，SICP的很多细节我早就忘得一干二净了，唯一剩下的就是对组合力量的深刻印象。</li></ol><figure><img src="/zhihu/images/书籍推荐《深度学习入门2_自制框架》_img_3.jpg" alt="高阶微分"><figcaption aria-hidden="true">高阶微分</figcaption></figure><p>组合其实就是通过简单的、模块化的组件组合，构建出复杂且强大的系统。这里的高阶微分就是一个例子。从一阶微分到高阶微分，通过完善的抽象设计，核心改动只需要将反向实现从调用numpy接口变成调用DeZero的运算符重载。后面的Layer设计其实也是组合的力量，从一个玩具框架最后组合出了rnn、cnn等网络，到真正的加载vgg的预训练模型，这些组合的力量让人印象深刻。</p><p>从这几点来看，我觉得这本书不仅是一个深度学习框架入门书籍，还是个极好的第二本Python书籍，适合学习完基本语法之后来看。他实现的是一个实打实的项目，肯定比实现什么算法之类的有意思多了，并且在实现该项目的同时，触及了一些计算机科学的本质。</p><p>最后说一下Chainer为什么消亡。其实我在网上没搜到太多对于Chainer缺点的评论，在reddit上有一些关于Chainer为什么打不过PyTorch的讨论<a href="#ref_3">[3]</a> <a href="#ref_4">[4]</a> <a href="#ref_5">[5]</a>。</p><p>从讨论来看，基本输给了生态。从易用性上Chainer没理由比PyTorch要差，只是日本的计算机圈子还是太小了，Preferred Networks比不上FB财大气粗。而且有人说大量的资料都是日语，导致Chainer更像是日本计算机小圈子的自娱自乐。Chainer的失败对国内的基础软件走向世界还是有很大启发的，虽然国内公司的深度学习基础软件看起来还没有Chainer成功。。。</p><h2 id="参考">参考</h2><ol type="1"><li><a href="#ref_1_0">^</a>Chainer/CuPy v7 release and Future of Chainer <a href="https://chainer.org/announcement/2019/12/05/released-v7.html">https://chainer.org/announcement/2019/12/05/released-v7.html</a></li><li><a href="#ref_2_0">^</a><a href="https://cupy.dev/">https://cupy.dev/</a></li><li><a href="#ref_3_0">^</a><a href="https://www.reddit.com/r/MachineLearning/comments/e0pir2/dwhy_isnt_chainer_more_useddiscussed/">https://www.reddit.com/r/MachineLearning/comments/e0pir2/dwhy_isnt_chainer_more_useddiscussed/</a></li><li><a href="#ref_4_0">^</a><a href="https://www.reddit.com/r/MachineLearning/comments/e6dd7x/d_preferred_networks_creators_of_chainer/">https://www.reddit.com/r/MachineLearning/comments/e6dd7x/d_preferred_networks_creators_of_chainer/</a></li><li><a href="#ref_5_0">^</a><a href="https://www.reddit.com/r/MachineLearning/comments/7lb5n1/d_chainer_vs_pytorch/">https://www.reddit.com/r/MachineLearning/comments/7lb5n1/d_chainer_vs_pytorch/</a></li></ol>]]></content>
    
    
      
      
    <summary type="html">&lt;blockquote&gt;
&lt;p&gt;原文链接: &lt;a href=&quot;https://zhuanlan.zhihu.com/p/22879442357&quot;&gt;https://zhuanlan.zhihu.com/p/22879442357&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;作</summary>
      
    
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/categories/zhihu/"/>
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/tags/zhihu/"/>
    
  </entry>
  
  <entry>
    <title>对DualPipe的一些想法</title>
    <link href="https://wyc-ruiker.github.io/2025/02/06/zhihu/%E5%AF%B9DualPipe%E7%9A%84%E4%B8%80%E4%BA%9B%E6%83%B3%E6%B3%95/"/>
    <id>https://wyc-ruiker.github.io/2025/02/06/zhihu/%E5%AF%B9DualPipe%E7%9A%84%E4%B8%80%E4%BA%9B%E6%83%B3%E6%B3%95/</id>
    <published>2025-02-05T16:00:00.000Z</published>
    <updated>2026-04-19T17:00:43.062Z</updated>
    
    <content type="html"><![CDATA[<blockquote><p>原文链接: <a href="https://zhuanlan.zhihu.com/p/21525151726">https://zhuanlan.zhihu.com/p/21525151726</a></p></blockquote><p>DualPipe是DeepSeek V3里面infra部分的重头戏，既能减少bubble还能做到通信掩盖，整个编排方式看起来神乎其技。而且对于一些国产芯片复现DeepSeek V3（FP8很难用起来、也没法用SM去控制网卡），DualPipe几乎是论文里面infra部分唯一能参考的优化点。我用我半吊子的大模型并行知识，猜想一下这个东西是怎么构造出来的，以及落地上可能遇到的难度。</p><h2 id="动机猜想">动机猜想</h2><p>对于DeepSeek V3，性能的最大瓶颈就是正反向里面巨大的AllToAll通信，根据论文里面的说法，计算和通信大概是1:1的水平。而且整体拓扑关系是attention-&gt;alltoall-&gt;mlp-&gt;alltoall这样的顺序，前后有严格的依赖，根本没什么掩盖空间。</p><p>对于这种类型的通信，比较常见的做法是bs或者seq维度切一切，把1层拆成两个没有数据依赖的小层，互相做掩盖，类似下图：</p><p><img src="/zhihu/images/对DualPipe的一些想法_img_0.jpg"></p><p>但是每个micro的bs为了省activation的显存，基本都是1；切seq的话，MLP会被切小，可能导致算子本身性能下降，都不算是最好的做法。</p><p>这是两个正向或者两个反向互相做掩盖的场景，如果这个不好搞的话，就可能会想到，一个正向和一个反向能不能做掩盖，因为对于1f1b的调度来说，一定有正反向交替的阶段：</p><p><img src="/zhihu/images/对DualPipe的一些想法_img_1.jpg"></p><p><img src="/zhihu/images/对DualPipe的一些想法_img_2.jpg"></p><p>画一画，会发现不怎么对劲，如果强行掩盖相邻部分的正反向通信，会导致bubble变的巨大：</p><p><img src="/zhihu/images/对DualPipe的一些想法_img_3.jpg"></p><p>那现在的问题就是，需要一个合适的pipeline排布，既可以正反向做掩盖，bubble也不会增大。估计DeepSeek的大神们很快想到了Chimera，因为这个是双向的pp，上面rank灌进来的micro和下面rank灌进来的micro，相互之间的依赖关系没有这么重，如果相邻的正反向是不同方向的micro，是可以掩盖起来的，而且不会增大bubble：</p><figure><img src="/zhihu/images/对DualPipe的一些想法_img_4.jpg" alt="想象一下上面的图micro再多一点"><figcaption aria-hidden="true">想象一下上面的图micro再多一点</figcaption></figure><p>结合上zero bubble的优化点，把dw拆出来，最后就形成了DualPipe：</p><p><img src="/zhihu/images/对DualPipe的一些想法_img_5.jpg"></p><h2 id="一些问题">一些问题</h2><p>DualPipe的编排相比最常用的pp或者vpp甚至zero bubble，都复杂许多，想在工程落地会有很多问题：</p><ul><li>实现逻辑的就很复杂。我不确定他们这个20micro的排布是基于策略写出来的还是基于一套算法搜出来的，想泛化感觉很不容易。</li><li>因为pp的编排复杂，会导致在超大集群下类似精度对比（需要沿着stage往上找输入）、快慢卡识别（会沿着pp传递）、首报错节点识别（要分析通信关系）都变的复杂很多。可能DeepSeek V3用的是2048卡，还不够大，这些问题还没那么重要，上了万卡规模，这些东西定位起来都很痛苦。</li><li>因为是双向的pp，需要每个rank放两个stage，导致param要翻个两倍。这个对于国产芯片可能很致命，显存没有那么多不说，暂时还不能用fp8训练，雪上加霜。</li></ul><h2 id="抛砖引玉">抛砖引玉</h2><p>虽然普通的pp不行，但是对于vpp来说，很多地方都是可以做掩盖的：</p><p><img src="/zhihu/images/对DualPipe的一些想法_img_6.jpg"></p><p>红框标注的地方都可以掩盖，而且这个方式相比普通的vpp显存没有任何增加，实现逻辑也比DualPipe简单多了，感觉是个比较靠谱的落地方向。</p><p>或者更本质的，能不能用类似<a href="https://arxiv.org/abs/2406.06858">FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion</a>的思路搞一个AllToAll MLP AllToAll的融合算子，这样对上层业务来说是最轻松的。但这个领域不太熟，不确定能不能做到。</p><p><img src="/zhihu/images/对DualPipe的一些想法_img_7.jpg"></p>]]></content>
    
    
      
      
    <summary type="html">&lt;blockquote&gt;
&lt;p&gt;原文链接: &lt;a href=&quot;https://zhuanlan.zhihu.com/p/21525151726&quot;&gt;https://zhuanlan.zhihu.com/p/21525151726&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;D</summary>
      
    
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/categories/zhihu/"/>
    
    
    <category term="zhihu" scheme="https://wyc-ruiker.github.io/tags/zhihu/"/>
    
  </entry>
  
  <entry>
    <title>MIT-6.S081 2020</title>
    <link href="https://wyc-ruiker.github.io/2020/10/13/mit-os/"/>
    <id>https://wyc-ruiker.github.io/2020/10/13/mit-os/</id>
    <published>2020-10-13T07:17:57.000Z</published>
    <updated>2021-12-17T15:26:22.801Z</updated>
    
    <content type="html"><![CDATA[<p>在几个月前，试图开了一下MIT 6.828的坑，但是因为各种原因，只做了lab1就搁置了。前几天突然在知乎上看到了<a href="https://zhuanlan.zhihu.com/p/251366985">二十八画生征友：一起来通关6.S081/6.828吧~</a>，发现这个MIT 6.S081是MIT 6.828的简化版，而且梯度更加的平滑。 <span id="more"></span> 之前的一段时间都因为各种事情，忙的飞起。最近正好有一段空闲时间，希望可以在一个月内通关MIT 6.S081。</p><p>一个比较好的<a href="https://mit-public-courses-cn-translatio.gitbook.io/mit6-s081/">参考资料</a>。</p><p>首先按照<a href="https://pdos.csail.mit.edu/6.828/2020/tools.html">tools</a>配环境，因为我用的是Ubuntu 18的虚拟机，所以需要自己手动去编译riscv-gnu-toolchain，编译的过程其实非常简单，但是下载的过程非常痛苦。这边有个<a href="https://blog.csdn.net/zhayujie5200/article/details/106374189/">老哥</a>直接把riscv-gnu-toolchain放到百度云盘上了，这样下载就会快很多而且不容易中断。</p><h1 id="lab-1-utilities">Lab-1 Utilities</h1><p>Lab1其实就是实现一些shell指令，做实验前一定要把<a href="https://pdos.csail.mit.edu/6.828/2020/xv6/book-riscv-rev1.pdf">xv6 book</a>的第一章通读一遍，<strong>非常关键！！！</strong></p><h2 id="sleep-pingpong">sleep &amp; pingpong</h2><p>相对简单，略过。这个pingpong我是用两个pipe从两个方向传输的。</p><h2 id="primes">primes</h2><p>一个非常风骚的多进程筛法。主要流程如下图：</p><p><img src="/2020/10/13/mit-os/1.PNG"></p><p>其实就是每个素数开一个进程，多个进程形成一个pipeline。按顺序输入数字，如果一个数被前面所有进程漏掉，那他显然是一个素数，就再开一个进程进入pipeline。</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string">&quot;kernel/types.h&quot;</span></span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string">&quot;kernel/stat.h&quot;</span></span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string">&quot;user/user.h&quot;</span></span></span><br><span class="line"></span><br><span class="line"><span class="type">int</span> p[<span class="number">36</span>][<span class="number">2</span>];</span><br><span class="line"></span><br><span class="line"><span class="type">void</span> <span class="title function_">prime_fork</span><span class="params">(<span class="type">int</span> last, <span class="type">int</span> from)</span> &#123;</span><br><span class="line">    <span class="keyword">if</span> (fork() == <span class="number">0</span>) &#123;</span><br><span class="line">        <span class="type">int</span> prime, now;</span><br><span class="line">        now = <span class="number">-1</span>; prime = <span class="number">-1</span>;</span><br><span class="line">        <span class="keyword">if</span> (last != <span class="number">-1</span>) &#123;</span><br><span class="line">            close(p[last][<span class="number">0</span>]);</span><br><span class="line">        &#125;</span><br><span class="line">        close(p[from][<span class="number">1</span>]);</span><br><span class="line">        <span class="keyword">while</span> (read(p[from][<span class="number">0</span>], &amp;now, <span class="number">4</span>)) &#123;</span><br><span class="line">            <span class="keyword">if</span> (prime == <span class="number">-1</span>) &#123;</span><br><span class="line">                prime = now;</span><br><span class="line">                pipe(p[prime]);</span><br><span class="line">                prime_fork(from, prime);</span><br><span class="line">                close(p[prime][<span class="number">0</span>]);</span><br><span class="line">                <span class="built_in">printf</span>(<span class="string">&quot;prime %d\n&quot;</span>, prime);</span><br><span class="line">                <span class="keyword">continue</span>;</span><br><span class="line">            &#125; <span class="keyword">else</span> &#123;</span><br><span class="line">                <span class="keyword">if</span> (now % prime == <span class="number">0</span>) &#123;</span><br><span class="line">                    <span class="keyword">continue</span>;</span><br><span class="line">                &#125; <span class="keyword">else</span> &#123;</span><br><span class="line">                    write(p[prime][<span class="number">1</span>], &amp;now, <span class="number">4</span>);</span><br><span class="line">                &#125;</span><br><span class="line">            &#125;</span><br><span class="line">        &#125;</span><br><span class="line">        <span class="keyword">if</span> (prime != <span class="number">-1</span>) &#123;</span><br><span class="line">            close(p[prime][<span class="number">1</span>]);</span><br><span class="line">            wait(<span class="number">0</span>);</span><br><span class="line">        &#125;</span><br><span class="line">        <span class="built_in">exit</span>(<span class="number">0</span>);</span><br><span class="line">    &#125;</span><br><span class="line">    <span class="keyword">return</span>;</span><br><span class="line">&#125;</span><br><span class="line"></span><br><span class="line"><span class="type">int</span> <span class="title function_">main</span><span class="params">()</span> &#123;</span><br><span class="line">    pipe(p[<span class="number">1</span>]);</span><br><span class="line">    prime_fork(<span class="number">-1</span>, <span class="number">1</span>);</span><br><span class="line">    close(p[<span class="number">1</span>][<span class="number">0</span>]);</span><br><span class="line">    <span class="keyword">for</span> (<span class="type">int</span> i = <span class="number">2</span>; i &lt;= <span class="number">35</span>; i++) &#123;</span><br><span class="line">        write(p[<span class="number">1</span>][<span class="number">1</span>], &amp;i, <span class="number">4</span>);</span><br><span class="line">    &#125;</span><br><span class="line">    close(p[<span class="number">1</span>][<span class="number">1</span>]);</span><br><span class="line">    wait(<span class="number">0</span>);</span><br><span class="line">    <span class="built_in">exit</span>(<span class="number">0</span>);</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><p>我直接暴力的用了一个p数组，代表每个进程向下传输所需要的pipe。一个素数的进程的pipe就直接使用p[素数]的数组就好了。因为fd空间不够35以内的素数开满，所以要管理fd，这个地方要小心处理，我因为close了一些回收过的fd错了好几次。</p><p>一个值得注意的地方是read前的close(p[from][1])，这个bug我也卡了一小会儿，这里直接引用xv6 book中的原文：</p><blockquote><p>If no data is available, a read on a pipe waits for either data to be written or for all file descriptors referring to the write end to be closed.</p></blockquote><p>没错，是all file descriptors都要close才行，不然read就会一直挂在那里，导致死循环。</p><h2 id="find">find</h2><p>这个本身其实挺复杂的，幸好ls指令已经写好了，所以直接抄就完事儿了。</p><h2 id="xargs">xargs</h2><p>这个理论上其实很简单，就是解析标准读入进来的字符串，先按照'\n' split一下，再按照' ' split一下，最后拼成argv就好了。但是有个地方非常坑爹，我最开始的时候，从标准读入读字符串，简单写成了这样：</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">buf_size = read(0, buf, (sizeof buf))</span><br><span class="line">do something with buf...</span><br></pre></td></tr></table></figure><p>这么写完之后，make grade大概五次会错一次，非常诡异。一般来说这种情况都是多进程出了问题，这时再掏出xv6 book的原文：</p><blockquote><p>The child process creates a pipe to connect the left end of the pipeline with the right end. Then it calls fork and runcmd for the left end of the pipeline and fork and runcmd for the right end, and waits for both to finish.</p></blockquote><p>一个pipe的左右指令是多进程，所以当左边指令输出，右边指令输入的时候，应该写成阻塞式：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">while</span> ((buf_size = read(<span class="number">0</span>, buf + offset, (<span class="keyword">sizeof</span> buf) - offset)) != <span class="number">0</span>) &#123;</span><br><span class="line">    offset += buf_size;</span><br><span class="line">&#125;</span><br><span class="line">buf_size = offset;</span><br></pre></td></tr></table></figure><p>我跑了十次，全都通过了testcase。</p><h1 id="lab-2-system-calls">Lab-2 System calls</h1><p>这次实验是实现两个系统调用trace和sysinfo，感觉在实现上没什么好说的，两个任务都比较简单。这里想重点总结一下整个xv6的boot和system call的流程。</p><p>首先讲讲boot的整个流程。我们的整个实验环境是搭建在RISC-V上面的，在RISC-V中，CPU运行的指令分成三种模式：machine mode, supervisor mode和user mode。最开始的时候，CPU是处于machine mode的。开机后，CPU运行boot loader中的指令，把整个kernel都给load进来。然后进入entry.S：</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line">_entry:</span><br><span class="line"># set up a stack for C.</span><br><span class="line">        # stack0 is declared in start.c,</span><br><span class="line">        # with a 4096-byte stack per CPU.</span><br><span class="line">        # sp = stack0 + (hartid * 4096)</span><br><span class="line">        la sp, stack0</span><br><span class="line">        li a0, 1024*4</span><br><span class="line">csrr a1, mhartid</span><br><span class="line">        addi a1, a1, 1</span><br><span class="line">        mul a0, a0, a1</span><br><span class="line">        add sp, sp, a0</span><br><span class="line"># jump to start() in start.c</span><br><span class="line">        call start</span><br></pre></td></tr></table></figure><p>entry.S的任务就是初始化stack，然后开始call start.c:</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">void</span></span><br><span class="line"><span class="title function_">start</span><span class="params">()</span></span><br><span class="line">&#123;</span><br><span class="line">  <span class="comment">// set M Previous Privilege mode to Supervisor, for mret.</span></span><br><span class="line">  <span class="type">unsigned</span> <span class="type">long</span> x = r_mstatus();</span><br><span class="line">  x &amp;= ~MSTATUS_MPP_MASK;</span><br><span class="line">  x |= MSTATUS_MPP_S;</span><br><span class="line">  w_mstatus(x);</span><br><span class="line"></span><br><span class="line">  <span class="comment">// set M Exception Program Counter to main, for mret.</span></span><br><span class="line">  <span class="comment">// requires gcc -mcmodel=medany</span></span><br><span class="line">  w_mepc((uint64)main);</span><br><span class="line"></span><br><span class="line">  <span class="comment">// disable paging for now.</span></span><br><span class="line">  w_satp(<span class="number">0</span>);</span><br><span class="line"></span><br><span class="line">  <span class="comment">// delegate all interrupts and exceptions to supervisor mode.</span></span><br><span class="line">  w_medeleg(<span class="number">0xffff</span>);</span><br><span class="line">  w_mideleg(<span class="number">0xffff</span>);</span><br><span class="line">  w_sie(r_sie() | SIE_SEIE | SIE_STIE | SIE_SSIE);</span><br><span class="line"></span><br><span class="line">  <span class="comment">// ask for clock interrupts.</span></span><br><span class="line">  timerinit();</span><br><span class="line"></span><br><span class="line">  <span class="comment">// keep each CPU&#x27;s hartid in its tp register, for cpuid().</span></span><br><span class="line">  <span class="type">int</span> id = r_mhartid();</span><br><span class="line">  w_tp(id);</span><br><span class="line"></span><br><span class="line">  <span class="comment">// switch to supervisor mode and jump to main().</span></span><br><span class="line">  <span class="keyword">asm</span> <span class="title function_">volatile</span><span class="params">(<span class="string">&quot;mret&quot;</span>)</span>;</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><p>在RISC-V中，mret、sret和uret分别用于从machine、supervisor和user模式中的trap返回。这里调用的是mret，其实就是从machine mode返回到supervisor mode。</p><p>在main.c中，userinit函数首先allocproc()出来第一个用户进程，然后执行initcode.S中的指令，这个指令其实就是调用SYS_exec执行init.c：</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line">.globl start</span><br><span class="line">start:</span><br><span class="line">        la a0, init</span><br><span class="line">        la a1, argv</span><br><span class="line">        li a7, SYS_exec</span><br><span class="line">        ecall</span><br><span class="line"></span><br><span class="line"># for(;;) exit();</span><br><span class="line">exit:</span><br><span class="line">        li a7, SYS_exit</span><br><span class="line">        ecall</span><br><span class="line">        jal exit</span><br><span class="line"></span><br><span class="line"># char init[] = &quot;/init\0&quot;;</span><br><span class="line">init:</span><br><span class="line">  .string &quot;/init\0&quot;</span><br><span class="line"></span><br><span class="line"># char *argv[] = &#123; init, 0 &#125;;</span><br><span class="line">.p2align 2</span><br><span class="line">argv:</span><br><span class="line">  .long init</span><br><span class="line">  .long 0</span><br></pre></td></tr></table></figure><p>在init.c中，系统生成一个名为"console"的device，并重载标准输出跟标准错误输出到这个device上，之后启动sh.c，shell启动，正式进入操作系统。</p><p>然后再讲讲整个system call的流程。这里我们可以用作业中的Sysinfo作为例子。</p><p>Sysinfo的功能是统计一下当前系统的剩余内存和已经使用的进程数量，并把这两个参数组装成一个结构体，传回给用户。这里就要分为三步，首先在user mode进行system call，进入supervisor mode，然后统计两个信息，最后system call return需要的信息回来。</p><p>我们要在user mode中调用这个sysinfo函数，但是因为这是一个系统调用，肯定不能在user mode里面实现，所以需要一个user mode的函数与supervisor mode的system call的绑定。具体到代码中，其实就是user/usys.pl中会有个entry("sysinfo")，这个语句会生成一段汇编代码：</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">sysinfo:</span><br><span class="line"> li a7, SYS_sysinfo</span><br><span class="line"> ecall</span><br><span class="line"> ret</span><br></pre></td></tr></table></figure><p>这个ecall指令会让我们进入supervisor mode并运行SYS_sysinfo函数。那么问题来了，我们看到系统调用的参数都是(void)的，并且我们当前user mode的各种寄存器要怎么保存呢？要解决这个问题，首先要知道ecall指令到底干了什么，在xv6 book中，是这样说的：</p><blockquote><p>The ecall instruction traps into the kernel and executes uservec, usertrap, and then syscall, as we saw above.</p></blockquote><p>uservec是一段汇编，主要是保护当前的user mode现场，将user进程中的寄存器数据放到TRAPFRAME中，同时加载kernel的页表。然后调用trap.c/usertrap()。</p><p>在trap.c/usertrap()会进入syscall(void)函数，syscall(void)会根据a7寄存器中的值调用对应的system call，system call参数的传递通过TRAPFRAME中的a0-a5寄存器，运行过后的返回值放到p-&gt;trapframe-&gt;a0寄存器中。</p><p>然后调用userret，也是一段汇编代码。userret会恢复user mode的页表，并且从TRAPFRAME中恢复过去的寄存器，这样就完成了一次系统调用。但是我们发现，system call并没有传递任何值给user mode，其实system call的返回值要通过copyout函数，直接写入该进程在user mode能使用的地址中。</p><h1 id="lab-3-page-tables">Lab-3 Page tables</h1><p>这个实验相对复杂，要比较清晰的理解页表才能顺利的完成。</p><p>首先，我们要了解页表是个什么东西。其实页表就是个(key, value)的pair集合，key是虚拟地址，value是物理地址，在程序使用的地址跟计算机实际运行的内存地址之间产生一层隔离关系。</p><p>这种隔离关系的好处是，不同进程之间无法直接访问对方的变量，恶意程序也无法破坏整个操作系统，只能破坏自己的进程。</p><p>其次，还可以使得物理上不连续的、不从0开始的物理地址变成连续的、从0开始的虚拟地址，方便内核做统一的内存管理，而不是每个进程自己乱管理。</p><h2 id="print-a-page-table">Print a page table</h2><p><img src="/2020/10/13/mit-os/2-1.PNG"></p><p>完成打印页表的任务，只需要理解这个三级页表的架构就好了。从图片可以轻松了解这个数据结构，但是作为一个kv系统，为什么要这么设计呢？</p><p>其实很简单，如果页表只有一级，那就是空间换时间，一个巨大的桶存着页表。如果页表级数特别多，那访问时间就会很慢。三级页表就是个设计上的trade-off，而且这么设计，每一级页表就是一个page的大小，我觉得代码会更加规整。</p><p>判断一个页表的内容是否合法，其实就是对应一些掩码的位置，核心逻辑如下：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">void</span></span><br><span class="line">_vmprint(<span class="type">pagetable_t</span> pagetable, <span class="type">int</span> level)</span><br><span class="line">&#123;</span><br><span class="line">  <span class="comment">// there are 2^9 = 512 PTEs in a page table.</span></span><br><span class="line">  <span class="keyword">for</span>(<span class="type">int</span> i = <span class="number">0</span>; i &lt; <span class="number">512</span>; i++) &#123;</span><br><span class="line">    <span class="type">pte_t</span> pte = pagetable[i];</span><br><span class="line">    <span class="keyword">if</span> (pte &amp; PTE_V) &#123;</span><br><span class="line">        uint64 child = PTE2PA(pte);</span><br><span class="line">        <span class="keyword">for</span> (<span class="type">int</span> j = <span class="number">0</span>; j &lt; level; j++) &#123;</span><br><span class="line">            <span class="built_in">printf</span>(<span class="string">&quot;..&quot;</span>);</span><br><span class="line">            <span class="keyword">if</span> (j + <span class="number">1</span> != level) <span class="built_in">printf</span>(<span class="string">&quot; &quot;</span>);</span><br><span class="line">        &#125;</span><br><span class="line">        <span class="built_in">printf</span>(<span class="string">&quot;%d: pte %p pa %p\n&quot;</span>, i, pte, child);</span><br><span class="line">        <span class="keyword">if</span> ((pte &amp; (PTE_R|PTE_W|PTE_X)) == <span class="number">0</span>) &#123;</span><br><span class="line">            _vmprint((<span class="type">pagetable_t</span>)child, level + <span class="number">1</span>);</span><br><span class="line">        &#125;</span><br><span class="line">    &#125;</span><br><span class="line">  &#125;</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><h2 id="a-kernel-page-table-per-process">A kernel page table per process</h2><p>这个环节又涉及到一个新的问题：用户级页表和内核级页表。</p><p>在上个环节我们已经知道，页表的主要功能是为了控制访问权限。那么理论上来说，每个用户级进程都要有自己的页表，从而只能访问自己的地址，不能访问别人的地址。内核既然是随便访问，那么用一个固定不变的内核页表就好了。所以在用户级进程与系统调用之间涉及到一个页表切换的问题。</p><p>kernel page table per process要求我们实现一个新功能，内核页表不再是全局的，而是每个进程独立的。这个任务分成两步：建立每个进程自己的内核页表、实现页表切换。</p><p>首先看一下内核页表在哪里进行了改动，我只发现了两个地方，一个是kvminit()，将低地址的IO device和trampoline映射到页表中。第二个是procinit()，这里声明了进程池中每个进程的栈空间，并跟内核页表进行绑定，方便内核从每个进程的栈中获取参数之类的。</p><p>那这个建立每个进程自己的内核页表就很简单的，先把kvminit()里面的东西都映射到每个进程的内核页表中，再绑定自己进程的栈就可以了。</p><p>实现页表切换分为两步，将页表写入satp寄存器，然后用sfence.vma刷新TLB，所以页表命中是完全靠硬件，页表缺失会引起中断，需要操作系统跟硬件的配合。在scheduler()中就可以进行页表的切换，需要注意的是，进程的内核部分运行结束后要立刻把页表切换成全局的内核页表。</p><h2 id="simplify-copyincopyinstr">Simplify copyin/copyinstr</h2><p>经过上一轮的准备之后，我们可以简化copyin/copyinstr的流程。copyin是一个系统调用，要把进程中内存的一些东西copy到内核中，这里的问题在于，系统调用传入的地址是一个虚拟地址，在没有进程中的用户页表的情况下，我们无从得知该虚拟地址对应的物理地址的位置。所以在xv6中，copyin的函数签名为：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">int</span> <span class="title function_">copyin</span><span class="params">(<span class="type">pagetable_t</span> pagetable, <span class="type">char</span> *dst, uint64 srcva, uint64 len)</span></span><br></pre></td></tr></table></figure><p>这里问题的关键是，因为内核页表是全局的，我们无法修改内核页表，让硬件帮助我们完成虚拟地址到物理地址的转换，所以要先进行一步软件层面的转换，然后通过内核页表的直接映射的性质去完成系统调用。当然，我们在上一个环节实现了每个进程自己维护的内核页表，所以我们可以将虚拟地址的映射加入到自己维护的内核页表中。</p><p>这里为了保持一致性，要小心的修改，在fork(), exec()和sbrk()的函数中，都有对虚拟地址映射的修改，主要增加或是改变对应内存大小。</p><p>这里其实有个问题，就是内核页表是采用直接映射的方式。也就是在映射用户的虚拟地址之前，低位其实已经映射一些内容了，这就会导致remap。一方面要调整我们的函数去忽略remap，另一方面lab指导中告诉我们，PLIC寄存器之前的空间都可以随便映射。</p><h1 id="lab-4-traps">Lab-4 Traps</h1><p>该实验系统的总结了trap的流程，其中大致的流程我们在上一篇文章lab2中都已经叙述过了，通过lab4，我们会对整个流程理解的更加深刻。</p><p>trap来自三种情况：syscall对ecall的调用、程序运行的错误、设备中断（比如时钟）。当中断发生时，我们会自动运行一段运行好的程序，当然，kernel中断跟user中断运行的程序应该是不同的，所以要用一个寄存器储存这段程序的位置。user中断的过程会相对复杂，因为user process会发生各种事情，所以我们就从user中断讲起。user中断运行的位置就是trampoline中的uservec。</p><p>uservec的功能我们lab2中已经分析过了，主要是切换页表跟保护现场两个功能。因为要使得切换页表后，trampoline中的程序在用户态和内核态都能运行，所以trampoline在内核页表和用户页表中要完全相同。</p><p>之后在usertrap中判断trap的类型，systemcall就去准备systemcall，时钟中断就去修改process的状态，错误就直接kill整个process，之后再调用userret将之前保护的现场还原回去以及将页表改回用户页表。</p><p>内核的trap跟用户的trap原理相似，但是一个很重要也很繁琐的细节就是，当user trap进入内核态的时候，内核也有可能继续trap。这就要求我们在进入user trap的时候，要准备好所有kernel trap可能用到的东西，以及保存kernel trap可能破坏的位置，方便kernel trap结束后进行恢复。</p><h2 id="risc-v-assembly">RISC-V assembly</h2><p>作业让我们熟悉一下RISC-V，群里大佬们推荐了一本包云岗翻译的RISC-V的教材，可以用来查询。</p><p>询问了一些简单的RISC-V的问题，但是要注意，他的编译器是默认开优化的，比如第一问：</p><blockquote><p>Which registers contain arguments to functions? For example, which register holds 13 in main's call to printf?</p></blockquote><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">  <span class="built_in">printf</span>(<span class="string">&quot;%d %d\n&quot;</span>, f(<span class="number">8</span>)+<span class="number">1</span>, <span class="number">13</span>);</span><br><span class="line"><span class="number">24</span>:<span class="number">4635</span>                lia2,<span class="number">13</span></span><br><span class="line"><span class="number">26</span>:<span class="number">45b</span>1                lia1,<span class="number">12</span></span><br><span class="line"><span class="number">28</span>:<span class="number">00000517</span>          auipca0,<span class="number">0x0</span></span><br><span class="line"><span class="number">2</span>c:<span class="number">7b</span>050513          addia0,a0,<span class="number">1968</span> # <span class="number">7</span>d8 &lt;<span class="built_in">malloc</span></span><br></pre></td></tr></table></figure><p>看指令，这个13可以理解是什么，但是这个12是哪来的。再一看f跟g函数，估计是太简单被直接常量折叠了，那我们就把编译选项加上-O0再来一遍。</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">printf</span>(<span class="string">&quot;%d %d\n&quot;</span>, f(<span class="number">8</span>)+<span class="number">1</span>, <span class="number">13</span>);</span><br><span class="line"><span class="number">4</span>c:<span class="number">4521</span>                lia0,<span class="number">8</span></span><br><span class="line"><span class="number">4</span>e:<span class="number">00000097</span>          auipcra,<span class="number">0x0</span></span><br><span class="line"><span class="number">52</span>:fce080e7          jalr<span class="number">-50</span>(ra) # <span class="number">1</span>c &lt;f&gt;</span><br><span class="line"><span class="number">56</span>:<span class="number">87</span>aa                mva5,a0</span><br><span class="line"><span class="number">58</span>:<span class="number">2785</span>                addiwa5,a5,<span class="number">1</span></span><br><span class="line"><span class="number">5</span>a:<span class="number">2781</span>                sext.wa5,a5</span><br><span class="line"><span class="number">5</span>c:<span class="number">4635</span>                lia2,<span class="number">13</span></span><br><span class="line"><span class="number">5</span>e:<span class="number">85b</span>e                mva1,a5</span><br><span class="line"><span class="number">60</span>:<span class="number">00001517</span>          auipca0,<span class="number">0x1</span></span><br><span class="line"><span class="number">64</span>:d0850513          addia0,a0,<span class="number">-760</span> # d68 &lt;<span class="built_in">malloc</span>+<span class="number">0x13e</span>&gt;</span><br><span class="line"><span class="number">68</span>:<span class="number">00001097</span>          auipcra,<span class="number">0x1</span></span><br><span class="line"><span class="number">6</span>c:<span class="number">9</span>d0080e7          jalr<span class="number">-1584</span>(ra) # a38 &lt;<span class="built_in">printf</span>&gt;</span><br></pre></td></tr></table></figure><p>可以比较清楚的看出来，第一个f(8)+1在a1寄存器中，第二个13在a2寄存器中。</p><h2 id="backtrace">Backtrace</h2><p>这个任务就是输出调用栈信息，这些任务都存在stack中，stack结构在slide中可以看到：</p><p><img src="/2020/10/13/mit-os/3-1.PNG"></p><p>其中初始的frame-pointer值，实验指导已经给出了代码。那问题就在于什么时候停止往前回溯。其实只要frame-pointer等于PGROUNDUP(fp)就可以跳出循环了。因为最浅层的那个调用既没有frame-pointer，也没有return address。</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">void</span></span><br><span class="line"><span class="title function_">backtrace</span><span class="params">(<span class="type">void</span>)</span></span><br><span class="line">&#123;</span><br><span class="line">  <span class="built_in">printf</span>(<span class="string">&quot;backtrace:\n&quot;</span>);</span><br><span class="line">  uint64 fp = r_fp();</span><br><span class="line">  uint64 upper_bound = PGROUNDUP(fp);</span><br><span class="line">  uint64 ra;</span><br><span class="line">  <span class="keyword">while</span> (<span class="number">1</span>)</span><br><span class="line">  &#123;</span><br><span class="line">    ra = *((uint64*)(fp - <span class="number">8</span>));</span><br><span class="line">    <span class="built_in">printf</span>(<span class="string">&quot;%p\n&quot;</span>, ra);</span><br><span class="line">    fp = *((uint64*)(fp - <span class="number">16</span>));</span><br><span class="line">    <span class="keyword">if</span> (fp == upper_bound) <span class="keyword">break</span>;</span><br><span class="line">  &#125;</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><h2 id="alarm">Alarm</h2><p>这个实验难度较大，让我们实现一个syscall，可以监控cpu在该进程上的运行时间。单独监控运行时间其实比较简单，大致思路就是在process里面加个变量，当产生时间片中断时，就把这个变量++，然后输出。但问题是，实验让我们完成的函数是sigalarm(int ticks, void (*handler)())这种形式的。每次经过ticks时间片，就调用一下handler函数。</p><p>我们根据实验材料中的提示，可以简单得到这样一个思路：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// give up the CPU if this is a timer interrupt.</span></span><br><span class="line"><span class="keyword">if</span>(which_dev == <span class="number">2</span>)</span><br><span class="line">&#123;</span><br><span class="line">p-&gt;timer ++;</span><br><span class="line"><span class="keyword">if</span> (p-&gt;timer == p-&gt;ticks) &#123;</span><br><span class="line">    p-&gt;timer = <span class="number">0</span>;</span><br><span class="line">    p-&gt;trapframe-&gt;epc = (uint64)p-&gt;handler;</span><br><span class="line">&#125;</span><br><span class="line">yield();</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><p>这样当调用usertrapret()之后，会自动运行handler位置的程序。但是当我们运行test程序之后，发现第一次监控输出之后，后面的程序会发生混乱。</p><p>简单分析一下，可以发现，我们回到handler之后，并没有再回到原本process运行的位置。这样就会产生补救的思路，先把trapframe存起来，然后当handler运行结束后，通过另一个系统调用，将trapframe恢复：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line">p-&gt;timer ++;</span><br><span class="line"><span class="keyword">if</span> (p-&gt;timer == p-&gt;ticks) &#123;</span><br><span class="line">    p-&gt;timer = <span class="number">0</span>;</span><br><span class="line">    <span class="keyword">if</span> (p-&gt;ishandle == <span class="number">0</span>)</span><br><span class="line">    &#123;</span><br><span class="line">        p-&gt;ishandle = <span class="number">1</span>;</span><br><span class="line">        memmove(p-&gt;dump_trapframe, p-&gt;trapframe, <span class="keyword">sizeof</span>(<span class="keyword">struct</span> trapframe));</span><br><span class="line">        p-&gt;trapframe-&gt;epc = (uint64)p-&gt;handler;</span><br><span class="line">    &#125;</span><br><span class="line">&#125;</span><br><span class="line"></span><br><span class="line">uint64</span><br><span class="line"><span class="title function_">sys_sigreturn</span><span class="params">(<span class="type">void</span>)</span></span><br><span class="line">&#123;</span><br><span class="line">  <span class="class"><span class="keyword">struct</span> <span class="title">proc</span> *<span class="title">p</span> =</span> myproc();</span><br><span class="line">  p-&gt;ishandle = <span class="number">0</span>;</span><br><span class="line">  memmove(p-&gt;trapframe, p-&gt;dump_trapframe, <span class="keyword">sizeof</span>(<span class="keyword">struct</span> trapframe));</span><br><span class="line">  <span class="keyword">return</span> <span class="number">0</span>;  </span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><p>因为有个slow_handler()测试点，暗示我们handler运行的过程中，计时功能应该关闭，所以我们要加一个p-&gt;ishandle变量来把门。</p><h1 id="lab-5-lazy-page-allocation">Lab-5 Lazy Page Allocation</h1><p>该实验主要是对sbrk的修改，就是当一个process要求操作系统分配更多内存给他的时候，并不直接进行分配，而是当process真的访问到该内存的时候，再进行分配，整个过程都是由缺页中断进行驱动的。</p><h2 id="eliminate-allocation-from-sbrk">Eliminate allocation from sbrk()</h2><p>非常简单，只要修改p-&gt;sz就行了，不需要进行任何其他操作。</p><h2 id="lazy-allocation">Lazy allocation</h2><p>通过这个case其实也非常简单，我感觉怎么写都能过，我在写这个的时候错了好几个地方，都成功通过echo hi了。</p><p>原理也比较简单，就是在缺页中断的时候调用一下mappages。</p><h2 id="lazytests-and-usertests">Lazytests and Usertests</h2><p>这个任务要求我们通过修改上一个任务的代码，从而通过两个非常复杂的程序。这里我遇到了两个非常坑爹的地方。</p><p>首先是lazytests里面的oom，我总是在oom之前就发生p-&gt;sz整数溢出。发现p-&gt;sz是uint64类型的，理论不应该溢出才是。</p><p>观察一下原有的sbrk的写法：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">int</span> addr;</span><br><span class="line"><span class="type">int</span> n;</span><br><span class="line"><span class="keyword">if</span>(argint(<span class="number">0</span>, &amp;n) &lt; <span class="number">0</span>)</span><br><span class="line">    <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">addr = myproc()-&gt;sz;</span><br><span class="line"><span class="keyword">if</span>(growproc(n) &lt; <span class="number">0</span>)</span><br><span class="line">    <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line"><span class="keyword">return</span> addr;</span><br></pre></td></tr></table></figure><p>这里，大部分人第一次都会去用myproc()-&gt;sz=addr+n来修改sbrk，然而这个位置会导致整形溢出，也就是他默认的数据类型会溢出，非常坑爹！</p><p>还有一个地方，更加恶心。就是usertests的sbrkargs，我跑了所有其他case都能通过，就这个case无法通过。</p><p>经过细致的排查，发现sbrkarg的特殊之处在于，他是在copyin这个函数里面软查找页表的时候产生缺页的，不会产生缺页中断，但是会导致写入失败。所以要在walkaddr里面也处理一下lazy alloc的情况。</p><p>也就是说，不是所有缺页都是通过缺页中断驱动的，真是令人作呕的设计啊...</p><h1 id="lab-6-copy-on-write-fork-for-xv6">Lab-6 Copy-on-Write Fork for xv6</h1><p>下面是页表的最后一个实验，实现xv6 fork的COW机制。</p><p>超级恶心，写了我整整一天，各种坑点非常多，而且难以debug。这里建议先从给页表进行引用计数下手，一个显而易见的思路是，将kmem改成如下这样：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">struct</span> &#123;</span></span><br><span class="line">  <span class="class"><span class="keyword">struct</span> <span class="title">spinlock</span> <span class="title">lock</span>;</span></span><br><span class="line">  <span class="type">int</span> ref_cnt[PHYSTOP / PGSIZE];</span><br><span class="line">  <span class="class"><span class="keyword">struct</span> <span class="title">run</span> *<span class="title">freelist</span>;</span></span><br><span class="line">&#125; kmem;</span><br></pre></td></tr></table></figure><p>可能有人会疑惑这个引用数组到底开到了RAM的什么地方？其实这个数组会在link的时候开到RAM上，并通过linker算出extern char end[]的位置，所以只要不是太大，并不会影响到整个程序的运行。这个其实就相当于一个无碰撞的hash表，如果想节省内存，还可以搞点复杂的hash表动态扩容啥的，然而嫌麻烦，就没有搞。然后这个freelist记录的就是所有引用计数为0的page了。</p><p>后面就需要注意uvmcopy以及缺页中断的情况了。在实验的提示中，他说可以利用RISC-V的RSW位，然而我并没有用这个东西，最后也成功通过了所有的case，核心函数如下：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">int</span></span><br><span class="line"><span class="title function_">fix_cow</span><span class="params">(<span class="type">pagetable_t</span> pagetable, uint64 va)</span></span><br><span class="line">&#123;</span><br><span class="line">  uint64 pa;</span><br><span class="line">  <span class="keyword">if</span> ((pa = walkaddr(pagetable, va)) == <span class="number">0</span>) &#123;</span><br><span class="line">    <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">  &#125;</span><br><span class="line">  <span class="type">char</span> *mem;</span><br><span class="line">  <span class="keyword">if</span> ((mem = kalloc()) == <span class="number">0</span>) &#123;</span><br><span class="line">    <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">  &#125;</span><br><span class="line">  memmove(mem, (<span class="type">char</span>*)pa, PGSIZE);</span><br><span class="line">  uvmunmap(pagetable, va, <span class="number">1</span>, <span class="number">1</span>);</span><br><span class="line">  <span class="keyword">if</span>(mappages(pagetable, va, PGSIZE, (uint64)mem, PTE_W|PTE_X|PTE_R|PTE_U) != <span class="number">0</span>)&#123;</span><br><span class="line">    kfree(mem);</span><br><span class="line">    <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">  &#125;</span><br><span class="line">  <span class="keyword">return</span> <span class="number">0</span>;</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><p>我最开始的时候把前两个if写反了，导致countfree怎么都通过不了，非常坚硬，研究了半天。要小心谨慎的维护page，基本所有错误都是page的alloc和free导致的。</p><h1 id="lab-7-multithreading">Lab-7 Multithreading</h1><p>教材这一部分有点难读，大量内容都在解释有关xv6内部内核线程同步的代码，有些烧脑。幸好作业部分非常简单，完成的非常顺利。</p><h2 id="uthread-switching-between-threads">Uthread: switching between threads</h2><p>这个任务是给xv6添加用户线程，其实跟内核线程的添加方式几乎完全相同，因为没有让我们实现锁机制，所以整体的思路非常简单。</p><p>最核心的地方就是切换thread运行的context，这个地方采用跟内核线程一样的方式，只要保存返回地址、栈指针、callee-saved就可以了。</p><p>通过阅读xv6教材，我的理解是，其实内核没有进程的概念，只有用户有进程的概念。只是用户进程trap进入内核态之后，有个内核线程与该进程绑定而已。</p><h2 id="using-threads">Using threads</h2><p>这个任务也非常简单，给一个最简单的开散列hash表加锁，并稍微优化一下性能。</p><p>最简单的思路就是每个bucket都加个锁，恰好能达到要求的1.25x的加速：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">static</span> </span><br><span class="line"><span class="type">void</span> <span class="title function_">put</span><span class="params">(<span class="type">int</span> key, <span class="type">int</span> value)</span></span><br><span class="line">&#123;</span><br><span class="line">  <span class="type">int</span> i = key % NBUCKET;</span><br><span class="line"></span><br><span class="line">  <span class="comment">// is the key already present?</span></span><br><span class="line">  <span class="class"><span class="keyword">struct</span> <span class="title">entry</span> *<span class="title">e</span> =</span> <span class="number">0</span>;</span><br><span class="line">  pthread_mutex_lock(&amp;lock[i]);       <span class="comment">// acquire lock</span></span><br><span class="line">  <span class="keyword">for</span> (e = table[i]; e != <span class="number">0</span>; e = e-&gt;next) &#123;</span><br><span class="line">    <span class="keyword">if</span> (e-&gt;key == key)</span><br><span class="line">      <span class="keyword">break</span>;</span><br><span class="line">  &#125;</span><br><span class="line">  <span class="keyword">if</span>(e)&#123;</span><br><span class="line">    <span class="comment">// update the existing key.</span></span><br><span class="line">    e-&gt;value = value;</span><br><span class="line">  &#125; <span class="keyword">else</span> &#123;</span><br><span class="line">    <span class="comment">// the new is new.</span></span><br><span class="line">    insert(key, value, &amp;table[i], table[i]);</span><br><span class="line">  &#125;</span><br><span class="line">  pthread_mutex_unlock(&amp;lock[i]);     <span class="comment">// release lock</span></span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><h2 id="barrier">Barrier</h2><p>barrier指的就是所有线程都执行到此步之后，在进行后续的程序执行。实际上是对xv6的sleep和wakeup机制的复习。</p><p>整体思路也比较简单，只要加锁就好了：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">static</span> <span class="type">void</span> </span><br><span class="line"><span class="title function_">barrier</span><span class="params">()</span></span><br><span class="line">&#123;</span><br><span class="line">  pthread_mutex_lock(&amp;bstate.barrier_mutex);</span><br><span class="line">  bstate.nthread ++;</span><br><span class="line">  <span class="keyword">if</span> (bstate.nthread != nthread) &#123;</span><br><span class="line">      pthread_cond_wait(&amp;bstate.barrier_cond, &amp;bstate.barrier_mutex);</span><br><span class="line">  &#125; <span class="keyword">else</span> &#123;</span><br><span class="line">      bstate.nthread = <span class="number">0</span>;</span><br><span class="line">      bstate.round ++;</span><br><span class="line">      pthread_cond_broadcast(&amp;bstate.barrier_cond);</span><br><span class="line">  &#125;</span><br><span class="line">  pthread_mutex_unlock(&amp;bstate.barrier_mutex);</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><h1 id="lab-8-locks">Lab-8 Locks</h1><p>教材中主要讲解了xv6系统中各种锁的设计理念。在xv6中，内部锁是由_sync_语句来保证原子性的，感觉这种思路确实相对比较简单。</p><p>因为CPU会进行乱序发射之类的指令顺序的优化，所以在很多情况下，需要显式的加上__sync_synchronize()，防止CPU进行跨越该语句的乱序发射。</p><p>两个任务具有一定的相似性，一个是优化buffer cache的锁，一个是优化kalloc的锁。我采用的方案也比较相似，在kalloc中，每个CPU弄一个独立的链表，在buffer cache中，每个blockbnum的hash值弄一个独立的链表。</p><p>在刚刚开始进行kalloc实验的时候，我走进了一个误区，以为freerange是在每个CPU上面执行的，这里保留一下学习群里的聊天记录，感谢群里的老哥们：</p><p><img src="/2020/10/13/mit-os/5-1.PNG"></p><h1 id="lab-9-file-system">Lab-9 File System</h1><p>感觉教材阅读起来比较困难，因为整个 File System 是分成 7 层层层抽象来构建的，我边参考代码边读了差不多三四遍，才完全把整个结构理解。整个层级图如下图所示：</p><p><img src="/2020/10/13/mit-os/5-2.PNG"></p><h2 id="large-files">Large files</h2><p>这个任务的难度不是太大，正常的文件 innode 结构是这样的：</p><p><img src="/2020/10/13/mit-os/5-3.PNG"></p><p>这样的二级结构使得每个文件只能支持 12+256 个 block，可以变成三级结构，使得整个文件能达到 11+256+256*256 个 block。只要小心的修改 itrunc 和 bmap 函数就可以了。</p><h2 id="symbolic-links">Symbolic links</h2><p>这个任务的难点在于要清晰的了解他想让你实现的 Symbolic links 行为到底是什么样子的。当 open 不带有 O_NOFOLLOW 标记时，需要一直向下找到 hard link 的位置。整个 symbolic link 的内容都要放在对应的 inode 上面，实际上只需要自己去将文件路径和路径长度保存在 inode 里面就可以了。</p><p>open 的修改部分： <figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span>((omode &amp; O_NOFOLLOW) == <span class="number">0</span>)&#123;</span><br><span class="line">    <span class="type">char</span> s[MAXPATH];</span><br><span class="line">    <span class="type">int</span> cnt = <span class="number">0</span>, length;</span><br><span class="line">    <span class="keyword">while</span>(cnt &lt; <span class="number">10</span> &amp;&amp; ip-&gt;type == T_SYMLINK)&#123;</span><br><span class="line">        cnt ++;</span><br><span class="line">        readi(ip, <span class="number">0</span>, (uint64)&amp;length, <span class="number">0</span>, <span class="number">4</span>);</span><br><span class="line">        readi(ip, <span class="number">0</span>, (uint64)s, <span class="number">4</span>, length);</span><br><span class="line">        s[length] = <span class="number">0</span>;</span><br><span class="line">        iunlockput(ip);</span><br><span class="line">        <span class="keyword">if</span>((ip = namei(s)) == <span class="number">0</span>)&#123;</span><br><span class="line">            end_op();</span><br><span class="line">            <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">        &#125;</span><br><span class="line">        ilock(ip);</span><br><span class="line">    &#125;</span><br><span class="line">    <span class="keyword">if</span>(cnt &gt;= <span class="number">10</span> &amp;&amp; ip-&gt;type == T_SYMLINK)&#123;</span><br><span class="line">        iunlockput(ip);</span><br><span class="line">        end_op();</span><br><span class="line">        <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">    &#125;</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure></p><p>sym_link: <figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">int</span></span><br><span class="line"><span class="title function_">len</span><span class="params">(<span class="type">char</span> *s)</span></span><br><span class="line">&#123;</span><br><span class="line">    <span class="type">int</span> cnt;</span><br><span class="line">    <span class="keyword">for</span>(cnt = <span class="number">0</span>; cnt &lt; MAXPATH; cnt ++)&#123;</span><br><span class="line">        <span class="keyword">if</span>(s[cnt] == <span class="number">0</span>) <span class="keyword">break</span>;</span><br><span class="line">    &#125;</span><br><span class="line">    <span class="keyword">return</span> cnt;</span><br><span class="line">&#125;</span><br><span class="line"></span><br><span class="line">uint64</span><br><span class="line"><span class="title function_">sys_symlink</span><span class="params">(<span class="type">void</span>)</span></span><br><span class="line">&#123;</span><br><span class="line">    <span class="type">char</span> target[MAXPATH], path[MAXPATH];</span><br><span class="line">    <span class="class"><span class="keyword">struct</span> <span class="title">inode</span> *<span class="title">ip</span>;</span></span><br><span class="line"></span><br><span class="line">    <span class="keyword">if</span>(argstr(<span class="number">0</span>, target, MAXPATH) &lt; <span class="number">0</span> || argstr(<span class="number">1</span>, path, MAXPATH) &lt; <span class="number">0</span>)</span><br><span class="line">        <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line"></span><br><span class="line">    begin_op();</span><br><span class="line">    <span class="keyword">if</span>((ip = create(path, T_SYMLINK, <span class="number">0</span>, <span class="number">0</span>)) == <span class="number">0</span>)&#123;</span><br><span class="line">        end_op();</span><br><span class="line">        <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">    &#125;</span><br><span class="line">    <span class="type">int</span> length = len(target);</span><br><span class="line">    writei(ip, <span class="number">0</span>, (uint64)&amp;length, <span class="number">0</span>, <span class="number">4</span>);</span><br><span class="line">    writei(ip, <span class="number">0</span>, (uint64)target, <span class="number">4</span>, length + <span class="number">1</span>);</span><br><span class="line">    iunlockput(ip);</span><br><span class="line">    end_op();</span><br><span class="line">    <span class="keyword">return</span> <span class="number">0</span>;</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure></p><h1 id="lab-10-mmap">Lab-10 Mmap</h1><p>这个实验的难点在于，mmap 要实现成 lazy alloc 的形式，所以要对页表进行大量的操作，然而做到 lab-10 可能前面页表的细节已经完全忘干净了（</p><p>不过缺点也在于测试样例过于简单，感觉覆盖的场景非常不完全。</p><p>对于这个 lazy alloc 的处理，我选择了一种比较取巧的方式，就是在开始的时候完全按照 lazy alloc 的那个方式来写，在最后判断的时候，多加一个 mmap 的判断：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">void</span></span><br><span class="line"><span class="title function_">mmaplazyalloc</span><span class="params">(<span class="type">int</span> va)</span></span><br><span class="line">&#123;</span><br><span class="line">  <span class="class"><span class="keyword">struct</span> <span class="title">proc</span> *<span class="title">p</span> =</span> myproc();</span><br><span class="line">  <span class="type">int</span> found = <span class="number">0</span>, i = <span class="number">0</span>;</span><br><span class="line">  <span class="keyword">for</span> (i = <span class="number">0</span>; i &lt; <span class="number">16</span>; i++) &#123;</span><br><span class="line">    <span class="keyword">if</span> (p-&gt;mmaps[i].used) &#123;</span><br><span class="line">      uint64 start = p-&gt;mmaps[i].addr;</span><br><span class="line">      uint64 end = p-&gt;mmaps[i].addr + p-&gt;mmaps[i].length;</span><br><span class="line">      <span class="keyword">if</span> (va &gt;= start &amp;&amp; va &lt; end) &#123;</span><br><span class="line">        found = <span class="number">1</span>;</span><br><span class="line">        <span class="keyword">break</span>;</span><br><span class="line">      &#125;</span><br><span class="line">    &#125;</span><br><span class="line">  &#125;</span><br><span class="line">  <span class="keyword">if</span> (found == <span class="number">0</span>)</span><br><span class="line">    <span class="keyword">return</span>;</span><br><span class="line">  </span><br><span class="line">  <span class="class"><span class="keyword">struct</span> <span class="title">file</span> *<span class="title">f</span> =</span> p-&gt;mmaps[i].f;</span><br><span class="line">  <span class="type">int</span> offset = p-&gt;mmaps[i].offset;</span><br><span class="line">  uint64 off = offset + va - p-&gt;mmaps[i].addr;</span><br><span class="line"></span><br><span class="line">  begin_op();</span><br><span class="line">  ilock(f-&gt;ip);</span><br><span class="line">  <span class="keyword">if</span> (readi(f-&gt;ip, <span class="number">1</span>, va, off, PGSIZE) &lt; <span class="number">0</span>) &#123;</span><br><span class="line">    panic(<span class="string">&quot;lazyalloc read.&quot;</span>);</span><br><span class="line">  &#125;</span><br><span class="line">  iunlock(f-&gt;ip);</span><br><span class="line">  end_op();</span><br><span class="line">&#125;</span><br><span class="line"></span><br><span class="line">......</span><br><span class="line">  &#125; <span class="keyword">else</span> <span class="keyword">if</span>((which_dev = devintr()) != <span class="number">0</span>)&#123;</span><br><span class="line">    <span class="comment">// ok</span></span><br><span class="line">  &#125; <span class="keyword">else</span> <span class="keyword">if</span> (r_scause() == <span class="number">13</span> || r_scause() == <span class="number">15</span>) &#123;</span><br><span class="line">    uint64 va = r_stval();</span><br><span class="line">    va = PGROUNDDOWN(va);</span><br><span class="line">    <span class="keyword">if</span> (va &gt;= p-&gt;sz || va &lt; PGROUNDUP(p-&gt;trapframe-&gt;sp)) &#123;</span><br><span class="line">      p-&gt;killed = <span class="number">1</span>;</span><br><span class="line">    &#125; <span class="keyword">else</span> <span class="keyword">if</span> (walkaddr(p-&gt;pagetable, va) == <span class="number">0</span>) &#123;</span><br><span class="line">      p-&gt;killed = <span class="number">1</span>;</span><br><span class="line">    &#125;</span><br><span class="line">    mmaplazyalloc(va);</span><br><span class="line">  &#125; <span class="keyword">else</span> &#123;</span><br><span class="line">    <span class="built_in">printf</span>(<span class="string">&quot;usertrap(): unexpected scause %p pid=%d\n&quot;</span>, r_scause(), p-&gt;pid);</span><br><span class="line">    <span class="built_in">printf</span>(<span class="string">&quot;            sepc=%p stval=%p\n&quot;</span>, r_sepc(), r_stval());</span><br><span class="line">    p-&gt;killed = <span class="number">1</span>;</span><br><span class="line">  &#125;</span><br></pre></td></tr></table></figure><p>另一个比较麻烦的地方就是对 munmap 的处理，因为题目保证了 munmap 要么截断前面，要么截断后面，不会截断中间，所以要对这几种情况进行调整：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br></pre></td><td class="code"><pre><span class="line">uint64</span><br><span class="line"><span class="title function_">sys_mmap</span><span class="params">(<span class="type">void</span>)</span></span><br><span class="line">&#123;</span><br><span class="line">  uint64 addr;</span><br><span class="line">  <span class="type">int</span> length, prot, flags, fd, offset;</span><br><span class="line">  <span class="class"><span class="keyword">struct</span> <span class="title">file</span>* <span class="title">f</span>;</span></span><br><span class="line">  <span class="keyword">if</span> (argaddr(<span class="number">0</span>, &amp;addr) &lt; <span class="number">0</span> ||</span><br><span class="line">    argint(<span class="number">1</span>, &amp;length) &lt; <span class="number">0</span> ||</span><br><span class="line">    argint(<span class="number">2</span>, &amp;prot) &lt; <span class="number">0</span> ||</span><br><span class="line">    argint(<span class="number">3</span>, &amp;flags) &lt; <span class="number">0</span> ||</span><br><span class="line">    argfd(<span class="number">4</span>, &amp;fd, &amp;f)  &lt; <span class="number">0</span> ||</span><br><span class="line">    argint(<span class="number">5</span>, &amp;offset) &lt; <span class="number">0</span>) &#123;</span><br><span class="line">    <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">  &#125;</span><br><span class="line">  <span class="keyword">if</span>((prot &amp; PROT_WRITE) &amp;&amp; !f-&gt;writable &amp;&amp; flags == MAP_SHARED)</span><br><span class="line">    <span class="keyword">return</span> <span class="number">-1UL</span>;</span><br><span class="line">  <span class="class"><span class="keyword">struct</span> <span class="title">proc</span> *<span class="title">p</span> =</span> myproc();</span><br><span class="line">  <span class="keyword">for</span> (<span class="type">int</span> i = <span class="number">0</span>; i &lt; <span class="number">16</span>; i ++) &#123;</span><br><span class="line">    <span class="keyword">if</span> (!p-&gt;mmaps[i].used) &#123;</span><br><span class="line">      <span class="class"><span class="keyword">struct</span> <span class="title">vma</span>* <span class="title">v</span> =</span> &amp;p-&gt;mmaps[i];</span><br><span class="line">      v-&gt;addr = p-&gt;sz;</span><br><span class="line">      v-&gt;length = length;</span><br><span class="line">      v-&gt;f = f;</span><br><span class="line">      v-&gt;prot = prot;</span><br><span class="line">      v-&gt;used = <span class="number">1</span>;</span><br><span class="line">      v-&gt;flags = flags;</span><br><span class="line">      v-&gt;offset = offset;</span><br><span class="line">      p-&gt;sz = p-&gt;sz + length;</span><br><span class="line">      filedup(f);</span><br><span class="line">      <span class="keyword">return</span> v-&gt;addr;</span><br><span class="line">    &#125;</span><br><span class="line">  &#125;</span><br><span class="line">  <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">&#125;</span><br><span class="line"></span><br><span class="line">uint64</span><br><span class="line"><span class="title function_">sys_munmap</span><span class="params">(<span class="type">void</span>)</span></span><br><span class="line">&#123;</span><br><span class="line">  uint64 addr;</span><br><span class="line">  <span class="type">int</span> length;</span><br><span class="line">  <span class="type">int</span> found = <span class="number">0</span>, i = <span class="number">0</span>;</span><br><span class="line">  <span class="keyword">if</span> (argaddr(<span class="number">0</span>, &amp;addr) &lt; <span class="number">0</span> ||</span><br><span class="line">    argint(<span class="number">1</span>, &amp;length) &lt; <span class="number">0</span>) &#123;</span><br><span class="line">    <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">  &#125;</span><br><span class="line">  <span class="class"><span class="keyword">struct</span> <span class="title">proc</span> *<span class="title">p</span> =</span> myproc();</span><br><span class="line">  <span class="class"><span class="keyword">struct</span> <span class="title">vma</span>* <span class="title">v</span> =</span> <span class="number">0</span>;</span><br><span class="line">  <span class="keyword">for</span> (i = <span class="number">0</span>; i &lt; <span class="number">16</span>; i++) &#123;</span><br><span class="line">    <span class="keyword">if</span> (p-&gt;mmaps[i].used) &#123;</span><br><span class="line">      uint64 start = p-&gt;mmaps[i].addr;</span><br><span class="line">      uint64 end = p-&gt;mmaps[i].addr + p-&gt;mmaps[i].length;</span><br><span class="line">      <span class="keyword">if</span> (addr &gt;= start &amp;&amp; addr &lt; end) &#123;</span><br><span class="line">        found = <span class="number">1</span>;</span><br><span class="line">        v = &amp;p-&gt;mmaps[i];</span><br><span class="line">        <span class="keyword">break</span>;</span><br><span class="line">      &#125;</span><br><span class="line">    &#125;</span><br><span class="line">  &#125;</span><br><span class="line">  <span class="keyword">if</span> (found == <span class="number">0</span>)</span><br><span class="line">    <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">  </span><br><span class="line">  <span class="keyword">for</span> (<span class="type">int</span> offset = <span class="number">0</span>; offset &lt; length; offset += PGSIZE) &#123;</span><br><span class="line">    uint64 a = addr + offset;</span><br><span class="line">    <span class="keyword">if</span> (p-&gt;mmaps[i].flags &amp; MAP_SHARED) &#123;</span><br><span class="line">        <span class="type">pte_t</span> *pte = walk(p-&gt;pagetable, a, <span class="number">0</span>);</span><br><span class="line">        <span class="keyword">if</span> (pte != <span class="number">0</span> &amp;&amp; (*pte &amp; PTE_V) != <span class="number">0</span>) &#123;</span><br><span class="line">            begin_op();</span><br><span class="line">            ilock(v-&gt;f-&gt;ip);</span><br><span class="line">            writei(v-&gt;f-&gt;ip, <span class="number">1</span>, a, a - v-&gt;addr + v-&gt;offset, PGSIZE);</span><br><span class="line">            iunlock(v-&gt;f-&gt;ip);</span><br><span class="line">            end_op();</span><br><span class="line">        &#125;</span><br><span class="line">    &#125;</span><br><span class="line">    uvmunmap(p-&gt;pagetable, a, <span class="number">1</span>, <span class="number">1</span>);</span><br><span class="line">  &#125;</span><br><span class="line">  </span><br><span class="line">  <span class="keyword">if</span> (addr == v-&gt;addr) &#123;</span><br><span class="line">    <span class="keyword">if</span> (length == v-&gt;length) &#123;</span><br><span class="line">      v-&gt;used = <span class="number">0</span>;</span><br><span class="line">      fileclose(v-&gt;f);</span><br><span class="line">    &#125; <span class="keyword">else</span> <span class="keyword">if</span> (addr + length &gt; v-&gt;addr + v-&gt;length) &#123;</span><br><span class="line">      panic(<span class="string">&quot;munmap: wrong&quot;</span>);</span><br><span class="line">    &#125; <span class="keyword">else</span> &#123;</span><br><span class="line">      v-&gt;addr = addr + length;</span><br><span class="line">      v-&gt;offset = v-&gt;offset + length;</span><br><span class="line">      v-&gt;length -= length;</span><br><span class="line">    &#125;</span><br><span class="line">  &#125; <span class="keyword">else</span> &#123;</span><br><span class="line">    <span class="keyword">if</span> (addr + length == v-&gt;addr + v-&gt;length) &#123;</span><br><span class="line">      v-&gt;length = addr - v-&gt;addr;</span><br><span class="line">    &#125; <span class="keyword">else</span> &#123;</span><br><span class="line">      panic(<span class="string">&quot;munmap: wrong&quot;</span>);</span><br><span class="line">    &#125;</span><br><span class="line">  &#125;</span><br><span class="line">  <span class="keyword">return</span> <span class="number">0</span>;</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><p>需要注意的是，我们在 writei 的时候，截断的 offset 要补入 p-&gt;mmaps[i] 的 offset 中，这样才能找到正确的 offset。</p><h1 id="lab-11-networking">Lab-11 Networking</h1><p>最后一个实验了，非常简单，就是让我们实现 e1000_transmit 和 e1000_recv 这两个函数。看起来仿佛非常复杂，实际上按照 hint 模拟就好了。</p><p>但是有个地方需要注意，在 hint 中， e1000_recv 实现方式如下：</p><blockquote><p>Some hints for implementing e1000_recv: First ask the E1000 for the ring index at which the next waiting received packet (if any) is located, by fetching the E1000_RDT control register and adding one modulo RX_RING_SIZE. Then check if a new packet is available by checking for the E1000_RXD_STAT_DD bit in the status portion of the descriptor. If not, stop. Otherwise, update the mbuf's m-&gt;len to the length reported in the descriptor. Deliver the mbuf to the network stack using net_rx(). Then allocate a new mbuf using mbufalloc() to replace the one just given to net_rx(). Program its data pointer (m-&gt;head) into the descriptor. Clear the descriptor's status bits to zero. Finally, update the E1000_RDT register to be the index of the last ring descriptor processed. e1000_init() initializes the RX ring with mbufs, and you'll want to look at how it does that and perhaps borrow code. At some point the total number of packets that have ever arrived will exceed the ring size (16); make sure your code can handle that.</p></blockquote><p>但是在实际实现中，net_rx要先release锁之后才能去运行，不然就会panic：</p><figure class="highlight c"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">static</span> <span class="type">void</span></span><br><span class="line"><span class="title function_">e1000_recv</span><span class="params">(<span class="type">void</span>)</span></span><br><span class="line">&#123;</span><br><span class="line">  <span class="keyword">while</span> (<span class="number">1</span>) &#123;</span><br><span class="line">    acquire(&amp;e1000_lock);</span><br><span class="line">    uint32 idx = (regs[E1000_RDT] + <span class="number">1</span>) % RX_RING_SIZE;</span><br><span class="line">    <span class="class"><span class="keyword">struct</span> <span class="title">rx_desc</span> *<span class="title">desc</span> =</span> &amp;rx_ring[idx];</span><br><span class="line">    <span class="keyword">if</span> (!(desc-&gt;status &amp; E1000_RXD_STAT_DD)) &#123;</span><br><span class="line">        release(&amp;e1000_lock);</span><br><span class="line">        <span class="keyword">break</span>;</span><br><span class="line">    &#125;</span><br><span class="line">    rx_mbufs[idx]-&gt;len = desc-&gt;length;</span><br><span class="line">    <span class="class"><span class="keyword">struct</span> <span class="title">mbuf</span> *<span class="title">m</span> =</span> rx_mbufs[idx];</span><br><span class="line">    rx_mbufs[idx] = mbufalloc(<span class="number">0</span>);</span><br><span class="line">    desc-&gt;addr = (uint64) rx_mbufs[idx]-&gt;head;</span><br><span class="line">    desc-&gt;status = <span class="number">0</span>;</span><br><span class="line">    regs[E1000_RDT] = idx;</span><br><span class="line">    release(&amp;e1000_lock);</span><br><span class="line">    net_rx(m);</span><br><span class="line">  &#125;</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure>]]></content>
    
    
    <summary type="html">&lt;p&gt;在几个月前，试图开了一下MIT 6.828的坑，但是因为各种原因，只做了lab1就搁置了。前几天突然在知乎上看到了&lt;a href=&quot;https://zhuanlan.zhihu.com/p/251366985&quot;&gt;二十八画生征友：一起来通关6.S081/6.828吧~&lt;/a&gt;，发现这个MIT 6.S081是MIT 6.828的简化版，而且梯度更加的平滑。&lt;/p&gt;</summary>
    
    
    
    <category term="操作系统" scheme="https://wyc-ruiker.github.io/categories/%E6%93%8D%E4%BD%9C%E7%B3%BB%E7%BB%9F/"/>
    
    
    <category term="OS" scheme="https://wyc-ruiker.github.io/tags/OS/"/>
    
  </entry>
  
  <entry>
    <title>GNN for Source Code Modeling（三）</title>
    <link href="https://wyc-ruiker.github.io/2020/03/09/gnn-for-source-code-modeling-3/"/>
    <id>https://wyc-ruiker.github.io/2020/03/09/gnn-for-source-code-modeling-3/</id>
    <published>2020-03-09T12:51:24.000Z</published>
    <updated>2021-12-16T11:28:06.000Z</updated>
    
    <content type="html"><![CDATA[<p>实际上 GNN 在 Source Code 上的应用和创新还有很多。之前的两篇文章都是关于 GNN 建图以及 GNN 跟其他任务相结合的工作，这篇文章就讲一下对 GNN 本身的创新。</p><span id="more"></span><p><a href="https://arxiv.org/abs/1904.12787">Graph Matching Networks for Learning the Similarity of Graph Structured Objects. ICML 2019</a> 就是一个利用 GNN 解决二进制函数相似性问题的工作。这个问题也是 Source Code Modeling 领域的经典问题之一。 二进制函数相似性是一个在信息安全领域应用很广泛的问题。因为很多软件都是不开源的，放在用户电脑上的只能是一些二进制代码。因为编译器、编译选项以及平台的不同，同一个函数的二进制代码也经常是不同的。如果一个函数被检查出了安全漏洞，那这个函数所编译出的所有二进制代码也会有安全漏洞。这些二进制代码放在成千上万用户的电脑上，造成非常大的安全隐患，所以及时找到这些代码是非常重要的。因为二进制代码本身会有一个 Control-flow-graph，所以利用 GNN 解决二进制函数相似性就成为 GNN 的一个应用。</p><p><img src="/2020/03/09/gnn-for-source-code-modeling-3/1-1.png"></p><p>论文中给出两种方式来解决这个问题。第一种是输入一个图输出一个 embedding，通过优化这个 embedding 使得两个相似的图的 embedding 会离得更近。第二种更加直接，输入两个图，输出他们的相似性。这两种方式都是在 GNN 的基础上做的。</p><p><img src="/2020/03/09/gnn-for-source-code-modeling-3/2-1.png"></p><p>第一种方法计算 embedding 的过程没啥好说的，就是输入<span class="math inline">\((G_1, G_2)\)</span>输出<span class="math inline">\((h_{G_1},h_{G_2})\)</span>。重点在于最后的 loss 设计，因为跟第二种可以共用相同的 loss，所以放到一起最后讲。 对于第二种方法，大多数步骤跟 GNN 都是一样的。主要区别在于在聚合邻居信息的时候，也要聚合另一张图的信息 <span class="math inline">\(\mu\)</span>：</p><p><img src="/2020/03/09/gnn-for-source-code-modeling-3/3-1.png"></p><p>其中<span class="math inline">\(\mu\)</span>的计算方式如下，上面的<span class="math inline">\(f_s\)</span>和下面的<span class="math inline">\(s_h\)</span>都是可以替换的向量相似度计算方式：</p><p><img src="/2020/03/09/gnn-for-source-code-modeling-3/4-1.png"></p><p>观察公式<span class="math inline">\((11)\)</span>可以发现，attention 值越大，两个点越相似。所以这个<span class="math inline">\(\Sigma\mu\)</span>其实就是该点的 embedding 跟另一张图最相似的点的 embedding 的差。如果两个图完全一样，那这个<span class="math inline">\(\Sigma\mu\)</span>就会一直是<span class="math inline">\(0\)</span>。所以这个 GMN 的优势就是通过一个图的表示可以更改另一个图的表示，从而捕捉两个图的不相似程度。 下面是两种方法的 loss。论文定义了两种 label，第一种是<span class="math inline">\((G_1,G_2,t)\)</span>，两个图相似<span class="math inline">\(t=1\)</span>，不相似<span class="math inline">\(t=-1\)</span>。第二种是<span class="math inline">\((G_1,G_2,G_3)\)</span>，其中<span class="math inline">\(G_1\)</span>跟<span class="math inline">\(G_2\)</span>更加相似。如果向量相似性用欧拉距离来度量，那可以用一种类似合页 loss 的方式来进行优化：</p><p><img src="/2020/03/09/gnn-for-source-code-modeling-3/5-1.png"></p><p><img src="/2020/03/09/gnn-for-source-code-modeling-3/6-1.png"></p><p>公式<span class="math inline">\((12)\)</span>可以看做当两个图相似的时候，距离应该小于<span class="math inline">\(1-\gamma\)</span>，当两个图不相似的时候，距离应该大于<span class="math inline">\(1+\gamma\)</span>。 公式<span class="math inline">\((13)\)</span>可以看做<span class="math inline">\(d(G_1,G_2)\lt d(G_1,G_3)-\gamma\)</span>。</p><p><img src="/2020/03/09/gnn-for-source-code-modeling-3/7-1.png"> 也可以按照上面的公式，用 Hamming 相似度来算，这样的好处是向量里面每一维度都是<span class="math inline">\([1,-1]\)</span>，容易在大型数据库中快速的查询。</p><p><img src="/2020/03/09/gnn-for-source-code-modeling-3/8-1.png"></p><p>可以看到，效果有一定的提升。这里的 baseline 是 Google 的一个二进制代码查询工具，利用手工构造的图 hash 来寻找相同的代码。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;实际上 GNN 在 Source Code 上的应用和创新还有很多。之前的两篇文章都是关于 GNN 建图以及 GNN 跟其他任务相结合的工作，这篇文章就讲一下对 GNN 本身的创新。&lt;/p&gt;</summary>
    
    
    
    <category term="机器学习" scheme="https://wyc-ruiker.github.io/categories/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/"/>
    
    
    <category term="program" scheme="https://wyc-ruiker.github.io/tags/program/"/>
    
    <category term="graph" scheme="https://wyc-ruiker.github.io/tags/graph/"/>
    
  </entry>
  
  <entry>
    <title>GNN for Source Code Modeling（二）</title>
    <link href="https://wyc-ruiker.github.io/2020/03/08/gnn-for-source-code-modeling-2/"/>
    <id>https://wyc-ruiker.github.io/2020/03/08/gnn-for-source-code-modeling-2/</id>
    <published>2020-03-07T17:18:25.000Z</published>
    <updated>2021-12-16T11:26:15.000Z</updated>
    
    <content type="html"><![CDATA[<p>这篇讲一个基于<a href="https://reku1997.gitee.io/2020/03/07/gnn-for-source-code-modeling-1/">上一篇</a>的改进工作。 在 Source Code 中，因为程序员的变量命名通常来讲都比较诡异，所以存在着比较严重的 open vocabulary 问题（也叫作 Out of Vocabulary）。比如一个变量名叫做 LianlianInput，因为这个 Lianlian 不在词汇表里面，VARNAMING 的时候就根本不会输出这个 subtoken，对最后的效果有比较严重的影响。</p><span id="more"></span><p><a href="https://arxiv.org/abs/1810.08305">Open Vocabulary Learning on Source Code with a Graph-Structured Cache. ICML 2019</a> 这篇文章就致力于解决代码中的 open vocabulary 问题，而且是通过 GNN 来解决，与我们的主题十分契合。 整个建图过程跟之前的工作非常类似，只是多了 Graph-Structured Cache Node：</p><p><img src="/2020/03/08/gnn-for-source-code-modeling-2/5.png"></p><p>其实也是先分词，然后相同的词搞一个节点而已。初始的特征是节点名字特征跟Cache节点类型本身有个特征进行拼接。节点名字特征用 CharCNN 来计算。 建图层面还是比较简单的，这篇文章主要的贡献之处在于利用 GSC 节点来解决 open vocabulary 的问题。解决的方法在于修改 GNN 最后一步<span class="math inline">\(y=g(\{h_v^t\})\)</span>中<span class="math inline">\(g\)</span>的计算方式，灵感来自于 Pointer Network。为了方便之后理解，先简单叙述一下啥是 Pointer Network。 这个 Pointer Network 本来是做组合优化的，最开始用来解决凸包问题。凸包问题我们都知道，输入一个点集，输出一个凸包。但是这个输出的范围其实是跟输入相关的。Pointer Network 用了很简单的机制解决了这个问题：</p><p><img src="/2020/03/08/gnn-for-source-code-modeling-2/6.png"></p><p>在传统的 attention 中，都是 encoder 和 decoder 的 hidden layer 算个权重，然后组合一下 encoder 的所有权重输入到 decoder 中。这里的输出就直接把 attention 的最大权重作为其中一步的输出，并且输入到 encoder 中继续形成新的 hidden layer。 下面就看看如何利用 GSC 解决 VARMISUSE 的问题。在前一篇文章中，这个 VARMISUSE 问题解决起来其实还是比较复杂的。这篇文章用 Java 项目作为数据集。因为 Java 跟垃圾语言 Python 不一样，是一个要先声明后调用的语言。所以对单独语法槽的预测，都可以变成一个 Pointer Network 问题，指向现存的变量节点。对于 GGNN 来说，简单按照<span class="math inline">\(y=\sigma(f_1(h_v^t,h_v^0)\odot f_2(h_v^t))\)</span>来计算一下 attention 权重（其实我不太懂这个 attention 为什么跟槽的 embedding 无关，可能是因为 GGNN 论文里面的 readout attention 就是这么做的），挑其中最大的几个作为输出即可。</p><p><img src="/2020/03/08/gnn-for-source-code-modeling-2/7.png"></p><p>可以看到，加上 GSC 之后效果有一定的提升。 解决 VARNAMING 的思路跟 Pointer Sentinel Mixture Model 非常相似。Pointer Sentinel Mixture Model 就是把纯 attention 求出来的 open vocabulary 分布和正常 attention 求出来的 close vocabulary 分布加到一起，来预测要产生的序列：</p><p><img src="/2020/03/08/gnn-for-source-code-modeling-2/9.png"></p><p>这个 open vocabulary 和 close vocabulary 组合的权重利用一个名为 sentinel 的虚拟输入的 attention 值来计算。 对于 VARNAMING 来说，也是按照上一篇文章的方法建图，对于所有要命名的变量 embedding 取一个平均值，然后作为输入放到 GRU 中。close vocabulary 的分布按照正常的方式产生，Pointer Network 的 attention 权重就是把每个 GSC 或者 sentinel 的 embedding 输入到一个线性层，然后跟 hidden layer 点积一下接一个 softmax。最后的公式就是：<span class="math display">\[P(w|h)=P_{graph}(s|h)P_{graph}(w|h)+(1-P_{graph}(s|h))P_{vocab}(w|h)\]</span></p><p><img src="/2020/03/08/gnn-for-source-code-modeling-2/8.png"></p><p>可以看出，实验效果提升巨大。可能看起来还是不怎么样，但是要考虑到这个问题的难度，做成这样就不错了。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;这篇讲一个基于&lt;a href=&quot;https://reku1997.gitee.io/2020/03/07/gnn-for-source-code-modeling-1/&quot;&gt;上一篇&lt;/a&gt;的改进工作。 在 Source Code 中，因为程序员的变量命名通常来讲都比较诡异，所以存在着比较严重的 open vocabulary 问题（也叫作 Out of Vocabulary）。比如一个变量名叫做 LianlianInput，因为这个 Lianlian 不在词汇表里面，VARNAMING 的时候就根本不会输出这个 subtoken，对最后的效果有比较严重的影响。&lt;/p&gt;</summary>
    
    
    
    <category term="机器学习" scheme="https://wyc-ruiker.github.io/categories/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/"/>
    
    
    <category term="program" scheme="https://wyc-ruiker.github.io/tags/program/"/>
    
    <category term="graph" scheme="https://wyc-ruiker.github.io/tags/graph/"/>
    
  </entry>
  
  <entry>
    <title>GNN for Source Code Modeling（一）</title>
    <link href="https://wyc-ruiker.github.io/2020/03/07/gnn-for-source-code-modeling-1/"/>
    <id>https://wyc-ruiker.github.io/2020/03/07/gnn-for-source-code-modeling-1/</id>
    <published>2020-03-07T10:10:17.000Z</published>
    <updated>2021-12-16T11:27:58.000Z</updated>
    
    <content type="html"><![CDATA[<p>关于 Source Code 的 learning 其实已经有很多工作了，每年的顶会中也有很多这个方面的文章。其实针对 Source Code 的 learning 可以算是 NLP 的一个子领域，因为 Source Code 本身就是就是程序员之间交流的一种语言。因为 Code 是一个结构化数据，存在着语义信息与语法信息，所以相比于自然语言来说，Source Code 是适合于 GNN 大显神威的领域。</p><span id="more"></span><p><a href="https://arxiv.org/abs/1711.00740">Learning to Represent Programs with Graphs. ICLR 2018</a> 是 GNN 在 Source Code 中运用的比较早的工作，质量也蛮高。 首先定义一下这篇文章要解决的两个问题： 第一个问题叫做 VARNAMING，就是在一段代码中有个匿名的变量，然后要通过这个变量在代码中的行为来对这个变量命名。 第二个问题叫做 VARMISUSE，类似于程序填空，预测一个程序中空缺的 token 是什么。通过这个任务，可以发现代码中一些 misuse 性质的 bug，举例如下：</p><p><img src="/2020/03/07/gnn-for-source-code-modeling-1/1.png"></p><p>当然，对于第二个任务来说，合法的解可能有好几个，就跟考试的程序填空中有好多正确答案一样。 根据文章的标题，可以想象文章的内容就是把 program 建成一个图，然后在图上面跑 GNN。文章中所采取的的 GNN 是 GGNN (Gated Graph Neural Networks)，可能是因为这个 GGNN 比较适合 program 建的图，所以很多后续工作采用的也是这个方法。 图定义为<span class="math inline">\(G=(V,E,X)\)</span>，其中<span class="math inline">\(V\)</span>是节点、<span class="math inline">\(X\)</span>是特征，<span class="math inline">\(E=(E_1,...,E_k)\)</span>是边的集合，边有<span class="math inline">\(k\)</span>种。 每个节点状态为<span class="math inline">\(h^{(v)}\)</span>，初始状态就是<span class="math inline">\(x^{v}\)</span>。跟传统的 GNN 一样，每个节点会向外发送<span class="math inline">\(k\)</span>的类型的消息<span class="math inline">\(m_k^{(v)}=f_k(h^{v})\)</span>，然后每个节点聚合邻居的消息<span class="math inline">\(\widetilde{m}^{(v)}=g(\{m_k^{(u)}\}\mid{(u,v,k)\in E})\)</span>。 特殊之处在于更新<span class="math inline">\(h^{(v)}\)</span>的方式，GGNN 用的是 GRU 单元来更新特征向量。<span class="math inline">\(h^{(v)}=GRU(\widetilde{m}^{(v)},h^{v})\)</span>。这样可以捕捉到一些较远的点对该点的影响。 在该论文中，<span class="math inline">\(f_k\)</span>是一个线性函数，<span class="math inline">\(g\)</span>是个简单的求和。 这篇论文中比较炫酷的部分是建图。首先一个程序，自然对应一个 AST，这个 AST 的叶子节点是程序中的 token，中间节点对应着 BNF 的中间节点。因为 GNN 捕捉不了树的儿子的顺序，所以要加一个 NextToken 边来把 token 都串起来：</p><p><img src="/2020/03/07/gnn-for-source-code-modeling-1/2.png"></p><p>另一方面，我们要捕捉程序的 Data-flow 信息。对于每个变量来说，上一次 read 的位置，连接一个 LastRead 边，因为有分支结构存在，这种边可能有多条。同样，上一次 write 的位置，连接一个 LastWrite 边。有赋值语句存在的时候，左右两边的语句要连接一个 ComputedFrom 边：</p><p><img src="/2020/03/07/gnn-for-source-code-modeling-1/3.png"></p><p>作者还加了很多乱七八糟的边，比如用 LastLexicalUse 来把所有同一个变量的调用都串起来。return 后面接的变量也会通过 ReturnTo 边连接到方法的声明上。对于形如 Foo(bar) 和 Foo(InputStream Stream) 这样的方法调用与声明，bar 也会连接到 stream 上。最后，对于 if(x&gt;y){...x...}else{...y...} 这样的语句，x 向条件节点连一个 GuardedBy 边，y 向条件节点连接一个 GuardedByNegation 的边。 最后再把所有反向边都加入，也就是形成一个无向图。这个图就处理好了。 这篇文章做实验用的是 C# 的一些项目。跟 Python 这种垃圾语言不一样，C# 每个变量是有固定的类型信息的，这种类型信息显然是可以运用的。作为一个高贵的 OOP 语言，C# 的类型还是有层级的。对于一个类型<span class="math inline">\(\tau\)</span>，有一个 embedding 函数 <span class="math inline">\(r(\tau)\)</span>。因为这个类型具有层次结构，所以可以搞一个集合<span class="math inline">\(\tau^{*}(v)\)</span>，里面具有<span class="math inline">\(v\)</span>本身的类型和<span class="math inline">\(v\)</span>所有的父类型的 embedding，然后对这个集合每一维度取个最大值，作为变量<span class="math inline">\(v\)</span>的类型特征。可以用类似 dropout 的方法来进行优化。 变量<span class="math inline">\(v\)</span>还具有变量名，可以把变量<span class="math inline">\(v\)</span>的变量名进行分词，分出来一堆 subtoken。这些 subtoken 的表示取一个平均作为变量名的特征。把变量名特征跟类型特征连接起来，得到每个节点的最初表示。 整个网络的结构有了，下面就是用这个网络来解决 VARNAMING 跟 VARMISUSE 了。 对于 VARNAMING 来说，可以把要命名的变量名替换为 SLOT token。然后跑一遍 GNN，对于每一处变量的表示取个平均值。这个值作为一个 GRU 的输入，来生成一堆 subtoken 作为变量名。这样就转变为一个 GraphToSeq 的问题。 VARMISUSE 的问题会相对复杂一点。首先也是把目标位置<span class="math inline">\(c(t)\)</span>替换为一个匿名的 SLOT 变量。正常连边的时候，变量相关的变量应该都不会连接到 SLOT 变量上。然后将候选集中的每个变量<span class="math inline">\(v_{t,v}\)</span>加入到这个图中，并连接好变量相关的边。在这个图上跑 GNN，可以得到<span class="math inline">\(h^{SLOT}\)</span>和<span class="math inline">\(h^{v_{t,v}}\)</span>。最后通过<span class="math inline">\(argmax_v W[h^{SLOT}, h^{v_{t,v}}]\)</span>来找到正确的变量。</p><p><img src="/2020/03/07/gnn-for-source-code-modeling-1/4.png"></p><p>相比一些简单的 baseline，这个方法效果提升巨大。值得注意的是，加入的各种边和节点的 embedding 也非常的 make sense，对最后的效果很有帮助。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;关于 Source Code 的 learning 其实已经有很多工作了，每年的顶会中也有很多这个方面的文章。其实针对 Source Code 的 learning 可以算是 NLP 的一个子领域，因为 Source Code 本身就是就是程序员之间交流的一种语言。因为 Code 是一个结构化数据，存在着语义信息与语法信息，所以相比于自然语言来说，Source Code 是适合于 GNN 大显神威的领域。&lt;/p&gt;</summary>
    
    
    
    <category term="机器学习" scheme="https://wyc-ruiker.github.io/categories/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/"/>
    
    
    <category term="GNN" scheme="https://wyc-ruiker.github.io/tags/GNN/"/>
    
    <category term="program" scheme="https://wyc-ruiker.github.io/tags/program/"/>
    
  </entry>
  
  <entry>
    <title>AutoTVM 探秘 （三）</title>
    <link href="https://wyc-ruiker.github.io/2020/01/02/autotvm-3/"/>
    <id>https://wyc-ruiker.github.io/2020/01/02/autotvm-3/</id>
    <published>2020-01-02T08:48:01.000Z</published>
    <updated>2020-07-20T08:26:40.000Z</updated>
    
    <content type="html"><![CDATA[<p>对于目前的优化算法来说，依然存在着许多问题。但是后续的工作并不是特别多。首先可以看一下 <a href="https://arxiv.org/abs/1905.12799">Reinforcement Learning and Adaptive Sampling for Optimized DNN Compilation, ICML 2019 Workshop RL4RealLife</a> 这篇文章主要谈到的问题是两点：1. 开发一个更有效的搜索算法（相对于 AutoTVM 的模拟退火） 2. 减少硬件测试的时间。从这篇文章的实验结果来看，第一个目标基本上没有达成，第二个目标完成的还不错。这个第二个目标也是我认为的之后优化的核心问题。 文章的整个框架如下图，主要贡献是两个蓝色的部分——基于强化学习的搜索和自适应采样。</p><span id="more"></span><p><img src="/2020/01/02/autotvm-3/屏幕快照-2020-01-02-上午11.50.38.png"></p><p>整个过程不是很复杂，这个强化学习其实就是用来代替模拟退火的。迭代数次 Policy Network，输入是一个当前算子的 config，也就是之前说的 schedule，输出是对 config 的上下调整。强化学习要从环境里面获取代价，这个代价其实就是从 cost model 里面预测出来的每个 config 的运行时间，再用 PPO 的方式去训练 Policy Network，这个强化学习套路感觉非常的强行，最后效果也一般般。 然后对所有的 config 进行自适应采样，只对采样出来的 config 在硬件上测试实际的运行时间，然后将采样出来的 config 用于 cost model 的训练。这个所谓的自适应采样其实非常简单，就是对所有 config 做一个 k-means，然后采样每个centroid。</p><p><img src="/2020/01/02/autotvm-3/屏幕快照-2020-01-02-下午12.01.27.png"></p><p><img src="/2020/01/02/autotvm-3/屏幕快照-2020-01-02-下午12.01.34.png"></p><p>从上面两个图可以看出，第二步自适应采样的动机还是比较强的。因为对于 AutoTVM 来说，大量的时间都用于在硬件上测试算子的运行时间，而且相似的算子 config 确实很多，所以通过聚类然后采样的想法确实比较直接。最后的加速效果也不错：</p><p><img src="/2020/01/02/autotvm-3/屏幕快照-2020-01-02-下午12.04.10.png"></p><p>通过结果可以看出来，虽然最后的加速效果很不错，但是对于结果的优化程度几乎没什么变化。说明第二步自适应采样很有效，但是第一步强化学习其实没什么用。讽刺的是这个进的还是 RL4RealLife Workshop... 对于 AutoTVM 来说，目前最主要的问题还是 tuning 的时间过慢。所以 AutoTVM 只能用于 inference，不能用于 training。因为你 tuning 的时间很有可能就比 training 的时间长了...</p><p><img src="/2020/01/02/autotvm-3/屏幕快照-2020-01-02-下午12.07.50.png"></p><p>从上面的图中可以看到，tuning 一次 MobileNet，在 V100 上面都要花差不多 19 个小时，非常缓慢。在我们实验室这种显卡上面，大概就要两到三天了。</p><p><img src="/2020/01/02/autotvm-3/屏幕快照-2020-01-02-下午4.16.54.png"></p><p>在 2019 年 12 月 5 日结束的<a href="https://sampl.cs.washington.edu/tvmconf/">第二届TVM与深度学习编译器会议</a>上面，也有一个思路类似的 talk。是来自 AWS 的 <a href="https://sampl.cs.washington.edu/tvmconf/slides/2019/E07-Cody-Yu.pdf">Improving AutoTVM Efficiency by Schedule Sharing</a>。跟上面的那篇文章非常类似，也是用聚类去优化 AutoTVM。 从上面的图中可以看出，对于每个从模型中抽取的 task，都要进行 turning。这个工作的动机是，如果一个 schedule 在一个 conv2d 上面效果良好，那他在另一个 conv2d 上面的效果应该也还不错。这意味着可以利用一些有代表性的任务来 turning，然后把该任务的 schedule 直接迁移到相似的任务上面去。这里的距离计算方式是 turning space 的重叠比率，然后利用这个距离来聚类。</p><p><img src="/2020/01/02/autotvm-3/屏幕快照-2020-01-02-下午3.46.39.png"></p><p><img src="/2020/01/02/autotvm-3/屏幕快照-2020-01-02-下午3.46.54.png"></p><p>从上面这个图中可以看出，Schedule Sharing 可以在平均 28% 的调整时间里获得 84% 的加速效果。相关讨论和 PR 在 <a href="https://github.com/apache/incubator-tvm/issues/4188">github issue</a> 上面可以找到。 新的搜索方法，还有一篇 <a href="https://arxiv.org/abs/1909.10616">Compiler-Level Matrix Multiplication Optimization for Deep Learning, arXiv</a>。这篇文章只把问题限制在了调 MM 的 tile size 上面，整个搜索空间与问题范围缩小了很多，而且去掉了 cost model，直接用强化学习来指导搜索，其实就相当于有一个很慢但是很准确的 cost model，跟前面讲的第一篇非常相似。（虽然这个 cost model 可能比 XGB 慢了差不多 1000 倍吧 XD） 总结一下上面的几个工作，几乎都是采取了一些很简单的做法，就对 AutoTVM 的整个 turning 时间起到了巨大的提升。说明这方面研究的潜力还是非常大的，如果能压缩到五分钟 turning 完一个网络，说不定就可以用 AutoTVM 来帮助 training。（当然现在来看还都是空谈，因为最好的工作也就是四到五倍的压缩效率，从两天变成半天）</p><p>说完了对搜索方法和训练采样方法的一些魔改方法之后，下面应该要说一些对于 AutoTVM 的核心 cost model 的魔改方法了。 目前对于 cost model 的研究集中于 GNN 上面，<a href="https://sampl.cs.washington.edu/tvmconf/">第二届TVM与深度学习编译器会议</a>上面有一个 UW 的 Talk，题目是 <a href="https://sampl.cs.washington.edu/tvmconf/slides/2019/E01-Eddie-Yan.pdf">Graph Convolutional Cost Models for TVM</a> ，还有一篇 <a href="https://arxiv.org/abs/1904.11876">Simulating Execution Time of Tensor Programs using Graph Neural Networks，ICLR 2019 workshop at Representation Learning on Graphs and Manifolds</a>。这两个工作基本是一样的，都是用 GCN 去优化 cost model。然而怪异的是两个工作的结果都是跟 XGB 在一个类似于估计运行时间的数据集上面的对比，没有最后 end-to-end 的效果提升。 两个工作都是用 AST 建图，大概长这样：</p><p><img src="/2020/01/02/autotvm-3/屏幕快照-2020-01-02-下午4.19.55.png"></p><p>然后用 GCN 求一下 embedding，然后把所有 embedding 都平均一下，然后接个 MLP...</p><p><img src="/2020/01/02/autotvm-3/屏幕快照-2020-01-02-下午4.33.14.png"></p><p>然后没了。</p><p><img src="/2020/01/02/autotvm-3/屏幕快照-2020-01-02-下午4.35.55.png"></p><p>这效果看起来其实也就那样，而且还没测 end-to-end 的性能，估计是实在不能看。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;对于目前的优化算法来说，依然存在着许多问题。但是后续的工作并不是特别多。首先可以看一下 &lt;a href=&quot;https://arxiv.org/abs/1905.12799&quot;&gt;Reinforcement Learning and Adaptive Sampling for Optimized DNN Compilation, ICML 2019 Workshop RL4RealLife&lt;/a&gt; 这篇文章主要谈到的问题是两点：1. 开发一个更有效的搜索算法（相对于 AutoTVM 的模拟退火） 2. 减少硬件测试的时间。从这篇文章的实验结果来看，第一个目标基本上没有达成，第二个目标完成的还不错。这个第二个目标也是我认为的之后优化的核心问题。 文章的整个框架如下图，主要贡献是两个蓝色的部分——基于强化学习的搜索和自适应采样。&lt;/p&gt;</summary>
    
    
    
    <category term="system" scheme="https://wyc-ruiker.github.io/categories/system/"/>
    
    <category term="机器学习" scheme="https://wyc-ruiker.github.io/categories/system/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/"/>
    
    
    <category term="AutoTVM" scheme="https://wyc-ruiker.github.io/tags/AutoTVM/"/>
    
    <category term="TVM" scheme="https://wyc-ruiker.github.io/tags/TVM/"/>
    
  </entry>
  
  <entry>
    <title>AutoTVM 探秘（二）</title>
    <link href="https://wyc-ruiker.github.io/2019/12/31/autotvm-2/"/>
    <id>https://wyc-ruiker.github.io/2019/12/31/autotvm-2/</id>
    <published>2019-12-31T06:02:40.000Z</published>
    <updated>2021-12-16T12:08:27.000Z</updated>
    
    <content type="html"><![CDATA[<p>好了，本篇开始进入正题！内容基本都来自于：<a href="https://arxiv.org/abs/1805.08166">Learning to Optimize Tensor Programs. NeurlPS`18</a></p><span id="more"></span><h2 id="问题定义">问题定义</h2><p><a href="https://reku1997.gitee.io/2019/12/30/autotvm-1/">上一篇文章</a>讲了 AutoTVM 的大致问题，现在给出数学上面的描述。 首先有一个 <span class="math inline">\(\mathcal{E}\)</span> 代表所有可能的计算，<span class="math inline">\(e\in\mathcal{E}\)</span> 就是我们要去优化的计算。对于这个 <span class="math inline">\(e\)</span> 来说，有一个合法的 schedule space 叫 <span class="math inline">\(\mathcal{S}_e\)</span>，其中每个合法的 schedule 就叫 <span class="math inline">\(s\in\mathcal{S}_e\)</span>。<span class="math inline">\(x=g(e,s)\)</span> 是 <span class="math inline">\(e\)</span> 跟 <span class="math inline">\(s\)</span> 通过编译器 <span class="math inline">\(g\)</span> 生成的 low-level code，<span class="math inline">\(f(x)\)</span> 为这个 <span class="math inline">\(x\)</span> 在硬件上面实际运行的时间，那么该问题的定义则变成：<span class="math display">\[\underset{s\in\mathcal{S}_e}{\operatorname{argmin}}f(g(e,s))\]</span></p><p>这个问题跟现在很多人研究的 hyper-parameter optimization 非常相似，然而 paper 的作者认为，该问题跟传统的 hyper-parameter optimization 有以下几个区别： 第一个区别就是 tensor optimization 比传统的 hyper-parameter optimization 要快很多。因为超参数搜索优化的目标是神经网络的效果，所以训练一次其实是非常慢的，所以超参数搜索可以尝试很多很复杂的方法来进行优化。（比如 GP-UCB 这种，一次 GP kernel regression 要对一个协方差矩阵求逆，实际上是非常慢的，在<a href="https://www.zhihu.com/question/33711002">为什么基于贝叶斯优化的自动调参没有大范围使用？</a>上面有一定的讨论）而 tensor optimization 这个问题，其实你把真的 tensor 程序放到机器上面跑，最慢其实也就几秒。结果你搞了个复杂的方法，要好久才能预测出来结果，那我还不如真的把程序直接放在机器上面跑呢。不过这样的好处是我们可以获得比超参数搜索多得多的数据。 第二个重大的区别是，对于 hyper-parameter optimization 来说，神经网络就是个黑盒子，我们只能根据一些概率的理论去乱调。而对于 tensor optimization 来说，我们有 AST，这是一个非常有力的信息，因为一切运算的秘密其实都隐藏在 AST 里面。 第三个区别是，tensor optimization 的任务之间其实都是相似的，可以进行 transfer learning。 在上一篇文章中我们看到，其实可能去调整的参数还是很多的，这些参数乘起来会变成非常巨大的搜索空间 <span class="math inline">\(\mathcal{S}_e\)</span>。我们的目的就是在这个巨大的搜索空间中找到最好的 <span class="math inline">\(s\)</span>。</p><h2 id="搜索框架">搜索框架</h2><p>paper 作者提出的框架是，先搞一个 cost model <span class="math inline">\(\hat{f}(x)\)</span>，然后用这个 <span class="math inline">\(\hat{f}(x)\)</span> 去指导搜索出 <span class="math inline">\(s_i\)</span>，再把 <span class="math inline">\((e_i, s_i)\)</span> 放到机器上面跑造出 <span class="math inline">\(c_i\)</span>，然后再去更新 <span class="math inline">\(\hat{f}(x)\)</span>。</p><p><img src="/2019/12/31/autotvm-2/屏幕快照-2019-12-30-下午9.28.57.png"></p><p>对于这个 cost model，作者搞了两种实现。第一种是陈天奇的传统艺能——XGBoost，第二种是 TreeGRU。（顾名思义，以前的 GRU 或者 LSTM 都是一个输入，这个有多个输入，然后公式小变了一下，其实都大差不差）因为在实际的运用中，TreeGRU 实在是速度有点不行，所以根本都没有 merge 到 master 上面去，在 github 上面版本其实只有 XGBoost。 很显然，这个 cost model 的输出应该是一个预测该程序在硬件上运行的时间。对于最后的 loss，作者也实现了两种方式，第一种是传统的 regression loss：<span class="math display">\[\sum_i(\hat{f}(x_i)-c_i)^2\]</span></p><p>第二种则是只考虑他们的相对快慢：<span class="math display">\[\sum_{i,j}log(1+e^{-sign(c_i-c_j)(\hat{f}(x_i)-\hat{f}(x_j))})\]</span></p><p>实验表明，两种 loss 效果差不多。 因为这个搜索空间 <span class="math inline">\(\mathcal{S}_e\)</span> 很大，我们不能枚举整个空间。这里作者使用的方法是，先在 cost model 的指导下通过模拟退火搞出一个候选集，然后再选出来一个相对比较优的集合在硬件上面进行测试，最后更新 cost model。 那么要如何定义一个相对比较优的集合呢？这个集合要同时兼顾 quality 和 diversity。作者给出的最大化式子是这样的：<span class="math display">\[L(S)=-\sum_{s\in\mathcal{S}}\hat{f}(g(e,s))+\alpha\sum_{j=1}^m\left|{\cup_{s\in\mathcal{S}}\{s_j\}}\right|\]</span></p><p>这个东西如果不看代码，其实很难知道他到底在干什么东西。看过代码就知道这个 <span class="math inline">\(s\)</span> 其实已经经过特征抽取，被平铺为一个向量了，然后 <span class="math inline">\(m\)</span> 就是把这个平铺的向量切成 <span class="math inline">\(m\)</span> 段。这个式子的意义就是，使得每个子段都尽可能的不一样，并且运行速度还要尽量小。 为什么这个式子要设计成这个样子，看起来不是非常奇怪吗？而且这个式子要怎么优化呢？其实原因就在于如何优化这个式子上面，这个式子是一个 submodular function，可以通过贪心求一个还算凑合的近似解，具体可以看<a href="https://www.zhihu.com/question/34720027">怎么理解次模函数 submodular function？</a> 整个算法的大致过程如下：</p><p><img src="/2019/12/31/autotvm-2/屏幕快照-2019-12-30-下午9.44.16.png"></p><h2 id="贝叶斯优化">贝叶斯优化</h2><p>对于超参数搜索来说，最广为应用的方式就是贝叶斯优化。那么这个问题能不能套贝叶斯优化呢？在文章中，作者通过 bootstrap 搞出好几个 GBDT，然后通过在多个 GBDT 上面输出的预测值来采取 EI 或者 UCB 等函数，再通过上面的那种过程搜索函数的最值。虽然看起来有些奇怪（GP-UCB 那种其实是可以求出 UCB 的解析解），但总体来说其实还是符合贝叶斯优化的精神的。但是实验效果表明，用不用贝叶斯优化，效果其实都差不太多。</p><h2 id="迁移学习">迁移学习</h2><p>他这个迁移学习模块，我觉得写的其实有点奇怪。后面做的实验也不过是对于不同 size 的卷积进行了迁移学习。然而实际上卷积运算的形式是固定的，cost model 测试的也是同一台机器上面的运算速度，所以其实相当于用的就是同一个 cost model，这个东西本身就是通用的。从这个角度看，把 AST 抽取成一个固定长度的特征就是自然而然的。</p><p><img src="/2019/12/31/autotvm-2/屏幕快照-2019-12-31-下午2.00.14.png"></p><p>对于 XGB 来说，就是简单的抽取 AST 中循环变量所代表的 touched memory 和 outer loop length，在文章中这个叫做 context feature。然而问题在于不同的算子循环变量的数量都有可能是不同的，于是在 transfer learning 中，在代码中使用了一种叫 curve sample 的技术，实际上就是采样 context features 变成一个长度固定的 context relation feature。然而为什么这样就可以提高 transfer learning 的效果，其实我也搞的不太清楚。 在 TreeGRU 中，采取的方式是将循环变量的 context feature 通过 TreeGRU fold 起来，从而得到整个 AST 的 embedding。</p><h2 id="实验效果">实验效果</h2><p>中间那些不同方式的对比实验就不拿出来贴了，反正最后 state of art 是采用 rank loss 的 XGB。这里贴一个端到端的结果：</p><p><img src="/2019/12/31/autotvm-2/屏幕快照-2019-12-30-下午10.57.23.png"></p><p>从结果可以看出，优化效果非常强劲，而且越是那种非 benchmark 的网络（如 DQN），优化效果越好。 当然，这个方法并不是完美的，下一篇文章将陈述一些该方法的问题，并讲解几个 AutoTVM 方向最新的文章与成果。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;好了，本篇开始进入正题！内容基本都来自于：&lt;a href=&quot;https://arxiv.org/abs/1805.08166&quot;&gt;Learning to Optimize Tensor Programs. NeurlPS`18&lt;/a&gt;&lt;/p&gt;</summary>
    
    
    
    <category term="system" scheme="https://wyc-ruiker.github.io/categories/system/"/>
    
    <category term="机器学习" scheme="https://wyc-ruiker.github.io/categories/system/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/"/>
    
    
    <category term="AutoTVM" scheme="https://wyc-ruiker.github.io/tags/AutoTVM/"/>
    
    <category term="TVM" scheme="https://wyc-ruiker.github.io/tags/TVM/"/>
    
  </entry>
  
  <entry>
    <title>AutoTVM 探秘（一）</title>
    <link href="https://wyc-ruiker.github.io/2019/12/30/autotvm-1/"/>
    <id>https://wyc-ruiker.github.io/2019/12/30/autotvm-1/</id>
    <published>2019-12-30T12:34:03.000Z</published>
    <updated>2020-07-20T08:34:18.000Z</updated>
    
    <content type="html"><![CDATA[<p>周末要在实验室搞个类似讲座之类的东西，先在这里写一下讲座内容，理清思路。也是对最近一个月的学习内容做一个总结。</p><span id="more"></span><h2 id="从-tvm-开始">从 TVM 开始</h2><p><a href="https://arxiv.org/abs/1802.04799">TVM: an automated end-to-end optimizing compiler for deep learning. OSDI`18</a> AutoTVM 其实是 TVM 的一个组件，那么先要搞清楚 TVM 是个啥。</p><blockquote><p>Apache TVM (incubating) is a compiler stack for deep learning systems. It is designed to close the gap between the productivity-focused deep learning frameworks, and the performance- and efficiency-focused hardware backends. TVM works with deep learning frameworks to provide end to end compilation to different backends.</p></blockquote><p>简单来说，这是一个深度学习编译器。输入是 high-level DL program (Pytorch TensorFlow etc.) 输出是 low-level optimized code。</p><p><img src="/2019/12/30/autotvm-1/屏幕快照-2019-12-30-下午4.08.45.png"></p><p>本文章的主题其实就是图里面蓝色的那个 Machine Learning-Based Automated Optimizer。 不过在进入主题之前先谈一谈这个 TVM 的意义吧，通过 TVM 的意义其实我们就可以自然的了解到为什么我们需要 AutoTVM。这些内容其实在之前的两篇文章里面都有谈过。 在之前很多厂商都搞过深度学习编译器，比如 TensorFlow XLA、NVIDIA TensorRT 等等。之前的搞法通常都是先把这些乱七八糟深度学习框架前端统一成一个 Graph IR，再对这个 Graph IR 进行一些例如 Operator Fusion 和 Constant Folding 之类的优化，然后将 Graph IR 映射到 XLA 算子或者 cuDNN 中，这些算子很多是由专业的工程师进行手工优化，效果拔群，通过这个过程实现神经网络的高效。 问题在于，你的模型需要在一大坨设备上面跑（比如手机、树莓派、GPU、CPU...）这些设备的运算能力和优化方式都有所不同，那么就需要每个设备都搞一个编译框架，然后由很多很多工程师去实现很多高效的算子用来映射。一个更夸张的发展趋势是，很多 AI 芯片厂商会把一些常用算子（如卷积层）直接设计一个硬件模块去加速，这样会导致只要出一个牛逼网络，那 AI 芯片就会多做一个模块去对网络的某些公共运算进行加速，然后工程师也会设计相关的算子，不停加班，永不失业。 还有个问题就是，比如 Operator Fusion 这种优化，有一些算子（如卷积+池化+relu）的融合模块已经在 cuDNN 中写好了，那么 Operator Fusion 的时候就可以直接对应过去。但是随着 DL 的发展，越来越多算子都可能进行融合，但是因为底层的实现还没做好，导致在图级别的优化会出现捉襟见肘的情况。很多时候优化会倾向于使用成熟的算子，避免那些还没有优化很好的融合方式。 以及一个在 learning 领域广泛出现的问题——长尾分布。对于那些通用的优化来说，优化一下可以产生很大的性能提升，但是对于那些长尾的优化来说，优化一次的代价过高，产生的利益也没有那么丰厚。 显然，解决问题的核心就在于如何对不同的硬件和不同的算子进行一波通用的优化。</p><h2 id="autotvm-初探">AutoTVM 初探</h2><p>对于上面这个问题，TVM 给我们的答案是 AutoTVM，一个 Automating Optimization。 在谈论这个问题之前，我们还要再复习一下体系结构的内容。其实这个在前两篇文章中也讲过很多。</p><p><img src="/2019/12/30/autotvm-1/屏幕快照-2019-12-30-下午4.44.18.png"></p><p>一个简单的 CPU 架构可以概括为上面这样，这个 CPU 有两个核心，每个核心都有自己 L12 cache，然后也支持 SIMD，也就是 fetch 一个指令可以在两个 PU 上面运算。当然现在很多处理器都支持超线程，也就是说一个核心有两个硬件线程，每个核心在操作系统中其实是两个核心。然后现在最厉害的 SIMD 指令叫做 AVX-512，可以在每个 cycle 同时对 16 个 float32 进行运算。 所以对于 CPU 而言，最为常用的优化其实就是三种：Parallelization（多核并行）、Vectorization（SIMD）、Cache。还有一些诸如是否进行循环展开之类的优化。 下面用 TVM 实现一个最简单的矩阵乘法，程序来自于 <a href="https://docs.tvm.ai/tutorials/autotvm/tune_simple_template.html#sphx-glr-tutorials-autotvm-tune-simple-template-py">AutoTVM 教程</a></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">matmul_v0</span>(<span class="params">N, L, M, dtype</span>):</span><br><span class="line">    A = tvm.placeholder((N, L), name=<span class="string">&#x27;A&#x27;</span>, dtype=dtype)</span><br><span class="line">    B = tvm.placeholder((L, M), name=<span class="string">&#x27;B&#x27;</span>, dtype=dtype)</span><br><span class="line"></span><br><span class="line">    k = tvm.reduce_axis((<span class="number">0</span>, L), name=<span class="string">&#x27;k&#x27;</span>)</span><br><span class="line">    C = tvm.compute((N, M), <span class="keyword">lambda</span> i, j: tvm.<span class="built_in">sum</span>(A[i, k] * B[k, j], axis=k), name=<span class="string">&#x27;C&#x27;</span>)</span><br><span class="line">    s = tvm.create_schedule(C.op)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># schedule</span></span><br><span class="line">    y, x = s[C].op.axis</span><br><span class="line">    k = s[C].op.reduce_axis[<span class="number">0</span>]</span><br><span class="line"></span><br><span class="line">    yo, yi = s[C].split(y, <span class="number">8</span>)</span><br><span class="line">    xo, xi = s[C].split(x, <span class="number">8</span>)</span><br><span class="line"></span><br><span class="line">    s[C].reorder(yo, xo, k, yi, xi)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">return</span> s, [A, B, C]</span><br></pre></td></tr></table></figure><p>上面的程序只包括了 Cache 优化，方式就是常见的矩阵乘法循环变量 reorder 和矩阵分块。注意，这里矩阵分块的 magic number 是 8, 也就是说把这个矩阵分成 8*8 的小块，使得 cache 的 hit rate 更高。 但是对于这样的 magic number，没有经验的人是很难找到最优的数值的。而且这个数值跟很多硬件因素都有关系，很多时候我们不能对硬件的所有因素都产生全面的了解，这个时候就需要 AutoTVM 的帮助了。 用起来也很简单，其实就是指名一下哪些参数需要搜索。比如下面的程序就是指明要搜索 tile size：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">@autotvm.template</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">matmul</span>(<span class="params">N, L, M, dtype</span>):</span><br><span class="line">    A = tvm.placeholder((N, L), name=<span class="string">&#x27;A&#x27;</span>, dtype=dtype)</span><br><span class="line">    B = tvm.placeholder((L, M), name=<span class="string">&#x27;B&#x27;</span>, dtype=dtype)</span><br><span class="line"></span><br><span class="line">    k = tvm.reduce_axis((<span class="number">0</span>, L), name=<span class="string">&#x27;k&#x27;</span>)</span><br><span class="line">    C = tvm.compute((N, M), <span class="keyword">lambda</span> i, j: tvm.<span class="built_in">sum</span>(A[i, k] * B[k, j], axis=k), name=<span class="string">&#x27;C&#x27;</span>)</span><br><span class="line">    s = tvm.create_schedule(C.op)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># schedule</span></span><br><span class="line">    y, x = s[C].op.axis</span><br><span class="line">    k = s[C].op.reduce_axis[<span class="number">0</span>]</span><br><span class="line"></span><br><span class="line">    <span class="comment">##### define space begin #####</span></span><br><span class="line">    cfg = autotvm.get_config()</span><br><span class="line">    cfg.define_split(<span class="string">&quot;tile_y&quot;</span>, y, num_outputs=<span class="number">2</span>)</span><br><span class="line">    cfg.define_split(<span class="string">&quot;tile_x&quot;</span>, x, num_outputs=<span class="number">2</span>)</span><br><span class="line">    <span class="comment">##### define space end #####</span></span><br><span class="line"></span><br><span class="line">    <span class="comment"># schedule according to config</span></span><br><span class="line">    yo, yi = cfg[<span class="string">&quot;tile_y&quot;</span>].apply(s, C, y)</span><br><span class="line">    xo, xi = cfg[<span class="string">&quot;tile_x&quot;</span>].apply(s, C, x)</span><br><span class="line"></span><br><span class="line">    s[C].reorder(yo, xo, k, yi, xi)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">return</span> s, [A, B, C]</span><br></pre></td></tr></table></figure><p>这其实也是系统设计的艺术，首先 TVM 把运算与 schedule 进行解耦，然后一部分 schedule 由用户进行实现，一部分需要精细调整的内容由一个 ML 算法进行搜索，从而达到一个易用性和性能的 trade-off。相对应的是 Facebook 做的 Tensor Comprehension，要解决的问题跟 TVM 是类似的，但是选择的是利用 polyhedra model 进行一个类似端到端的优化过程，但是优化的空间其实比 TVM 这种 schedule space 模型要差一些，所以效果也会打些折扣。一些相关的讨论可以在<a href="https://www.zhihu.com/question/267167829">如何看待Tensor Comprehensions？与TVM有何异同？</a>上面看到。 对于 GPU 来说，由于架构跟 CPU 存在区别，所以优化的方式也不太一样：</p><p><img src="/2019/12/30/autotvm-1/屏幕快照-2019-12-30-下午8.07.07.png"></p><p><img src="/2019/12/30/autotvm-1/屏幕快照-2019-12-30-下午8.27.33.png"></p><p>可以看到，相对 CPU 来说，GPU 多了很多可以向量化的计算单元，甚至还有 Tensor Core 可以对计算进行张量化。而且 L1 cache 可以由程序员来进行主动的控制，作为线程之间的缓存，提供了很大的自由性。 在 GPU 里面还有线程与线程块的概念。几个 thread 会统一放到一个 block 中。同一个 block 中的线程会共享同一个 L1 cache 或者 shared memory，合理的分配 shared memory 会显著减少读写时间。 在 GPU 上面优化矩阵乘法，我们可以这样写，代码来自 <a href="http://tvm.d2l.ai/chapter_gpu_schedules/matmul.html">Dive in DL Compiler</a>：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">matmul_gpu</span>(<span class="params">n</span>):</span><br><span class="line">    A, B, C = d2ltvm.matmul(n, n, n)</span><br><span class="line">    s = tvm.create_schedule(C.op)</span><br><span class="line">    <span class="comment"># Create caches</span></span><br><span class="line">    A_shared = s.cache_read(A, <span class="string">&quot;shared&quot;</span>, [C])</span><br><span class="line">    A_local  = s.cache_read(A_shared, <span class="string">&quot;local&quot;</span>, [C])</span><br><span class="line">    B_shared = s.cache_read(B, <span class="string">&quot;shared&quot;</span>, [C])</span><br><span class="line">    B_local  = s.cache_read(B_shared, <span class="string">&quot;local&quot;</span>, [C])</span><br><span class="line">    C_local = s.cache_write(C, <span class="string">&quot;local&quot;</span>)</span><br><span class="line">    <span class="comment"># Split each axis into block axis, thread axis, and inner axis</span></span><br><span class="line">    x, y = s[C].op.axis</span><br><span class="line">    xb, xo, xi = split(s[C], x, (block_size, tx))</span><br><span class="line">    yb, yo, yi = split(s[C], y, (block_size, ty))</span><br><span class="line">    s[C].reorder(xb, yb, xo, yo, xi, yi)</span><br><span class="line">    <span class="comment"># Note that we bind yb to blockIdx.x instead of blockIdx.y</span></span><br><span class="line">    bind_thread(s[C], (yb, xb, yo, xo),</span><br><span class="line">                (<span class="string">&quot;blockIdx.x&quot;</span>, <span class="string">&quot;blockIdx.y&quot;</span>, <span class="string">&quot;threadIdx.x&quot;</span>, <span class="string">&quot;threadIdx.y&quot;</span>))</span><br><span class="line">    <span class="comment"># Schedule C_local</span></span><br><span class="line">    s[C_local].compute_at(s[C], yo)</span><br><span class="line">    yi, xi = s[C_local].op.axis</span><br><span class="line">    k, = s[C_local].op.reduce_axis</span><br><span class="line">    ko, ki = s[C_local].split(k, tk)</span><br><span class="line">    s[C_local].reorder(ko, ki, yi, xi)</span><br><span class="line">    <span class="comment"># Optimize read caches of A and B with cooperative fetching</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">optimize_read_cache</span>(<span class="params">shared, local</span>):</span><br><span class="line">        s[shared].compute_at(s[C_local], ko)</span><br><span class="line">        s[local].compute_at(s[C_local], ki)</span><br><span class="line">        y, x = s[shared].op.axis</span><br><span class="line">        <span class="comment"># Note that we must split into block_size parts to reuse</span></span><br><span class="line">        <span class="comment"># the previous axis threads</span></span><br><span class="line">        yo, yi = s[shared].split(y, nparts=block_size)</span><br><span class="line">        xo, xi = s[shared].split(x, nparts=block_size)</span><br><span class="line">        s[shared].reorder(yo, xo, yi, xi)</span><br><span class="line">        bind_thread(s[shared], (yo, xo), (<span class="string">&quot;threadIdx.y&quot;</span>, <span class="string">&quot;threadIdx.x&quot;</span>))</span><br><span class="line">    optimize_read_cache(A_shared, A_local)</span><br><span class="line">    optimize_read_cache(B_shared, B_local)</span><br><span class="line">    <span class="keyword">return</span> s, (A, B, C)</span><br></pre></td></tr></table></figure><p>看起来有点复杂，其实就是 shared memory 的一些分配。从代码中可以看到，有很多 split 的操作，事实上对于缺乏经验的工程师来说，确定这些 split size 是非常困难的。 在 <a href="https://docs.tvm.ai/tutorials/autotvm/tune_conv2d_cuda.html">AutoTVM 教程</a>中我们可以找到一个相对通用的模板：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">@autotvm.template</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">conv2d_no_batching</span>(<span class="params">N, H, W, CO, CI, KH, KW, stride, padding</span>):</span><br><span class="line">    <span class="keyword">assert</span> N == <span class="number">1</span>, <span class="string">&quot;Only consider batch_size = 1 in this template&quot;</span></span><br><span class="line"></span><br><span class="line">    data = tvm.placeholder((N, CI, H, W), name=<span class="string">&#x27;data&#x27;</span>)</span><br><span class="line">    kernel = tvm.placeholder((CO, CI, KH, KW), name=<span class="string">&#x27;kernel&#x27;</span>)</span><br><span class="line">    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=<span class="number">1</span>, out_dtype=<span class="string">&#x27;float32&#x27;</span>)</span><br><span class="line">    s = tvm.create_schedule([conv.op])</span><br><span class="line"></span><br><span class="line">    <span class="comment">##### space definition begin #####</span></span><br><span class="line">    n, f, y, x = s[conv].op.axis</span><br><span class="line">    rc, ry, rx = s[conv].op.reduce_axis</span><br><span class="line"></span><br><span class="line">    cfg = autotvm.get_config()</span><br><span class="line">    cfg.define_split(<span class="string">&quot;tile_f&quot;</span>, f, num_outputs=<span class="number">4</span>)</span><br><span class="line">    cfg.define_split(<span class="string">&quot;tile_y&quot;</span>, y, num_outputs=<span class="number">4</span>)</span><br><span class="line">    cfg.define_split(<span class="string">&quot;tile_x&quot;</span>, x, num_outputs=<span class="number">4</span>)</span><br><span class="line">    cfg.define_split(<span class="string">&quot;tile_rc&quot;</span>, rc, num_outputs=<span class="number">3</span>)</span><br><span class="line">    cfg.define_split(<span class="string">&quot;tile_ry&quot;</span>, ry, num_outputs=<span class="number">3</span>)</span><br><span class="line">    cfg.define_split(<span class="string">&quot;tile_rx&quot;</span>, rx, num_outputs=<span class="number">3</span>)</span><br><span class="line">    cfg.define_knob(<span class="string">&quot;auto_unroll_max_step&quot;</span>, [<span class="number">0</span>, <span class="number">512</span>, <span class="number">1500</span>])</span><br><span class="line">    cfg.define_knob(<span class="string">&quot;unroll_explicit&quot;</span>, [<span class="number">0</span>, <span class="number">1</span>])</span><br><span class="line">    <span class="comment">##### space definition end #####</span></span><br><span class="line"></span><br><span class="line">    <span class="comment"># inline padding</span></span><br><span class="line">    pad_data = s[conv].op.input_tensors[<span class="number">0</span>]</span><br><span class="line">    s[pad_data].compute_inline()</span><br><span class="line">    data, raw_data = pad_data, data</span><br><span class="line"></span><br><span class="line">    output = conv</span><br><span class="line">    OL = s.cache_write(conv, <span class="string">&#x27;local&#x27;</span>)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># create cache stage</span></span><br><span class="line">    AA = s.cache_read(data, <span class="string">&#x27;shared&#x27;</span>, [OL])</span><br><span class="line">    WW = s.cache_read(kernel, <span class="string">&#x27;shared&#x27;</span>, [OL])</span><br><span class="line">    AL = s.cache_read(AA, <span class="string">&#x27;local&#x27;</span>, [OL])</span><br><span class="line">    WL = s.cache_read(WW, <span class="string">&#x27;local&#x27;</span>, [OL])</span><br><span class="line"></span><br><span class="line">    <span class="comment"># tile and bind spatial axes</span></span><br><span class="line">    n, f, y, x = s[output].op.axis</span><br><span class="line">    bf, vf, tf, fi = cfg[<span class="string">&quot;tile_f&quot;</span>].apply(s, output, f)</span><br><span class="line">    by, vy, ty, yi = cfg[<span class="string">&quot;tile_y&quot;</span>].apply(s, output, y)</span><br><span class="line">    bx, vx, tx, xi = cfg[<span class="string">&quot;tile_x&quot;</span>].apply(s, output, x)</span><br><span class="line">    kernel_scope = n  <span class="comment"># this is the scope to attach global config inside this kernel</span></span><br><span class="line"></span><br><span class="line">    s[output].bind(bf, tvm.thread_axis(<span class="string">&quot;blockIdx.z&quot;</span>))</span><br><span class="line">    s[output].bind(by, tvm.thread_axis(<span class="string">&quot;blockIdx.y&quot;</span>))</span><br><span class="line">    s[output].bind(bx, tvm.thread_axis(<span class="string">&quot;blockIdx.x&quot;</span>))</span><br><span class="line">    s[output].bind(vf, tvm.thread_axis(<span class="string">&quot;vthread&quot;</span>))</span><br><span class="line">    s[output].bind(vy, tvm.thread_axis(<span class="string">&quot;vthread&quot;</span>))</span><br><span class="line">    s[output].bind(vx, tvm.thread_axis(<span class="string">&quot;vthread&quot;</span>))</span><br><span class="line">    s[output].bind(tf, tvm.thread_axis(<span class="string">&quot;threadIdx.z&quot;</span>))</span><br><span class="line">    s[output].bind(ty, tvm.thread_axis(<span class="string">&quot;threadIdx.y&quot;</span>))</span><br><span class="line">    s[output].bind(tx, tvm.thread_axis(<span class="string">&quot;threadIdx.x&quot;</span>))</span><br><span class="line">    s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)</span><br><span class="line">    s[OL].compute_at(s[output], tx)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># tile reduction axes</span></span><br><span class="line">    n, f, y, x = s[OL].op.axis</span><br><span class="line">    rc, ry, rx = s[OL].op.reduce_axis</span><br><span class="line">    rco, rcm, rci = cfg[<span class="string">&#x27;tile_rc&#x27;</span>].apply(s, OL, rc)</span><br><span class="line">    ryo, rym, ryi = cfg[<span class="string">&#x27;tile_rx&#x27;</span>].apply(s, OL, ry)</span><br><span class="line">    rxo, rxm, rxi = cfg[<span class="string">&#x27;tile_ry&#x27;</span>].apply(s, OL, rx)</span><br><span class="line">    s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)</span><br><span class="line"></span><br><span class="line">    s[AA].compute_at(s[OL], rxo)</span><br><span class="line">    s[WW].compute_at(s[OL], rxo)</span><br><span class="line">    s[AL].compute_at(s[OL], rxm)</span><br><span class="line">    s[WL].compute_at(s[OL], rxm)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># cooperative fetching</span></span><br><span class="line">    <span class="keyword">for</span> load <span class="keyword">in</span> [AA, WW]:</span><br><span class="line">        n, f, y, x = s[load].op.axis</span><br><span class="line">        fused = s[load].fuse(n, f, y, x)</span><br><span class="line">        tz, fused = s[load].split(fused, nparts=cfg[<span class="string">&quot;tile_f&quot;</span>].size[<span class="number">2</span>])</span><br><span class="line">        ty, fused = s[load].split(fused, nparts=cfg[<span class="string">&quot;tile_y&quot;</span>].size[<span class="number">2</span>])</span><br><span class="line">        tx, fused = s[load].split(fused, nparts=cfg[<span class="string">&quot;tile_x&quot;</span>].size[<span class="number">2</span>])</span><br><span class="line">        s[load].bind(tz, tvm.thread_axis(<span class="string">&quot;threadIdx.z&quot;</span>))</span><br><span class="line">        s[load].bind(ty, tvm.thread_axis(<span class="string">&quot;threadIdx.y&quot;</span>))</span><br><span class="line">        s[load].bind(tx, tvm.thread_axis(<span class="string">&quot;threadIdx.x&quot;</span>))</span><br><span class="line"></span><br><span class="line">    <span class="comment"># tune unroll</span></span><br><span class="line">    s[output].pragma(kernel_scope, <span class="string">&#x27;auto_unroll_max_step&#x27;</span>, cfg[<span class="string">&#x27;auto_unroll_max_step&#x27;</span>].val)</span><br><span class="line">    s[output].pragma(kernel_scope, <span class="string">&#x27;unroll_explicit&#x27;</span>, cfg[<span class="string">&#x27;unroll_explicit&#x27;</span>].val)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">return</span> s, [raw_data, kernel, conv]</span><br></pre></td></tr></table></figure><p>对于这些 knob，有个简单进行解释的图表： <img src="/2019/12/30/autotvm-1/屏幕快照-2020-01-02-上午11.37.17.png"></p><p>好了，现在对 AutoTVM 已经有了一些感性的理解了。不过这个开头写的有点多，以上内容先算一篇，下一篇我们讲 AutoTVM 的具体实现。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;周末要在实验室搞个类似讲座之类的东西，先在这里写一下讲座内容，理清思路。也是对最近一个月的学习内容做一个总结。&lt;/p&gt;</summary>
    
    
    
    <category term="system" scheme="https://wyc-ruiker.github.io/categories/system/"/>
    
    <category term="机器学习" scheme="https://wyc-ruiker.github.io/categories/system/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/"/>
    
    
    <category term="AutoTVM" scheme="https://wyc-ruiker.github.io/tags/AutoTVM/"/>
    
    <category term="TVM" scheme="https://wyc-ruiker.github.io/tags/TVM/"/>
    
  </entry>
  
  <entry>
    <title>CSE 599W: SYSTEMS FOR ML 课程笔记 7-12</title>
    <link href="https://wyc-ruiker.github.io/2019/11/21/cse-599w-systems-for-ml-7-12/"/>
    <id>https://wyc-ruiker.github.io/2019/11/21/cse-599w-systems-for-ml-7-12/</id>
    <published>2019-11-21T07:04:48.000Z</published>
    <updated>2021-12-16T11:33:10.000Z</updated>
    
    <content type="html"><![CDATA[<p>各种课程资料请参考<a href="https://reku1997.gitee.io/2019/11/08/cse-599w-systems-for-ml-1-6/">上一篇文章</a></p><span id="more"></span><h2 id="lecture-7-automatic-code-generation---tvm-stack">Lecture 7: Automatic Code Generation - TVM Stack</h2><p>现在深度学习的框架非常多，而这些乱七八糟框架写出来的代码通常又跑在乱七八糟的设备上。这其中最为关键的问题就是：如何让深度学习代码在不同的设备上都跑出最好的效果。 众所周知，如果软件架构出现了什么难以解决的问题，那就加个中间层，看看能不能解决。如果还不能解决，那就再加个中间层（ 目前各家（比如 TensorFlow XLA、NVIDIA TensorRT 等）采用的设计思路就是将各个框架写出来的网络转换成一种统一的表示形式，也就是所谓的 Graph IR。</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/1.png"></p><p>当然了，最后 IR 如果想要运行，那你还是要把 IR 变成机器码才可以。对于不同的硬件平台、数据格式、精度、线程结构都要写一堆不同的代码生成规则和优化规则。 这个问题是 TVM 的技术背景了。前面说了，如果还解决不了问题，那就再加个中间层。TVM 加的中间层就是所谓的 Tensor Expression Language 表示方法。这个 idea 来自于 Halide，核心在于把代码的计算和调度分开。</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/2.png"></p><p>举个具体的例子，最简单的一个向量相加，用 TVM 实现起来长这样：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">C = tvm.compute((n,), <span class="keyword">lambda</span> i: A[i] + B[i])</span><br><span class="line">s = tvm.create_schedule(C.op)</span><br></pre></td></tr></table></figure><p>得到的 C 代码是：</p><figure class="highlight cpp"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> (<span class="type">int</span> i = <span class="number">0</span>; i &lt; n; ++i)</span><br><span class="line">&#123;</span><br><span class="line">    C[i] = A[i] + B[i];</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><p>加上一些额外的循环控制（就是上篇文章中讲的 cache 优化）</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">C = tvm.compute((n,), <span class="keyword">lambda</span> i: A[i] + B[i])</span><br><span class="line">s = tvm.create_schedule(C.op)</span><br><span class="line">xo, xi = s[C].split(s[C].axis[<span class="number">0</span>], factor=<span class="number">32</span>)</span><br></pre></td></tr></table></figure><p>生成的代码就变成了：</p><figure class="highlight cpp"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> (<span class="type">int</span> xo = <span class="number">0</span>; xo &lt; <span class="built_in">ceil</span>(n / <span class="number">32</span>); ++xo)</span><br><span class="line">&#123;</span><br><span class="line">    <span class="keyword">for</span> (<span class="type">int</span> xi = <span class="number">0</span>; xi &lt; <span class="number">32</span>; ++xi)</span><br><span class="line">    &#123;</span><br><span class="line">        <span class="type">int</span> i = xo * <span class="number">32</span> + xi;</span><br><span class="line">        <span class="keyword">if</span> (i &lt; n)</span><br><span class="line">        &#123;</span><br><span class="line">            C[i] = A[i] + B[i];</span><br><span class="line">        &#125;</span><br><span class="line">    &#125;</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><p>甚至可以绑定特定的变量：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">C = tvm.compute((n,), <span class="keyword">lambda</span> i: A[i] + B[i])</span><br><span class="line">s = tvm.create_schedule(C.op)</span><br><span class="line">xo, xi = s[C].split(s[C].axis[<span class="number">0</span>], factor=<span class="number">32</span>)</span><br><span class="line">s[C].recorder(xi, xo)</span><br><span class="line">s[C].bind(xo, tvm.thread_axis(“blockIdx.x”)</span><br><span class="line">s[C].bind(xi, tvm.thread_axis(“threadIdx.x”)</span><br></pre></td></tr></table></figure><p>这样就出来一个 CUDA kernel 代码：</p><figure class="highlight cpp"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">int</span> i = threadIdx.x * <span class="number">32</span> + blockIdx.x;</span><br><span class="line"><span class="keyword">if</span> (i &lt; n)</span><br><span class="line">&#123;</span><br><span class="line">    C[i] = A[i] + B[i];</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><p>TVM 的核心就是一些调度原语，比如下图这些：</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/3.png"></p><p>还有 TVM 最有趣的部分 AutoTVM，可以用 learning 的方式对代码进行自动优化，不过这个东西在课程中并不是重点，所以只是简单的提及了一下。</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/4.png"></p><p>在我理解，TVM 就是一个 Graph IR 的优化框架，或者说是底层算子的高阶表示。</p><h2 id="lecture-8-hardware-specialization-in-deep-learning">Lecture 8: Hardware Specialization in Deep Learning</h2><p>这一讲主要介绍 TVM 技术栈中的重要部分 VTA。对于硬件我不是特别熟悉，也不是特别感兴趣，所以这章就随便看看了。 <a href="https://zhuanlan.zhihu.com/p/39635145">VTA: 开源AI芯片栈</a> Tianqi 的这篇文章概括性的讲述了 VTA 的意义。 总体来说，就是根据 RISC 的思想，设计了一套硬件架构。TVM 充当编译器，VTA 充当底层计算硬件，两者配合达到良好的效果。</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/5.png"></p><h2 id="lecture-9-memory-optimization">Lecture 9: Memory Optimization</h2><p>这一讲我觉得还蛮有意思的，讲的是深度学习的内存优化。 主要内容来自于这篇论文：<a href="https://arxiv.org/abs/1604.06174">Training Deep Nets with Sublinear Memory Cost</a> 这一讲的核心问题在于，为什么我们要用 autograd 来形成计算图，而不是直接使用 BP？</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/1-1.png"></p><p>这个问题的答案其实就本节课的标题：Memory Optimization。我们要优化内存的使用，所以要用计算图。 假如只有前向传播的话，我们其实可以通过计算图的层次结构来复用内存：</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/2-1.png"></p><p>这个内存复用有两个原则：Inplace 和 Normal Sharing。 所谓 Inplace 就是输出的内容直接存到输入内存的地方。Normal Sharing 则是重复使用那些不被需要的内存。 显然 Inplace 也有一些失效的时候，当输入的内容还被其他输出依赖时，我们之前的输出就不能使用这块内存。</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/3-1.png"></p><p>回到我们最开始的问题，为什么我们要用 autograd 的计算图来代替 BP？</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/4-1.png"></p><p>这个问题其实没有在 slides 中说清楚。看了看其他人的理解，都是觉得 BP 要存中间结果，所以比较费内存。但是计算图也要存中间结果啊？所以有点费解。我的理解是，计算图可以更好的捕获变量之间的依赖性，就像上面这个图一样。如果用 BP，每一步中间结果都要存起来，但是如果用 autograd，就可以全程 Inplace。 后面讲了一些利用两个原则优化内存的例子，都不是很难理解。但是问题在于，不管我们怎么复用，整体的空间复杂度仍然是线性增长的。</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/5-1.png"></p><p>这时候有个牛逼做法就是保存 <span class="math inline">\(\sqrt{N}\)</span> 个中间节点，每次需要反向传播的时候，就从这个中间节点往下算 <span class="math inline">\(\sqrt{N}\)</span> 步。这样内存就是 <span class="math inline">\(O(\sqrt{N})\)</span> 级别了，是不是跟 ICPC 竞赛里面常见的分块算法一模一样？ 感觉内存优化这部分还是有些东西不太清楚，需要回头看一下论文。看完论文补充一下这部分的内容。 接下来是 assignment 2。 这次作业感觉工作量还是很大的，首先是要满足作业的第一个要求，把 test_tvm_op.py 里面的所有测试都跑通：</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/1-2.png"></p><p>这一步主要要参考 <a href="https://docs.tvm.ai/index.html">TVM Documentation</a> 的写法，把这些算子都补齐。其中一些算子的作用还是比较迷惑的，要看测试才知道是什么意思。TVM 的写法有些地方也比较古怪，要自己慢慢尝试。 然后是满足第二个要求，跑通 mnist 的两个方法：</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/2-2.png"></p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/3-2.png"></p><p>这一步是要把 autodiff 的内容补全，其实很简单，但是有些地方可能不太清楚他到底想让你写什么东西，仔细看看调用代码，然后把函数补全就好了。 从上面的截图可以看出，我们自己写的算子效果非常垃圾，跑一遍 mlp 要 80s，下面就是如何优化矩阵乘法算子了：</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/4-2.png"></p><p>主要思路还是来自于 <a href="http://dlsys.cs.washington.edu/pdf/lecture6.pdf">Lecture 6</a>，只要简单的把矩阵分一下块，调整一下循环顺序，然后再并行一下，就可以得到 20x 的优化效果。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">s = tvm.create_schedule(C.op)</span><br><span class="line">xo, yo, xi, yi = s[C].tile(C.op.axis[<span class="number">0</span>], C.op.axis[<span class="number">1</span>], x_factor=<span class="number">32</span>, y_factor=<span class="number">64</span>)</span><br><span class="line">xk, yk = s[C].split(k, factor=<span class="number">8</span>)</span><br><span class="line">s[C].reorder(xo, yo, xk, xi, yi, yk)</span><br><span class="line">s[C].parallel(xo)</span><br><span class="line">s[C].unroll(yk)</span><br><span class="line">f = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name=func_name)</span><br></pre></td></tr></table></figure><p>更多精细的优化在 <a href="https://zhuanlan.zhihu.com/p/75203171">如何利用TVM快速实现超越Numpy的GEMM</a> 这篇文章中可以找到，非常牛逼。</p><h2 id="lecture-10-parallel-scheduling">Lecture 10: Parallel Scheduling</h2><p>在内存分配之后，第10讲讲的是并行调度的问题。 首先讲了一下 Model Parallel 和 Data Parallel 的模式，比较基础。</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/1-3.png"></p><p>因为现实的运算中存在着各种复杂的同步关系，所以我们需要一个自动调度器。 这个自动调度器需要调度各种各样的资源，包括数据、随机数生成器、网络通信等等；也要调度各种各样的操作。 对于深度学习来说，基于计算图的调度就非常自然。因为计算图是一个 DAG，各种依赖复用关系都非常明显。但是基于计算图的这种调度对于一些什么写后读、读后写的问题可能不是特别敏感，所以还增加需要 mutation aware 这种调度方式。</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/2-3.png"></p><p>后面给了一个很简单的队列调度的例子，看看就能懂。</p><h2 id="lecture-11-distributed-training-and-communication-protocols">Lecture 11: Distributed Training and Communication Protocols</h2><p>这一讲讲的主要是参数的 Synchronization 问题。 首先介绍了一种叫做 Allreduce 的操作，感觉就是在中间把分布式程序拦了一下：</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/3-3.png"></p><p>然后讲了几种网络拓扑对 Allreduce 操作的影响，看起来比较简单。 之后介绍了 Parameter Server，这个在实习的时候接触的蛮多，其实就是个 KV，用 PS 去更新和获取参数，他们内部甚至搞了个无锁 hash 表...</p><h2 id="lecture-12-model-serving">Lecture 12: Model Serving</h2><p>终于到了最后一讲，这一讲主要讲的是模型部署在现实应用中的问题。</p><p><img src="/2019/11/21/cse-599w-systems-for-ml-7-12/4-3.png"></p><p>主要分为模型压缩和服务系统两个部分来进行讲解。 模型压缩的第一个部分是矩阵/向量分解，这个矩阵分解套路很简单很常见，向量分解没见过，也没太看懂，有空可以仔细研究一下。大概是把一个 cnn 分解成几个 cnn，但是能达到相似的效果。 然后网络剪枝，其中一个思路是 prune the connections。想法很简单，每次增加前向传播的 theshold，减少连通性。 第二个思路是 weight sharing，首先对参数进行聚类，相似的参数就看做一个参数，再对他们进行统一的更新。 模型的低比特量化也是一个很有趣的思路，通过降低模型数据存储的精度，来压缩模型并且尽量保持精度。 还有知识蒸馏，用一个大模型去训练一个小模型。上周末实验室讲座有个大哥讲的就是这个内容，但是我因为家里有急事所以没听到，不过获得了 PPT，之后自己补一补吧。 第二个部分就是服务系统了，服务系统的目标是编写程序的灵活性、GPU 上面的高效率以及延迟满足 SLA（service-level agreement）。 然后讲了一个叫 Nexus 的系统，看起来有些难以理解，先略过吧~</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;各种课程资料请参考&lt;a href=&quot;https://reku1997.gitee.io/2019/11/08/cse-599w-systems-for-ml-1-6/&quot;&gt;上一篇文章&lt;/a&gt;&lt;/p&gt;</summary>
    
    
    
    <category term="system" scheme="https://wyc-ruiker.github.io/categories/system/"/>
    
    <category term="机器学习" scheme="https://wyc-ruiker.github.io/categories/system/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/"/>
    
    
    <category term="learning" scheme="https://wyc-ruiker.github.io/tags/learning/"/>
    
    <category term="system" scheme="https://wyc-ruiker.github.io/tags/system/"/>
    
    <category term="机器学习" scheme="https://wyc-ruiker.github.io/tags/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/"/>
    
  </entry>
  
  <entry>
    <title>CSE 599W: Systems for ML 课程笔记 1-6</title>
    <link href="https://wyc-ruiker.github.io/2019/11/08/cse-599w-systems-for-ml-1-6/"/>
    <id>https://wyc-ruiker.github.io/2019/11/08/cse-599w-systems-for-ml-1-6/</id>
    <published>2019-11-08T07:31:51.000Z</published>
    <updated>2021-12-16T11:33:12.000Z</updated>
    
    <content type="html"><![CDATA[<p><a href="http://dlsys.cs.washington.edu/">课程网站</a> 在头条 AML 实习的时候就觉得这个 AI system 方向非常有趣，但是苦于实验室不是搞这一套的，自己拖延症也非常严重，所以一直在入门的边缘徘徊。但是在今天——研一秋学期考试周的前一周，我决定开始学习 AI system 方向最著名的必学课程，Tianqi Chen 在 UW 开设的 CSE599W。 这个课程其实资料并不是很完善，只有 github 上面的几个 repo 和课程网站上面的 slide，缺乏讲课的视频资源。而且在开始学习之前就听说很多地方 slide 写的非常简陋，只能通过 tvm 和 tinyflow 代码慢慢学习。我在学习之前也找了一些 blog 资源，开个坑，希望可以努力坚持下来！ 本人的作业也开源到 <a href="https://github.com/wyc-ruiker/CSE-599W-2018">github</a> 上面了，希望大家多多指导。</p><span id="more"></span><p><a href="http://jcf94.com/2018/10/04/2018-10-04-cse559w/">Chenfan Blog——CSE 599W： Systems for ML</a></p><p><a href="https://zhuanlan.zhihu.com/c_186688444">知乎专栏：SysML/DL机器学习系统</a></p><p><a href="https://zhuanlan.zhihu.com/p/50529704">手把手带你遨游TVM</a></p><p>不过上面这两个参考资料，第一个有点过于简略，第二个虽然写的很好但是已经五个月没有更新了...可能是坑了，所以后面的内容还是要靠自己慢慢研究了... 第三篇是蓝色大大的 TVM 入门文章，写的非常赞，可以清晰的理解 TVM 存在的意义与解决的问题。</p><h2 id="lecture-1-2">Lecture 1-2</h2><p>介绍了一波深度学习的基本概念，常见的各种 CNN、RNN、激活函数、BatchNormal、梯度消失与梯度爆炸都用一节课介绍了一通。因为我之前学过 Ng 在 coursera 上面的 Deep learing 专项，所以对这些都很了解了，而且这些内容的资源满大街都是，就不在这里继续介绍了。 第二讲是一个实验课，用 mxnet 搭建一个基础的网络，虽然我之前系统看过沐神的《动手深度学习》，但是我看的是 github 上面的 pytorch 魔改版，mxnet 只在上半年的华为软挑的时候用过一下，不是特别了解。但是这个 gluon api 似乎跟 pytorch 大同小异，这里也不多废话了，有兴趣自己看看原版《动手深度学习》就好了。</p><h2 id="lecture-3-overview-of-deep-learning-system">Lecture 3: Overview of Deep Learning System</h2><p>学习过什么是 Deep Learning 了，那么啥是 Deep Learning System 呢？ 在我的理解其实就是从调包到真正计算出结果的全过程，就是 Deep Learning System，也就是通常所说的算法的真正落地。 在这个课程中，Deep Learning System 从高层到底层分成了三个部分：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/3-1.png"></p><p>第一部分就是调包的封装 API，第二部分表示调完包后对调包代码的优化与 Scheduling，第三部分就是最下面的一些高效的 GPU kernel 实现、不同硬件后端的部署等等。 那么一个 Deep Learning 框架在 API 层需要包括什么内容呢？为什么大家要用框架而不是自己从头写呢？ 下面这个图就回答了这个问题，现在的模型越来越大，实现起来需要注意的内容也就越来越多，如果每步都由我们自己来进行链式求导算梯度的话，可能就没有这么多转行深度学习的大哥了（</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/3-2.png"></p><p>计算图是一个 Deep Learning 框架的基本概念，节点表示运算操作，边表示数据依赖，这里展示了一个最简单的 Logistic Regression 计算图实例： 首先是计算 loss 之前的前向传播，其实就是一个最简单的矩阵乘法加一个 softmax：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/3-3.png"></p><p>然后是将 softmax 输出的东西搞一个交叉熵作为 loss，相当于最大化 liklihood:</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/3-4.png"></p><p>然后是自动微分：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/3-5.png"></p><p>最后通过 SGD 来更新梯度：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/3-6.png"></p><p>结合上面的所有步骤，我们就得到了一个计算图：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/3-7.png"></p><p>最上层的 API 做了最简单的介绍，下面就是中间的 System 部分，讲的是对计算图的优化和对计算的调度。 计算优化最简单的一种就是 memory 优化，比如增加 cache 利用率之类的。因为我们的代码通常跑在多个线程甚至多个计算设备上面，所以并行调度也是非常重要的。最简单的一种并行调度如下图，这是一个 mxnet 代码，因为计算 C 和计算 B 是完全独立的，所以这两个部分可以并行化。</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/3-8.png"></p><p>然后简单介绍一下最底层的情况，我们训练完了 model，要部署到不同的设备上。想必大家也都听过各种各样的厂商搞出过各种各样的部署方法，为了产生更好的模型性能：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/3-9.png"></p><p>TVM 就是为了解决这样的问题而诞生的，只要我们都搞成中间代码，全栈自动编译优化，这样 model 就可以非常简单的部署到不同的设备上了：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/3-10.png"></p><p>这门课将在接下来详细介绍 Deep Learning System 的三个部分。</p><h2 id="lecture-4-backpropagation-and-automatic-differentiation">Lecture 4: Backpropagation and Automatic Differentiation</h2><p>第四节课讲的是 Deep Learning 中的求导方式——Auto-Diff. 首先我们要了解现代计算机实现求导通常有哪些方式。 第一种叫做 Symbolic Differentiation，如图所示：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/4-1.png"></p><p>通过程序来维护整个求导的式子，然后把变量带入就得到最终的导数。这种做法的缺点在于表达式是一个很难维护的东西，最后要维护的东西就会越来越多。正常 Deep Learning 框架显然不应该选择这种求导方式... 第二种叫做 Numerical Differentiation，这种求导形式非常简单，看起来很适合计算机：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/4-2.png"></p><p>但是有个很关键的问题是，这种求导方法要进行两次正向传播，跑起来很慢，而且误差会比较大。但是因为其实现方式的简单，所以这是一个非常好的 grad check 工具。Ng 在 DL 专项里面也是用这种方式去进行 grad check 的。 第三种叫做 Backpropagation，现代 Deep Learning 的核心：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/4-3.png"></p><p>虽然这种做法很适合计算机，效率也不错。但有一个关键的问题，我们在做正向传播的时候要记录中间结果，这样才能在后面进行反向传播，内存耗费很大。而且难以形成计算图，无法进行通用的并行化。 这时候就需要掏出第四种方法，叫做 Automatic Differentiation：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/4-5.png"></p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/4-4.png"></p><p>其实思路非常简单，看伪代码就看的出来。这种方法就是通过反拓扑序去生成一个反向的计算图，因为是计算图所以不用保存任何中间变量；也因为是计算图，所以可以进行通用的并行化与代码优化，两全其美！ 接下来就是 Auto-diff 的具体实现，也就是 assignment 1。这次作业的难度不是太大，前面就是一些算子的简单实现，后面有两段重要的代码：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> node <span class="keyword">in</span> topo_order:</span><br><span class="line">    <span class="keyword">if</span> <span class="built_in">isinstance</span>(node.op, PlaceholderOp):</span><br><span class="line">        <span class="keyword">continue</span></span><br><span class="line">    input_vals = [node_to_val_map[x] <span class="keyword">for</span> x <span class="keyword">in</span> node.inputs]</span><br><span class="line">    res = node.op.compute(node, input_vals)</span><br><span class="line">    <span class="keyword">if</span> <span class="built_in">isinstance</span>(res, np.ndarray) == <span class="literal">False</span>:</span><br><span class="line">        res = np.array(res)</span><br><span class="line">    node_to_val_map[node] = res</span><br></pre></td></tr></table></figure><p>这段是前向传播的部分，记录每个节点计算出来的结果。前向传播需要正向拓扑序。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> node <span class="keyword">in</span> reverse_topo_order:</span><br><span class="line">    grad = sum_node_list(node_to_output_grads_list[node])</span><br><span class="line">    node_to_output_grad[node] = grad</span><br><span class="line">    input_grads = node.op.gradient(node, grad)</span><br><span class="line">    <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(node.inputs)):</span><br><span class="line">        node_to_output_grads_list[node.inputs[i]] = node_to_output_grads_list.get(node.inputs[i], [])</span><br><span class="line">        node_to_output_grads_list[node.inputs[i]].append(input_grads[i])</span><br></pre></td></tr></table></figure><p>这段是反向形成 Auto-diff 计算图的过程，跟伪代码很像。进一步解释一下：按照反拓扑序来遍历节点，计算到这个节点就代表着所有相关的梯度都计算完了。现在需要把相关的梯度都加起来，然后加起来的梯度作为这个节点向后面传播的梯度。整个过程很容易理解。</p><h2 id="lecture-5-gpu-programming">Lecture 5: GPU Programming</h2><p>这一讲涉及到的部分比较底层，虽然浙大本科的硬件三连质量非常高，但是对于 GPU 架构接触还是非常少的，又没有视频，只能尽力看一看了... 首先讲的是 CPU 的架构，很容易理解，CPU 跟内存有关的操作都很慢，所以才有各种指令预测、cache 优化算法之类的折磨着一代有一代的计算机学子：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/5-1.png"></p><p>而 GPU 则是给 CPU 加了一大堆计算资源，比普通的指令向量化更加强劲：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/5-2.png"></p><p>从这个图可以看出来，GPU 的寄存器数量非常大，所以他可以开很多线程去并行计算，但是从这个图也可以看出来，GPU 的 cache 小的可怜，所以 GPU 适用于那种轻内存读写、重计算的任务，也就是大量的并行计算任务。</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/5-3.png"></p><p>然后讲了 CUDA Programming Model，因为我完全不懂 CUDA 编程，所以后面的东西都是我瞎理解的，不一定对。 这个叫做 SIMT 的 Model 就是把几个 thread group 成一个 block，再把几个 block group 成一个 grid，block 可以以任何顺序在 GPU 上面调度。</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/5-3.5.png"></p><p>然后是个最简单的 cuda 程序——vector add。过程非常简单，跟操作系统里面查页表差不多，这里通过 block 下标和 thread 下标可以得到向量加法的下标，然后每个线程都执行一样的代码就可以了。</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/5-3.8.png"></p><p>后面讲了一个窗口 sum 的例子，因为他图画的不太好，所以有点难以理解，我尽力理解了一下：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/5-4.png"></p><p>最简单的实现方式是这样的：</p><figure class="highlight cpp"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">define</span> RADIUS 3</span></span><br><span class="line"><span class="function">__global__ <span class="type">void</span> <span class="title">windowSumNaiveKernel</span><span class="params">(<span class="type">const</span> <span class="type">float</span>* A, <span class="type">float</span>* B, <span class="type">int</span> n)</span> </span></span><br><span class="line"><span class="function"></span>&#123;    </span><br><span class="line">    <span class="type">int</span> out_index = blockDim.x * blockIdx.x + threadIdx.x;    </span><br><span class="line">    <span class="type">int</span> in_index = out_index + RADIUS;    </span><br><span class="line">    <span class="keyword">if</span> (out_index &lt; n) &#123;</span><br><span class="line">        <span class="type">float</span> sum = <span class="number">0.</span>;</span><br><span class="line">        <span class="keyword">for</span> (<span class="type">int</span> i = -RADIUS; i &lt;= RADIUS; ++i) &#123;</span><br><span class="line">            sum += A[in_index + i];        </span><br><span class="line">        &#125;</span><br><span class="line">        B[out_index] = sum;</span><br><span class="line">    &#125;</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><p>为什么说他图画的不太好，因为他的图跟代码的下标其实是对应不上的。B[0] 的结果应该是从 A[0] 到 A[6] 的和才对，这样下标就和代码完全对应上了。 这个实现非常简单，所以有很大的优化空间。其中最脑残的地方就是每个线程都读了 7 次 A，这样对 GPU 这种小 cache 来说是非常残忍的。</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/5-5.png"></p><p>所以一个非常简单的优化就是从一个线程读 7 次 A，变成一个 block (假设有四个线程) 读 9 次 A。</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/5-6.png"></p><p>那么下面的实现就很好理解了：</p><figure class="highlight cpp"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="function">__global__ <span class="type">void</span> <span class="title">windowSumKernel</span><span class="params">(<span class="type">const</span> <span class="type">float</span>* A, <span class="type">float</span>* B, <span class="type">int</span> n)</span> </span>&#123;</span><br><span class="line">    __shared__ <span class="type">float</span> temp[THREADS_PER_BLOCK + <span class="number">2</span> * RADIUS];</span><br><span class="line">    <span class="type">int</span> out_index = blockDim.x * blockIdx.x + threadIdx.x;</span><br><span class="line">    <span class="type">int</span> in_index = out_index + RADIUS;</span><br><span class="line">    <span class="type">int</span> local_index = threadIdx.x + RADIUS;</span><br><span class="line">    <span class="keyword">if</span> (out_index &lt; n) &#123;</span><br><span class="line">        temp[local_index] = A[in_index];</span><br><span class="line">        <span class="keyword">if</span> (threadIdx.x &lt; RADIUS) &#123;</span><br><span class="line">            temp[local_index - RADIUS] = A[in_index - RADIUS];</span><br><span class="line">            temp[local_index + THREADS_PER_BLOCK] = A[in_index+THREADS_PER_BLOCK];</span><br><span class="line">        &#125;</span><br><span class="line">        __syncthreads();</span><br><span class="line">        <span class="type">float</span> sum = <span class="number">0.</span>;</span><br><span class="line">        <span class="keyword">for</span> (<span class="type">int</span> i = -RADIUS; i &lt;= RADIUS; ++i) &#123;</span><br><span class="line">            sum += temp[local_index + i];</span><br><span class="line">        &#125;        </span><br><span class="line">        B[out_index] = sum;    </span><br><span class="line">    &#125;</span><br><span class="line">&#125;</span><br></pre></td></tr></table></figure><p>其中比较奥妙的地方是 if (threadIdx.x &lt; RADIUS) 这个语法块，其实也不难理解，就是把两边的 RADIUS 部分充满。最后是一个简单的矩阵乘法程序：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/5-7.png"></p><h2 id="lecture-6-optimize-for-hardware-backends">Lecture 6: Optimize for Hardware Backends</h2><p>这一讲处于计算图和硬件之间。 首先讲了一些体系结构的基本知识，比如多级 cache 之类的，然后讲了矩阵乘法的 cache 优化，非常基础。 最简单的矩阵乘法如下图所示：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/6-1.png"></p><p>像上面这样直接做矩阵乘法，其实对 cache 的利用率并不高，更好的方式是把矩阵分成很多个小块：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/6-2.png"></p><p>分块的方法可以缩小从下级 cache 到寄存器的 cost，有人想那我直接把 v1、v2 高的大一点，这样不就可以优化更多了吗？ 但是通常来说，寄存器非常小，很难存下一整列或者一整行，所幸 CPU 的 cache 通常比较大，也许可以存下：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/6-3.png"></p><p>把对寄存器的优化和 L1 cache 的优化相结合，就得到了下面这种比较复杂的优化方式：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/6-4.png"></p><p>对于 CPU 来说，代码优化可能集中于 Reuse memory 上面，而 GPU 的优化则是集中于 Reuse among threads 上。（这段其实不太懂，我理解就是跟前一讲的那个 window sum 一样，优化复用不同线程的内存）</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/6-5.png"></p><p>对于计算图的代码优化，其实套路千变万化：</p><p><img src="/2019/11/08/cse-599w-systems-for-ml-1-6/6-6.png"></p><p>为了解决这样的优化问题，我们就需要本门课的核心内容——TVM 了！</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;&lt;a href=&quot;http://dlsys.cs.washington.edu/&quot;&gt;课程网站&lt;/a&gt; 在头条 AML 实习的时候就觉得这个 AI system 方向非常有趣，但是苦于实验室不是搞这一套的，自己拖延症也非常严重，所以一直在入门的边缘徘徊。但是在今天——研一秋学期考试周的前一周，我决定开始学习 AI system 方向最著名的必学课程，Tianqi Chen 在 UW 开设的 CSE599W。 这个课程其实资料并不是很完善，只有 github 上面的几个 repo 和课程网站上面的 slide，缺乏讲课的视频资源。而且在开始学习之前就听说很多地方 slide 写的非常简陋，只能通过 tvm 和 tinyflow 代码慢慢学习。我在学习之前也找了一些 blog 资源，开个坑，希望可以努力坚持下来！ 本人的作业也开源到 &lt;a href=&quot;https://github.com/wyc-ruiker/CSE-599W-2018&quot;&gt;github&lt;/a&gt; 上面了，希望大家多多指导。&lt;/p&gt;</summary>
    
    
    
    <category term="system" scheme="https://wyc-ruiker.github.io/categories/system/"/>
    
    <category term="机器学习" scheme="https://wyc-ruiker.github.io/categories/system/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/"/>
    
    
    <category term="learning" scheme="https://wyc-ruiker.github.io/tags/learning/"/>
    
    <category term="system" scheme="https://wyc-ruiker.github.io/tags/system/"/>
    
    <category term="课程笔记" scheme="https://wyc-ruiker.github.io/tags/%E8%AF%BE%E7%A8%8B%E7%AC%94%E8%AE%B0/"/>
    
  </entry>
  
  <entry>
    <title>从 AdaBoost 到 GBDT</title>
    <link href="https://wyc-ruiker.github.io/2019/07/26/adaboost/"/>
    <id>https://wyc-ruiker.github.io/2019/07/26/adaboost/</id>
    <published>2019-07-26T08:07:10.000Z</published>
    <updated>2021-12-16T11:24:07.000Z</updated>
    
    <content type="html"><![CDATA[<p>集成学习顾名思义，就是把一堆垃圾方法集成起来变成一个牛逼的方法。集成学习主要分为两种思路：Bagging 和 Boosting。Bagging 的话就是一堆独立的垃圾方法，比如随机森林，就是通过不同的采样和不同的特征抽取方法产生一堆独立的决策树，然后把他们的决策进行投票。而 Boosting 则是通过一个垃圾方法来产生下一个垃圾方法，最知名的方法就是 AdaBoost 了。</p><span id="more"></span><h2 id="看上去傻傻的-adaboost">看上去傻傻的 AdaBoost</h2><p>对于普通的集成学习来说，其实最后的结果就是不同方法的线性组合，以二分类问题（1，-1）为例，最后的结果在数学上可以表示为： <span class="math display">\[G=sign(\sum_{t=1}^T\alpha_t g_t(x_n))\]</span></p><p>上面的 <span class="math inline">\(g_t\)</span> 就代表不同的学习方法，而 <span class="math inline">\(\alpha_t\)</span> 代表着每个学习方法的权重。 在 AdaBoost 中，最关键的一点就是对于错误函数的修改。AdaBoost 使用的是带权重的错误函数，<span class="math inline">\(u_n\)</span> 代表着每个样本点犯错误的权重： <span class="math display">\[E_{in}^u(h)=\frac1N\sum_{n=1}^N u_n\cdot err(y_n,h(x_n))\]</span></p><p>我们就是要使用不同的 <span class="math inline">\(u_n\)</span> 来的得到不同的方法 <span class="math inline">\(g_t\)</span>。显然，全都差不多的 <span class="math inline">\(g_t\)</span> 最后集成学习出来的效果肯定很垃圾。从参数上面思考可以看出，模型 <span class="math inline">\(g_t\)</span> 是通过参数 <span class="math inline">\(u_n^t\)</span> 生成的，而模型 <span class="math inline">\(g_{t+1}\)</span> 则是通过参数 <span class="math inline">\(u_n^{t+1}\)</span> 来生成的。我们想让 <span class="math inline">\(g_t\)</span> 和 <span class="math inline">\(g_{t+1}\)</span> 产生足够大的差距，其实就是让 <span class="math inline">\(u_n^{t+1}\)</span> 对应的 error 在 <span class="math inline">\(g_t\)</span> 上表现很差就好了。 什么叫表现的差呢？对于分类预测来说，最差的结果其实就是扔硬币，也就是说达到百分之 <span class="math inline">\(50\)</span> 的错误率，这个结果已经是最差的了。我们就是要通过构造 <span class="math inline">\(u_n^{t+1}\)</span> 使得 <span class="math inline">\(g_t\)</span> 在这个 error 上的错误率接近 <span class="math inline">\(0.5\)</span>。</p><p>AdaBoost 的做法很简单，如果 <span class="math inline">\(g_t\)</span> 的错误率为 <span class="math inline">\(\epsilon_t\)</span> 的话，那么构造一个尺度因子： <span class="math display">\[\diamond t=\sqrt{\frac{1-\epsilon_t}{\epsilon_t}}\]</span></p><p>对于正确的 <span class="math inline">\(u_n^t\)</span> 可以乘上 <span class="math inline">\(\diamond t\)</span>，对于错误的 <span class="math inline">\(u_n^t\)</span>，可以除上 <span class="math inline">\(\diamond t\)</span>。这样就会使得正确的和错误的参数 <span class="math inline">\(u_{n+1}^t\)</span> 达到平衡，从而达到放大错误、缩小正确的目的，使得 <span class="math inline">\(g_t\)</span> 在新 error 上面的错误率等于 <span class="math inline">\(0.5\)</span>。 现在有了生成每个 <span class="math inline">\(g_t\)</span> 的方法，那每个方法的参数 <span class="math inline">\(\alpha_t\)</span> 应该是什么呢？一个直观的想法是：生成完方法之后，再对这些方法进行线性组合的优化，然后来求出 <span class="math inline">\(\alpha_t\)</span>。这样固然是可以的，但是其实我们可以通过数学推导，来直接算出这个 <span class="math inline">\(\alpha_t\)</span>。这里先下结论，一会儿再进行数学推导： <span class="math display">\[\alpha_t=ln(\diamond t)\]</span></p><p>最简单的一种 AdaBoost 方法叫做 AdaBoost-Stump。在 AdaBoost-Stump 中，每个小方法都只能在某个维度上面画直线来对训练集进行分类。但就是这样简单的方法，可以通过 Boost 组合的方式达到非常好的效果，这就是 AdaBoost 的神奇之处。 现在我们想利用 AdaBoost 把决策树来组合起来。但是问题来了，首先是 error 函数如何参与决策树的分支操作，这个参与起来会比较麻烦。比较简单的做法是直接对决策树屏蔽了 <span class="math inline">\(u_n\)</span> 参数。然后通过 <span class="math inline">\(u_n\)</span> 参数来对训练集进行采样，来达到不同 error 函数的效果。这样决策树就是原来的决策树，不需要进行任何的修改。 另一个问题就是，很多决策树都是直接把叶子剖分到单个结点或者单个类别。这样的做法必然会导致这个决策树在测试集上面的错误率 <span class="math inline">\(\epsilon_t\)</span> 等于 <span class="math inline">\(0\)</span>。这样我们就没有办法通过 <span class="math inline">\(\diamond t\)</span> 来产生下一个垃圾方法了。当然解决这个问题的方法也很简单，对决策树进行剪枝和采样数据不全部采样都是可以尝试的做法。这个方法叫做 AdaBoost-DTree。值得注意的一点是：如果我们的决策树的高度限制为 <span class="math inline">\(1\)</span>，也就是说只能做一次划分，那么这个方法就跟 AdaBoost-Stump 没有任何区别啦。也就是说 AdaBoost-Stump 是 AdaBoost-DTree 的退化。</p><h2 id="奥妙重重的-adaboost">奥妙重重的 AdaBoost</h2><p>还记得上面的 <span class="math inline">\(\alpha_t\)</span> 吗？下面就是通过推导 <span class="math inline">\(\alpha_t\)</span> 来发现 AdaBoost 的奥妙之处了。当然过程中其实涉及到一些高深的数学知识，但是因为我都不会，所以很多地方可能会讲的很“民科”。 上面说过 <span class="math inline">\(u_n^{t+1}\)</span> 的求法：对于正确的 <span class="math inline">\(u_n^t\)</span> 可以乘上 <span class="math inline">\(\diamond t\)</span>，对于错误的 <span class="math inline">\(u_n^t\)</span>，可以除上 <span class="math inline">\(\diamond t\)</span>。通过数学式子其实可以将他们归纳成一个规律： <span class="math display">\[u_n^{(t+1)}=u_n^{(t)}\cdot \diamond_t^{-y_ng_t(x_n)}=u_n^{(t)}\cdot exp(-y_n\alpha_tg_t(x_n))\]</span></p><p>如果初始值 <span class="math inline">\(u_n^{(1)}=\frac1N\)</span> 的话，就有： <span class="math display">\[u_n^{(T+1)}=u_n^{(1)}\cdot \prod_{t=1}^Texp(-y_n\alpha_tg_t(x_n))=\frac1N\cdot exp(-y_n\sum_{t=1}^T\alpha_tg_t(x_n))\]</span></p><p>通过上面的式子可以发现一个很显然的事情：<span class="math inline">\(u_n^{(T+1)}\)</span> 与 <span class="math inline">\(exp(-y_n\sum_{t=1}^T\alpha_tg_t(x_n))\)</span> 成正比。 将 <span class="math inline">\(\sum_{t=1}^T\alpha_tg_t(x_n)\)</span> 从另外一个角度来看，可以发现这个其实是一个对 <span class="math inline">\(x_n\)</span> 特征转换的线性组合，跟 SVM 中的那个完全一致。其实这个式子就是没有进行正则化的分界距离，跟 <span class="math inline">\(y_n\)</span> 相乘的话，可以感性的理解成跟 SVM 一样，这个距离越大越好。 使得 <span class="math inline">\(y_n\sum_{t=1}^T\alpha_tg_t(x_n)\)</span> 越大越好，那么显然 <span class="math inline">\(exp(-y_n\sum_{t=1}^T\alpha_tg_t(x_n))\)</span> 越小越好，于是 <span class="math inline">\(u_n^{(T+1)}\)</span> 也就越小越好。 我们的目标就变成了最小化： <span class="math display">\[\sum_{n=1}^Nu_n^{(T+1)}=\frac1N\sum_{n=1}^Nexp(-y_n\sum_{t=1}^T\alpha_tg_t(x_n))\]</span></p><p>因为 <span class="math inline">\(\sum_{t=1}^T\alpha_tg_t(x_n)\)</span> 是我们的预测项，所以这个最小化其实也是另一种错误函数，而且这个错误函数显然是 0/1 error 的上界，这个错误函数一般叫做 <span class="math inline">\(\hat{err}_{ADA}\)</span>。 有个优化函数，下面的问题就是如何求出这个优化函数 <span class="math inline">\(\sum_{n=1}^Nu_n^{(T+1)}\)</span> 的最小值了。 思考梯度下降的过程，我们通过泰勒展开发现梯度的反方向是要求下降的最好方向，当然梯度下降的方向是一个向量。但是我们这里的“梯度”其实是一个函数，函数跟向量其实并没有多大的区别，只是一个下标是连续的一个下标是离散的（因为数学不好，只能这么理解了）。 接下来对要最小化的式子进行推导： <span class="math display">\[\frac1N\sum_{n=1}^Nexp(-y_n(\sum_{t=1}^T\alpha_tg_t(x_n)+\eta h(x_n)))=\sum_{n=1}^Nu_n^Texp(-y_n\eta h(x_n))\]</span></p><p>对这个式子进行最简单的一阶泰勒展开可以得到：<span class="math display">\[\sum_{n=1}^Nu_n^Texp(-y_n\eta h(x_n))=\sum_{n=1}^Nu_n^t(1-y_n\eta h(x_n))=\sum_{n=1}^Nu_n^t-\eta\sum_{n=1}^Nu_n^ty_nh(x_n)\]</span></p><p>先忽略掉步长 <span class="math inline">\(\eta\)</span>，我们的目标就变成了找到一个好的 <span class="math inline">\(h(x_n)\)</span> 来最小化 <span class="math inline">\(\sum_{n=1}^Nu_n^{(t)}(-y_nh(x_n))\)</span>。 对于二分类问题，<span class="math inline">\(-y_nh(x_n)\)</span> 的值要么是 <span class="math inline">\(-1\)</span> 要么是 <span class="math inline">\(1\)</span>。当 <span class="math inline">\(y_n = h(x_n)\)</span> 时，<span class="math inline">\(\sum_{n=1}^Nu_n^{(t)}(-y_nh(x_n)) = -\sum_{n=1}^Nu_n^{(t)}\)</span>，当 <span class="math inline">\(y_n \neq h(x_n)\)</span> 时，<span class="math inline">\(\sum_{n=1}^Nu_n^{(t)}(-y_nh(x_n)) = \sum_{n=1}^Nu_n^{(t)}\)</span>。将这个结果稍微平移并且统一一下，可以发现一件神奇的事情：<span class="math inline">\(\sum_{n=1}^Nu_n^{(t)}(-y_nh(x_n)) = -\sum_{n=1}^Nu_n^{(t)}+2E_{in}^{u_n^{(t)}}\cdot N\)</span>。 太有趣了。让原来的 <span class="math inline">\(\sum_{n=1}^Nu_n^{(t)}(-y_nh(x_n))\)</span> 竟然最优的 <span class="math inline">\(h(x_n)\)</span> 就是让 <span class="math inline">\(E_{in}^{u_n^{(t)}}\)</span> 最小的 <span class="math inline">\(h(x_n)\)</span>，也就是我们的 AdaBoost 过程中求出的 <span class="math inline">\(g_t\)</span>！ 在之前梯度下降的时候，我们选择的是自己随便设置一个步长，但是在 AdaBoost 里面，因为我们要组合各个方法以达到最好的效果，所以这个步长其实就是最后预测式中的 <span class="math inline">\(\alpha_t\)</span>。这个步长也就是在最佳方向上的最大步进长度，先把要求最佳步长的表达式写下来： <span class="math display">\[\check{E}_{ADA}=\sum_{n=1}^Nu_n^{(t)}exp(-y_n\eta g_t(x_n))\]</span></p><p>有两种情况需要我们考虑，分别是预测正确：<span class="math inline">\(u_n^{(t)}exp(-\eta)\)</span>，和预测错误：<span class="math inline">\(u_n^{(t)}exp(+\eta)\)</span>。 之前我们将错误率以符号 <span class="math inline">\(\epsilon_t\)</span> 来表示，经过简单推导统一，可以得到： <span class="math display">\[\check{E}_{ADA}=(\sum_{n=1}^Nu_n^{(t)})\cdot ((1-\epsilon_t)exp(-\eta)+\epsilon_t\ exp(+\eta))\]</span></p><p>求导，<span class="math inline">\(\frac{\partial \check{E}_{ADA}}{\partial \eta}=0\)</span> 得到： <span class="math display">\[\eta_t=ln\sqrt{\frac{1-\epsilon_t}{\epsilon_t}}=\alpha_t\]</span></p><p>这就是 <span class="math inline">\(\alpha_t = ln\sqrt{\frac{1-\epsilon_t}{\epsilon_t}}\)</span> 的原因啦！</p><h2 id="从-adaboost-到-gradient-boosting">从 AdaBoost 到 Gradient Boosting</h2><p>总结一下之前 AdaBoost 的求解过程，其实就是去优化以下式子： <span class="math display">\[min_{\eta}min_h\frac1N\sum_{n=1}^Nexp(-y_n(\sum_{t=1}^T\alpha_tg_t(x_n)+\eta h(x_n)))\]</span></p><p>之前说了，这个 <span class="math inline">\(exp\)</span> 函数其实只是错误函数的一种形式，我们也可以换成其他类型的 error 函数，比如这样： <span class="math display">\[min_{\eta}min_h\frac1N\sum_{n=1}^Nerr(\sum_{t=1}^T\alpha_tg_t(x_n)+\eta h(x_n), y_n)\]</span></p><p>这个公式就是一种通用的 Gradient Boosting 啦！ 接下来，我们就使用普通的 regression 错误（<span class="math inline">\(err(s,y)=(s-y)^2\)</span>）来看看 regression 情况下的 Gradient Boosting 是怎么回事吧~ 我们将 <span class="math inline">\(\sum_{t=1}^T\alpha_tg_t(x_n)\)</span> 看做 <span class="math inline">\(s_n\)</span>，那么原式通过泰勒展开之后就等于： <span class="math display">\[min_h\frac1N\sum_{n=1}^Nerr(s_n,y_n)+\frac1N\sum_{n=1}^N\eta h(x_n)\frac{\partial err(s,y_n)}{\partial s}\]</span></p><p>其中的一阶导数<span class="math inline">\(\frac{\partial err(s,y_n)}{\partial s}=2(s_n-y_n)\)</span>。 去除一堆常数项和常数因子，其实我们只需要最小化<span class="math inline">\(h(x_n)\cdot 2(s_n-y_n)\)</span>就好了。所以只要令<span class="math inline">\(h(x_n)\)</span>是梯度<span class="math inline">\(2(s_n-y_n)\)</span>的反方向就好了，但是这个反方向究竟要取多大呢？回想一下我们之前的梯度下降，其实梯度只代表一个方向，多大并没有什么关系。为了防止这个梯度取到无穷大，我们需要对梯度的大小进行一下限制。参考之前的正则化思路，我们也可以通过加上一个惩罚项<span class="math inline">\(h^2(x_n)\)</span>来得到新的优化式子，经过添加常数进行整理，可以得到我们最后关心的优化式子： <span class="math display">\[min\sum_{n=1}^N((h(x_n)-(y_n-s_n))^2)\]</span></p><p>也就是说，我们利用我们的基础 regression 方法使得 <span class="math inline">\(h(x_n)\)</span> 更加接近 <span class="math inline">\(y_n-s_n\)</span> 就可以了。简单的来说就是对所有 <span class="math inline">\(N\)</span> 个点 <span class="math inline">\((x_n, y_n-s_n)\)</span> 做 regression，得到的回归方程就是我们要求的 <span class="math inline">\(g_t(x_n)\)</span> 啦！ 其中一个很重要的概念就是 <span class="math inline">\(y_n-s_n\)</span>，很多博客会把这个东西叫做残差。也就是这些残差来决定 <span class="math inline">\(g_t\)</span>。 现在需要求出步长 <span class="math inline">\(\eta\)</span>。这个步骤非常简单，把我们之前求出来的 <span class="math inline">\(g_t\)</span> 代回到原来的优化式子中们可以得到： <span class="math display">\[min_{\eta}\frac1N\sum_{n=1}^N(s_n+\eta g_t(x_n)-y_n)^2 = \frac1N\sum_{n=1}^N((y_n-s_n)-\eta g_t(x_n))^2\]</span></p><p>可以发现，这里又是对残差进行拟合。不过这个拟合只有一个变量，非常简单，只需要简单求个导数就可以得到我们需要的 <span class="math inline">\(\eta\)</span> 啦！ 这个 Gradient Boosting 最出名的利用方法就是我们标题中提到的大名鼎鼎的 Gradient Boosted Decision Tree(GBDT) 啦！不过我们整个推导过程中好像并没有用过决策树啊？这个决策树要在哪里使用呢？ 很简单，就是我们通过决策树来做每一步的 regression 就好啦！有一个细节是，因为我们在求第一棵决策树的时候并没有残差这个东西，所以直接对各个 <span class="math inline">\((x_n, y_n)\)</span> 做拟合就好了，从第二棵决策树开始对残差做拟合。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;集成学习顾名思义，就是把一堆垃圾方法集成起来变成一个牛逼的方法。集成学习主要分为两种思路：Bagging 和 Boosting。Bagging 的话就是一堆独立的垃圾方法，比如随机森林，就是通过不同的采样和不同的特征抽取方法产生一堆独立的决策树，然后把他们的决策进行投票。而 Boosting 则是通过一个垃圾方法来产生下一个垃圾方法，最知名的方法就是 AdaBoost 了。&lt;/p&gt;</summary>
    
    
    
    <category term="机器学习" scheme="https://wyc-ruiker.github.io/categories/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/"/>
    
    
  </entry>
  
  <entry>
    <title>支持向量机的清晰推导</title>
    <link href="https://wyc-ruiker.github.io/2019/07/06/svm/"/>
    <id>https://wyc-ruiker.github.io/2019/07/06/svm/</id>
    <published>2019-07-06T06:12:16.000Z</published>
    <updated>2021-12-16T11:32:34.000Z</updated>
    
    <content type="html"><![CDATA[<p>其实是标题党啦... SVM 推导作为最臭名昭著的机器学习面试题，其实我在去年的这个时候准备头条实习面试的期间就已经“背诵”过了，但完全没有理解自己推导的是个什么东西。最近看台大的《机器学习技法》的课程视频，感觉这个 SVM 推导过程讲的非常清晰。最关键的是每一步的 motivation 都讲的非常清楚，正好博客也好久好久好久没有更新了，这里简单重复一下，以证明自己学会了 SVM 的推导吧~</p><span id="more"></span><h2 id="从线性分割说起"><strong>从线性分割说起</strong></h2><p>在《机器学习基石》课程中讲了一个基于迭代的简单线性分类算法 PLA，但是对一个二分类的数据集，合法的分类线普遍是无数条。对于这无数条线我们应该选择哪条线作为我们最终分类的结果才比较好呢？</p><p><img src="/2019/07/06/svm/a.png"></p><p>从上图来看，显然最右边这一列是比较合理的分割方法。也就说，离分类线的最近的点要尽可能的远，这样对各种噪音误差的兼容性是最好的。 用数学的定义来说就是： <span class="math display">\[max_{w}min_{n=1...N}distance(x_n, w)\]</span></p><p>对于每个点来说都必须存在<span class="math inline">\(y_nw^Tx_n&gt;0\)</span>（如果一定要每个点都没有分割错误的话）。 在计算距离的时候，bias 显然是不应该参与进来的，所以需要把 bias 从矩阵中拆出来，变成<span class="math inline">\(h(x)=sign(w^Tx+b)\)</span>的形式，距离就从<span class="math inline">\(distance(x,w)\)</span>变成了<span class="math inline">\(distance(x,b,w)\)</span>。 下面要计算点<span class="math inline">\(x\)</span>到超平面<span class="math inline">\(w^Tx+b\)</span>的距离，这个就是一个简单的投影，得到 <span class="math display">\[distance(x,b,w)=\frac{|w^Tx+b|}{||w||}\]</span></p><p>因为上面有对于每个点来说都必须存在<span class="math inline">\(y_n(w^Tx_n+b)&gt;0\)</span>的条件，所以分子里面这个麻烦的绝对值式子可以直接用<span class="math inline">\(y_n(w^Tx_n+b)\)</span>来代替（反正<span class="math inline">\(y_n\)</span>要么是<span class="math inline">\(1\)</span>要么是<span class="math inline">\(-1\)</span>，不会对结果有任何影响），所以得到以下式子： <span class="math display">\[distance(x,b,w)=\frac{y_n(w^Tx_n+b)}{||w||}\]</span></p><p>又因为<span class="math inline">\(w^Tx_n+b\)</span>这个东西是可以加常数放缩的，也就是说<span class="math inline">\(w^Tx_n+b\)</span>跟<span class="math inline">\(3w^Tx_n+3b\)</span>所代表的平面是完全一致的，所以我们完全可以通过放缩使得最小的平面求出来的值等于<span class="math inline">\(1\)</span>，这样的好处是，最小的 distance 一定是 <span class="math inline">\(\frac{1}{||w||}\)</span>。这样做我们要求的数学式子就变为： <span class="math display">\[max_{w}\frac{1}{||w||}min_{n=1...N}y_n(w^Tx_n+b)=1\]</span></p><p>但是接下来有个奇怪的问题，如果我们限制的条件是所有的<span class="math inline">\(y_n(w^Tx_n+b)\ge 1\)</span>，那一定会有一个<span class="math inline">\(w\)</span>使得<span class="math inline">\(y_n(w^Tx_n+b)=1\)</span>吗？ 这个证明其实很简单，如果最优的<span class="math inline">\(w\)</span>使得<span class="math inline">\(y_n(w^Tx_n+b)=1.126\)</span>，那么我们只要取<span class="math inline">\((b/1.126, w/1.126)\)</span>这样一组解，就会比之前的<span class="math inline">\(w\)</span>更优了。 现在优化目标已经很明显了，为了之后陈述方便，我们将最大化<span class="math inline">\(\frac{1}{||w||}\)</span>变成最小化<span class="math inline">\(w^Tw\)</span>。为了后续的推导方便，加了一个<span class="math inline">\(\frac{1}{2}\)</span>的放缩。于是我们的优化函数就变成了下面这样： <span class="math display">\[min_{w}\frac{1}{2}w^Tw\]</span><span class="math display">\[min_{n=1...N}y_n(w^Tx_n+b)\ge 1\]</span></p><p>现在有了优化函数，如何求解呢？可以很简单的发现，这是一个非常经典的二次规划问题！给出一堆<span class="math inline">\(w\)</span>的限制，求解一个<span class="math inline">\(w\)</span>二次项的最优解！ 对于非线性的情况，我们只要把线性的映射到非线性情况就行了，非常的 easy，真是太爽了！</p><h2 id="为什么我们需要对偶问题"><strong>为什么我们需要对偶问题</strong></h2><p>不是都爽了吗？还要搞啥呢？ 真的爽吗？我们想一下，线性的情况确实蛮爽的，但是非线性的情况，当我们要将低维的特征映射的高维的时候，将<span class="math inline">\(x_n\)</span>映射到<span class="math inline">\(z_n\)</span>，这个代价是非常巨大的。比如一个<span class="math inline">\(x_n=(x1, x2)\)</span>特征，映射到二维就需要<span class="math inline">\(z_n=(1, x1, x2, x1^2, x2^2, x1x2)\)</span>这么多东西，更高的维度需要的代价会更大。 为了解决这个问题，我们首先需要了解一个叫做 Lagrange Mulitipliers 的东西。这个东西在正则化的时候，我们其实已经了解过了，大致是这个形式： <span class="math display">\[min_wE_{in}(w), w^Tw\le C \cong min_w(E_{aug}(w)=E_{in}(w)+\frac{\lambda}{N}w^Tw)\]</span></p><p>既然<span class="math inline">\(w^Tw\le C\)</span>我们可以通过在后面加一个<span class="math inline">\(\frac{\lambda}{N}w^Tw\)</span>来获得等价问题，那么<span class="math inline">\(y_n(w^Tz_n+b)\ge 1\)</span>这个条件我们显然也可以如法炮制，大致是这个形式： <span class="math display">\[L(b,w,a)=\frac{1}{2}w^Tw+\sum_{n=1}^N\alpha_n(1-y_n(w^Tz_n+b))\]</span><span class="math display">\[\min_{b,w}(max_{\alpha_n\ge 0}L(b,w,\alpha))\]</span>我们保证<span class="math inline">\(\alpha_n\ge 0\)</span>这样的条件，这样就可以使得所有<span class="math inline">\(1-y_n(w^Tz_n+b)\)</span>都符合不超过<span class="math inline">\(0\)</span>的要求。因为如果有一个<span class="math inline">\(1-y_n(w^Tz_n+b)ge 0\)</span>，那就可以让对应的<span class="math inline">\(\alpha_n\)</span>为无穷，优化的 max 就是无穷大了。 但是这个东西我们还是不会优化，这时候就需要引入一个叫做拉格朗日对偶问题的东西，这个东西通过一系列的证明，表示在我们这个优化形式下，这两个问题是完全等价的： <span class="math display">\[min_{b,w}(max_{a_n\ge 0}L(b,w,\alpha))\cong max_{a_n\ge0}(min_{b,w}L(b,w,\alpha))\]</span></p><p>展开来写的话就是： <span class="math display">\[max_{a_n\ge0}(min_{b,w}\frac{1}{2}w^Tw+\sum_{n=1}^N\alpha_n(1-y_n(w^Tz_n+b)))\]</span></p><p>下面开始魔法时刻~</p><p>如果里面的那个最小值优化是一个最值的话，那么对于每一个参数的偏导都应该是<span class="math inline">\(0\)</span>。基于这一点，我们就可以做一些有趣的事情了。 首先对里面的<span class="math inline">\(L(b,w,\alpha)\)</span>求<span class="math inline">\(b\)</span>的偏导，可以得到如下式子： <span class="math display">\[\frac{\partial L(b,w,\alpha)}{\partial b}=-\sum_{n=1}^N\alpha_ny_n=0\]</span></p><p>这个结论其实就有点神奇了，但是更神奇的地方在于，我们把<span class="math inline">\(L(b,w,\alpha)\)</span>展开，发现<span class="math inline">\(b\)</span>的系数恰好就是这个<span class="math inline">\(-\sum_{n=1}^N\alpha_ny_n\)</span>！也就说<span class="math inline">\(b\)</span>这一项直接就可以消掉！</p><p>继续我们的魔法之旅，下面对<span class="math inline">\(L(b,w,\alpha)\)</span>求<span class="math inline">\(w_i\)</span>的偏导： <span class="math display">\[\frac{\partial L(b,w,\alpha)}{\partial w_i}=w_i-\sum_{n=1}^N\alpha_ny_nz_{n,i}=0\]</span></p><p>这样就得到了一个有趣的结论，<span class="math inline">\(w=\sum_{n=1}^N\alpha_ny_nz_n\)</span>。 将这个结论带入回原来的式子，可以得到： <span class="math display">\[max(\frac{1}{2}||\sum_{n=1}^N\alpha_ny_nz_n||+\sum_{n=1}^N\alpha_n)\\ a_n\ge0, \sum_{n=1}^N\alpha_ny_n=0, w=\sum_{n=1}^N\alpha_ny_nz_n\]</span></p><p>所谓的 KKT 条件，其实也就是上面式子的几个条件了，还有一个原来定义中的条件<span class="math inline">\(y_n(w^Tx_n+b)\ge1\)</span>，当然还有一个最有趣的 KKT 条件就是<span class="math inline">\(\alpha_n(1-y_n(w^Tz_n+b))=0\)</span>，因为在优化的过程中，我们要求最后结果的最大值，而两个非负数相乘最后想要结果最大的话，必须其中一个为<span class="math inline">\(0\)</span>，那么上面这个条件就必须成立。 通过取负数将优化最大值变成优化最小值，然后展开可以得到： <span class="math display">\[min_{\alpha}\frac{1}{2}\sum_{n=1}^N\sum_{m=1}^N\alpha_n\alpha_my_ny_mz_n^Tz_m-\sum_{n=1}^N\alpha_n\\ a_n\ge0, \sum_{n=1}^N\alpha_ny_n=0\]</span></p><p>诶，这个东西的形式是不是有点眼熟？没错，就是二次规划，太爽了，对偶问题我们现在也会求解了！ 求解出<span class="math inline">\(\alpha\)</span>之后，就可以通过<span class="math inline">\(w=\sum_{n=1}^N\alpha_ny_nz_n\)</span>求出<span class="math inline">\(w\)</span>。那我们的 bias 呢？还记得<span class="math inline">\(\alpha_n(1-y_n(w^Tz_n+b))=0\)</span>这个 KKT 条件吗？在<span class="math inline">\(\alpha_n&gt;0\)</span>的时候，我们就可以求出对应的 bias 了，而那些<span class="math inline">\(\alpha_n&gt;0\)</span>所对应的<span class="math inline">\(n\)</span>，其实就是所谓的支持向量！ 当然了，这个时候聪明的读者就会发现，你这个不还是有一个<span class="math inline">\(z_n^Tz_m\)</span>吗？复杂度并没有降低啊？</p><h2 id="久等了核方法"><strong>久等了，核方法！</strong></h2><p>核方法其实跟 SVM 的推导并没有什么直接关系，但是在 SVM 的使用中却是大放异彩的。 所谓的核方法，其实就是一些神奇的<span class="math inline">\(z_n^Tz_m\)</span>求法，这些核方法可以完成很好的空间变换功能，甚至将原来的特征空间变换到无限维，正是核方法的出现，才使得 SVM 在某一个历史时期“一统江湖”的。 比如我们有一个垃圾特征空间变换，将<span class="math inline">\(x_n=(x_1, x_2, x_3, ..., x_d)\)</span>变换到<span class="math inline">\(\phi_n=(1, x_1, x_2, ..., x_d, x_1x_2, ..., x_1x_d, x_2x_1, x_2^2, ..., x_2x_d, ..., x_d^2)\)</span>。 现在我们想知道<span class="math inline">\(\phi_n(x)^T\phi_n(x^*)\)</span>，直接乘起来的话复杂度是<span class="math inline">\(O(d^2)\)</span>。我们通过观察式子，进行一定程度的优化可以发现<span class="math inline">\(\phi_n(x)^T\phi_n(x)=1+x^Tx^*+(x^Tx^*)(x^Tx^*)\)</span>。这样复杂度就是<span class="math inline">\(O(d)\)</span>了，是不是很棒？ 当然了，还有很多例如 RBF 之类的比较复杂的核方法，这里就不加赘述了，在实际应用中通常将<span class="math inline">\(z_n^Tz_m\)</span>换成<span class="math inline">\(K(x_n, x_m)\)</span>就代表核方法啦。</p><h2 id="虚假的-svm-和真正的-svm"><strong>虚假的 SVM 和真正的 SVM</strong></h2><p>其实上面讲的 SVM 跟实际使用的 SVM 并不完全一致，主要问题出现在大前提上，就是<span class="math inline">\(y_n(w^Tx_n+b)\ge1\)</span>这个定义。这个定义要求我们每个点都不能分割错误，过于严格，因为有一些噪音的结果是不需要考虑的。我们的 SVM 需要容忍一些分割错误。 我们将原来的 SVM 扩展一下，形式如下： <span class="math display">\[min_{b,w,\xi}\frac{1}{2}w^Tw+C\sum_{n=1}^N\xi_n\\ y_n(w^Tz_n+b)\ge1-\xi_n, \xi_n\ge0\]</span></p><p>其中参数<span class="math inline">\(C\)</span>就是用来 trade-off 容错率和分割线的鲁棒性的，参数<span class="math inline">\(\xi_n\)</span>就是表示这个点到底犯了多少错。 用 Lagrange Mulitipliers 重写这个优化式子，形式是这样： <span class="math display">\[L(b,w,\xi,\alpha,\beta)=\frac{1}{2}w^Tw+C\sum_{n=1}^N\xi_n+\sum_{n=1}^N\alpha_n(1-\xi_n-y_n(w^Tz_n+b))+\sum_{n=1}^N\beta_n(-\xi_n)\]</span></p><p><span class="math display">\[max_{\alpha_n\ge0, \beta_n\ge0}(min_{b,w,\xi}L(b,w,\xi,\alpha,\beta))\]</span></p><p>跟原来一样，通过求<span class="math inline">\(\xi_n\)</span>的偏导可以发现： <span class="math display">\[\frac{\partial L(b,w,\xi,\alpha,\beta)}{\partial \xi_n}=C-\alpha_n-\beta_n=0\]</span> 跟之前的 bias 一样，我们将式子展开之后，发现<span class="math inline">\(\xi_n\)</span>的系数就是上面这个等于<span class="math inline">\(0\)</span>的式子，所以<span class="math inline">\(\xi_n\)</span>这一项就没了。但是需要多一个<span class="math inline">\(0\le\alpha_n\le C\)</span>的条件。很神奇，这一项没了之后，我们的式子又开始变得眼熟起来： <span class="math display">\[max_{0\le\alpha_n\le C}(min_{b,w}\frac{1}{2}w^Tw+\sum_{n=1}^N\alpha_n(1-y_n(w^Tz_n+b)))\]</span></p><p>后面的事情就很显而易见了，前面都推导过了，这里也就不加赘述了。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;其实是标题党啦... SVM 推导作为最臭名昭著的机器学习面试题，其实我在去年的这个时候准备头条实习面试的期间就已经“背诵”过了，但完全没有理解自己推导的是个什么东西。最近看台大的《机器学习技法》的课程视频，感觉这个 SVM 推导过程讲的非常清晰。最关键的是每一步的 motivation 都讲的非常清楚，正好博客也好久好久好久没有更新了，这里简单重复一下，以证明自己学会了 SVM 的推导吧~&lt;/p&gt;</summary>
    
    
    
    <category term="机器学习" scheme="https://wyc-ruiker.github.io/categories/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/"/>
    
    
  </entry>
  
  <entry>
    <title>GoodBye ICPC</title>
    <link href="https://wyc-ruiker.github.io/2017/12/20/goodbye-icpc/"/>
    <id>https://wyc-ruiker.github.io/2017/12/20/goodbye-icpc/</id>
    <published>2017-12-20T01:03:17.000Z</published>
    <updated>2021-12-16T11:33:00.000Z</updated>
    
    <content type="html"><![CDATA[<h2 id="心灵之约观光团">2016 心灵之约观光团</h2><ul><li>2016 浙江大学程序设计竞赛 一等奖</li><li>2016 浙江省程序设计竞赛 三等奖</li></ul><h2 id="rkmxtxwd-热裤暮夏天下无敌">2016 rkmxtxwd 热裤暮夏天下无敌</h2><ul><li>2016 ACM/ICPC 大连 银牌  36th place.</li><li>2016 ACM/ICPC 北京 银牌  29th place.</li></ul><h2 id="leatherclub-广东老乡">2017 LeatherClub 广东老乡</h2><ul><li>2017 ACM/ICPC 西安 金牌  17th place.</li><li>2017 CCPC 杭州 金牌  10th place.</li><li>2017 ACM/ICPC 南宁 金牌  9th place.</li><li>2017 ACM/ICPC 上海 ECL-final 铜牌  128th place.</li></ul><p>虽然很想继续...但是明年还要找工作（or 保研？？？），还要为了生计奔波，感觉自己除了竞赛什么都不会...其实竞赛也始终没有变成很强的选手...所以只能说再见了吧...</p><p>GoodBye ICPC</p>]]></content>
    
    
      
      
    <summary type="html">&lt;h2 id=&quot;心灵之约观光团&quot;&gt;2016 心灵之约观光团&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;2016 浙江大学程序设计竞赛 一等奖&lt;/li&gt;
&lt;li&gt;2016 浙江省程序设计竞赛 三等奖&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id=&quot;rkmxtxwd-热裤暮夏天下无敌&quot;&gt;2016 rkmxt</summary>
      
    
    
    
    <category term="acm" scheme="https://wyc-ruiker.github.io/categories/acm/"/>
    
    
  </entry>
  
</feed>
