设计一个近现代的条件分支预测器:从计数器到神经元

造机大实验的小笔记

本文所有的硬件实现均使用 SpinalHDL(基于 Scala)描述。

引子 #

在本学期《计算机组成原理》课的大实验——在 FPGA 上设计一个五级流水线 RISC-V 处理器——的扩展要求中,有一项是要求我们为处理器实现一个动态分支预测器。

课堂上,老师已经草草带过了最经典的两位饱和计数器分支预测器,即 Bimodal。Bimodal 的状态转换相当于只考虑每个分支最后几次跳转的情况,基于此进行预测。从直觉上,这样利用的分支历史信息十分有限,可能难以捕捉到一些跳转模式的特征。

Local & Global #

要利用更多的分支历史,有两种想法:

  • 对于每个分支,单独考虑其历史跳转,称作 Local Branch Prediction
  • 考虑全局发生的历史跳转,称作 Global Branch Prediction

读者可能会疑惑以上两种想法和 Bimodal 的区别。这里简单做说明。

Bimodal 对每一处分支指令,维护一个两位饱和计数器,决定该指令是否预测跳转。在实现上,一般要对每条分支指令的地址进行哈希,并将哈希后的结果映射到一张表 BTB,表的每项是一个两位饱和计数器。简单模拟即可发现,这种预测器只能捕捉 $(1110)^n$ 或 $(0001)^n$ 这种简单重复的跳转模式。一旦出现 $(0101)^n$ 这种规律并非“寻常情况占优”的模式,就难以预测正确。

Local Branch Prediction 使用两张表。第一张表的每项是一个 $n$ 位寄存器,记录了最近 $n$ 次分支指令的跳转结果(跳转为 1,不跳转为 0)。每当遇到分支指令时,该寄存器会左移一位,并将当前分支的实际结果写入最低位。跳转指令的地址经哈希后先查找这一张表,得到一个历史 pattern,再使用这个 pattern 作为索引,查找第二张表:饱和寄存器表。简单模拟即可发现,对于长度为 $n$ 的分支历史,全部的 $2^k$ 种 pattern 都会被映射到不同的 BTB 表项。这样有助于对复杂的跳转模式进行预测。

Global Branch Prediction 则不考虑分支指令的 PC。这种策略在全局维护一个 $n$ 位寄存器 GHR,记录最近 $n$ 次分支指令的跳转结果。查询 BTB 时,将 GHR 的值哈希并映射到 BTB。从直觉上看,在 $n$ 较小时,Global Branch Prediction 的准确率是比较差的,因为不同的分支指令可能拥有相同的局部跳转 pattern(例如 $(11110101)$ 和 $(11110000)$ 仅看前 4 位是完全相同的),这会导致大量不同的分支指令被映射到同一个 BTB 表项。当这些分支指令的表现变得不同时,预测器就会大量出错。因此,只有 $n$ 足够大,考虑的全局历史足够长时,才能区分开不同的分支,保证精度。

可以考虑将 Local 和 Bimodal 两种策略与经典 Bimodal 相结合,达到更好的预测效果。在大实验中,我将 Global 信息与 Bimodal 结合,实现了 Gshare/Gselect 预测器。

实现 Gshare/Gselect 预测器 #

Gshare/Gselect 预测器是将 Global Branch Prediction 与 Bimodal 预测相结合的产物。我们维护一个与 Global Branch Prediction 中相同的 GHR 作为全局历史。在查找 BTB 时,将“对分支指令的地址进行哈希”,改为“将分支指令的地址和全局跳转历史共同进行哈希”,从而计算出 BTB 索引,这样就将全局跳转历史纳入了预测考虑。有两种最简单的共同哈希方法:

  • BTB 索引 = 分支地址 XOR 全局历史。称作 Gshare
  • BTB 索引 = 分支地址 拼接上 全局历史。称作 Gselect

还可以使用不同的 Gselect 函数,例如 {branch_address[3:0], global_history[3:0]} 挑出了低 4 位进行拼接。下表展示了两条分支语句在不同条件下的 4 种情况,和使用如上的 Gselect 函数时,Gselect 和 Gshare 计算出的 BTB 索引。

Branch Address Global History Gselect Gshare
0000_0000 0000_0001 0000_0001 0000_0001
0000_0000 0000_0000 0000_0000 0000_0000
1111_1111 0000_0000 1111_0001 1111_1111
1111_1111 1000_0000 1111_0001 0111_1111

可以发现,只有在考虑 GHR 的情况下,才有可能对同一分支在不同条件下的情况进行区分,并且只有 Gshare 成功区分出了 4 种不同的情况。一般情况下,Gshare 的预测精度要比 Gselect 好。

那么,如何在大实验中实现这样的 Gshare 分支预测器呢?

在一切开始之前,我们需要先考虑,在经典 RISC-V 五级流水线架构下实现分支预测器,需要哪些控制信号。在 IF 阶段看来,分支预测器相当于一个函数,输入当前 PC 值,输出预测的下一个 PC 值。而在计算出分支指令真实目标的阶段看来,分支预测器相当于一个只进不出的黑盒,只需要输入当前 PC 值和分支指令的真实跳转情况,预测器内部会自己更新状态。假定我们的所有分支指令在 EXE 阶段计算出目标地址,那么就需要在 EXE 阶段更新预测器。从而,我们有分支预测器的 IO 信号:

 1val io = new Bundle {
 2    // connect to IF stage
 3    val if_pc = in UInt(32 bits)
 4    val predicted_pc = out UInt(32 bits)
 5
 6    // connect to EXE stage
 7    val exe_pc = in UInt(32 bits)
 8    val real_taken = in Bool()
 9    val real_target = in UInt(32 bits)
10    val is_branch = in Bool()
11}

这里需要说明的是,尽管前面一直在说“BTB 表的每项是一个饱和计数器”,但在实现上,为了工程的便利,我们将查表与查计数器分开,即将课本上的 BTB 一分为二,分为一个分支目标缓冲区(Branch Target Buffer, BTB)和一个模式历史表(Pattern History Table, PHT)。BTB 采用 1-way direct-mapped 组织,表项结构与查找方式和 Cache 类似,需要初始化 valid 字段为 false。此外,还需要一个单独的寄存器 Branch History Register, BHR(即前文的 GHR)存储分支历史信息。

 1case class BranchPredictorConfig(
 2    index_width: Int = 7,
 3    counter_width: Int = 2, // PHT counter width
 4    target_width: Int = 32,
 5    history_width: Int = 7, // BHR width
 6) {
 7    assert(history_width == index_width,
 8        "For GShare, history_width must be equal to index_width")
 9    def entry_num = 1 << index_width
10    def tag_width = 32 - index_width - 2
11
12    /* pc[index_width+1 : 2] used as index, then compare 
13     * pc[31 : index_width+2] with tag to check whether it's a hit.
14     * Lower 2 bits of pc are ignored since instructions are
15     * aligned to 4 bytes.
16     */
17}
18case class BTBEntry() extends Bundle {
19    val valid = Bool()
20    val tag = UInt(tag_width bits)
21    val target = UInt(target_width bits)
22}
23val btb = Vec.fill(entry_num)(Reg(BTBEntry())) simPublic()
24btb.foreach(_.valid init(False))
25val bhr = Reg(UInt(history_width bits)) init(0) simPublic()

PHT 则是真正的 2-bit 饱和计数器表,表项数目与 BTB 相同。BTB 命中后,使用相同的索引访问 PHT,根据 PHT 中对应的计数器值来决定分支的预测结果。

1object BPState {
2    def STRONGLY_NOT_TAKEN = U"00"
3    def WEAKLY_NOT_TAKEN = U"01"
4    def WEAKLY_TAKEN = U"10"
5    def STRONGLY_TAKEN = U"11"
6}
7val pht = Vec(Reg(UInt(counter_width bits)) init(STRONGLY_NOT_TAKEN), entry_num) simPublic()

这样,Gshare 预测器所需的数据结构就定义完成了。

定义一些查表的辅助函数:

1def get_index(pc: UInt): UInt = pc(index_width + 1 downto 2)
2def get_tag(pc: UInt): UInt = pc(31 downto index_width + 2)
3
4def get_pht_index(pc: UInt, history: UInt): UInt = {
5    (get_index(pc) ^ history).resize(index_width bits)
6}

get_pht_index 内部的异或说明了这是一个 Gshare 预测器。可以更换该函数,发明不同的 Gxx 预测器。

接下来只需要实现查表和更新表的逻辑。在工程上,一种便于维护的实现思路是,在预测器内部分出两个区域,一个区域是组合逻辑,组合查找 BTB 并组合返回;另一个区域是时序逻辑,用于在获取分支指令真实跳转结果后更新 BTB。

 1// combinational logic for IF stage
 2val if_area = new Area {
 3    val pc = io.if_pc
 4    val index = get_index(pc)
 5    val tag = get_tag(pc)
 6
 7    val pht_index = get_pht_index(pc, bhr)
 8
 9    val btb_entry = btb(index)
10    val pht_counter = pht(pht_index)
11
12    val btb_hit = btb_entry.valid && (btb_entry.tag === tag)
13    val pht_taken = pht_counter.msb
14    val predict_taken = btb_hit && pht_taken
15
16    io.predicted_pc := predict_taken ? btb_entry.target | (pc + 4)
17}
18
19// sequential logic for EXE stage
20val exe_area = new Area {
21    val pc = io.exe_pc
22    val index = get_index(pc)
23    val tag = get_tag(pc)
24    val real_taken = io.real_taken
25
26    val pht_index = get_pht_index(pc, bhr)
27
28    when(io.is_branch) {
29        // update PHT
30        val old_pht_counter = pht(pht_index)
31        val new_pht_counter = UInt(counter_width bits)
32        switch(old_pht_counter) {
33            is(STRONGLY_NOT_TAKEN) {
34                new_pht_counter := real_taken ? WEAKLY_NOT_TAKEN | STRONGLY_NOT_TAKEN
35            }
36            is(WEAKLY_NOT_TAKEN) {
37                new_pht_counter := real_taken ? WEAKLY_TAKEN | STRONGLY_NOT_TAKEN
38            }
39            is(WEAKLY_TAKEN) {
40                new_pht_counter := real_taken ? STRONGLY_TAKEN | WEAKLY_NOT_TAKEN
41            }
42            is(STRONGLY_TAKEN) {
43                new_pht_counter := real_taken ? STRONGLY_TAKEN | WEAKLY_TAKEN
44            }
45        }
46        pht(pht_index) := new_pht_counter
47
48        // update BTB
49        when(real_taken) {
50            btb(index).valid := True
51            btb(index).tag := tag
52            btb(index).target := io.real_target
53        }
54
55        bhr := (bhr |<< 1) | (io.real_taken.asUInt).resize(history_width bits)
56    }
57}

至此,我们实现了 Gshare 分支预测器。Gselect 稍作修改即可得到。

What’s Next? #

我们已经有了一个在大实验中表现足够优秀的分支预测器(也没必要卷了),但现代 CPU 的分支预测远远不止于此。

观察我们前面所有的基于查找表的分支预测器,可以发现它们都是通过建立跳转历史的 pattern 到分支结果的映射,来预测分支结果。也就是说,分支预测器是在尝试找到 pattern 和分支结果之间的映射规律。我们完全可以考虑使用机器学习算法来学习这一规律。

考虑机器学习中常见的感知机 Perceptron 模型,也就是神经网络中所谓的神经元。其计算 $$ f = w_0 + \sum_{i=0}^nw_ih_i $$ 的结果。其中,$w$ 为可变的权重,$h$ 为分支历史(Taken or Not Taken, 1 or -1)。当 $f \geq 0$ 时,预测跳转,否则预测不跳转。

下面推导如何更新权重。记真实跳转情况为 $t = 1 \ \text{or} \ -1$;为了简化,记权重向量为 $\mathbf{w}$,输入向量(分支历史,第 0 项为常数项 1)为 $\mathbf{x}$,预测结果表示为 $f = \mathbf{w} \cdot \mathbf{x}$。更新权重时,有两种情况:

  • 预测错误,即 $t \cdot f < 0$
  • 预测正确但不够自信,即 $f < \theta$,$\theta$ 是一个事先设定的值

这两种情况可以合并为:只要 $t \cdot f < \theta$,就需要更新。这相当于有损失函数:

$$ L(\mathbf{w}) = \max(0, \theta - t \cdot f) $$

容易验证预测正确且自信时 $L=0$,而其他情况 $L = \theta - t \cdot (\mathbf{w} \cdot \mathbf{x})$,此时需要更新权重以最小化损失函数。

使用梯度下降法:

$$ \frac{\partial L}{\partial \mathbf{w}} = \frac{\partial (\theta - t \cdot \mathbf{w} \cdot \mathbf{x})}{\partial \mathbf{w}} = -t \cdot \mathbf{x} $$

设学习率 $\eta$ 为 1,代入上面的梯度:

$$ \mathbf{w}_\text{new} = \mathbf{w}_\text{old} - \eta \cdot (-t \cdot \mathbf{x}) $$ $$ \mathbf{w}_\text{new} = \mathbf{w}_\text{old} + t \cdot \mathbf{x} $$

回顾 $\mathbf{x}$ 的含义,可以得出更新权重的方法:

  • 对于 $w_0$,如果实际结果是 Taken,则 $+1$,否则 $-1$
  • 对于其余 $w_i$,如果 $h_i$ 与当前分支结果相同则 $+1$,否则 $-1$

在实际使用时,可以制作一张表,索引为分支指令的地址,每个表内是一个 Perceptron。这实际上(大致)是 2019 年 AMD Zen2 仍在使用的预测器的结构,距离我们当前的时代相当近。

通过使用 Perceptron,我们将原先基于查找表的“每一个不同的 pattern 对应一个计数器”,转变为“每一个 pattern 长度对应一个 Perceptron”,使得表增长的速度由指数降低为线性,从而可以在使用相同资源(寄存器数量,RAM 等)的情况下,使用更长的分支历史信息进行分支预测,达到更好的精度。

目前,主流的现代处理器都使用一种集大成的分支预测器:TAGE 预测器。其主要思想是使用多张查找表,每张使用不同长度的分支历史信息,且长度几何增长。对每个表项维护一个 useful 计数器,反映不同长度历史信息发挥的作用。具体实现和维护方法较为复杂,在预测中也会使用到 Perceptron,不在本文中描述。

结语 #

尽管受限于实验平台资源数量和时间原因,我在大实验中只实现了略显古老的 Gshare 预测器,没有对 Perceptron 乃至 TAGE 做更多尝试(尝试了也没啥用),但通过了解分支预测器的发展历史和思想,我体验到了实现硬件机器学习的独特魅力。这一学期玩了一些很具代表性的可编程硬件,包括 FPGA 和 Tofino P4,也都尝试在上面实现了某种“预测”和“学习”算法,初步体会了硬件编程“因地制宜”的思想。之后如果有时间,可能会再做一些记录。

本文大量参考了知乎专栏:现代分支预测设计-早期分支预测器,该系列文章(2 篇)写得非常好,可作为扩展阅读。