设计一个近现代的条件分支预测器:从计数器到神经元
造机大实验的小笔记
本文所有的硬件实现均使用 SpinalHDL(基于 Scala)描述。
引子 #
在本学期《计算机组成原理》课的大实验——在 FPGA 上设计一个五级流水线 RISC-V 处理器——的扩展要求中,有一项是要求我们为处理器实现一个动态分支预测器。
课堂上,老师已经草草带过了最经典的两位饱和计数器分支预测器,即 Bimodal。Bimodal 的状态转换相当于只考虑每个分支最后几次跳转的情况,基于此进行预测。从直觉上,这样利用的分支历史信息十分有限,可能难以捕捉到一些跳转模式的特征。
Local & Global #
要利用更多的分支历史,有两种想法:
- 对于每个分支,单独考虑其历史跳转,称作 Local Branch Prediction
- 考虑全局发生的历史跳转,称作 Global Branch Prediction
读者可能会疑惑以上两种想法和 Bimodal 的区别。这里简单做说明。
Bimodal 对每一处分支指令,维护一个两位饱和计数器,决定该指令是否预测跳转。在实现上,一般要对每条分支指令的地址进行哈希,并将哈希后的结果映射到一张表 BTB,表的每项是一个两位饱和计数器。简单模拟即可发现,这种预测器只能捕捉 $(1110)^n$ 或 $(0001)^n$ 这种简单重复的跳转模式。一旦出现 $(0101)^n$ 这种规律并非“寻常情况占优”的模式,就难以预测正确。
Local Branch Prediction 使用两张表。第一张表的每项是一个 $n$ 位寄存器,记录了最近 $n$ 次分支指令的跳转结果(跳转为 1,不跳转为 0)。每当遇到分支指令时,该寄存器会左移一位,并将当前分支的实际结果写入最低位。跳转指令的地址经哈希后先查找这一张表,得到一个历史 pattern,再使用这个 pattern 作为索引,查找第二张表:饱和寄存器表。简单模拟即可发现,对于长度为 $n$ 的分支历史,全部的 $2^k$ 种 pattern 都会被映射到不同的 BTB 表项。这样有助于对复杂的跳转模式进行预测。
Global Branch Prediction 则不考虑分支指令的 PC。这种策略在全局维护一个 $n$ 位寄存器 GHR,记录最近 $n$ 次分支指令的跳转结果。查询 BTB 时,将 GHR 的值哈希并映射到 BTB。从直觉上看,在 $n$ 较小时,Global Branch Prediction 的准确率是比较差的,因为不同的分支指令可能拥有相同的局部跳转 pattern(例如 $(11110101)$ 和 $(11110000)$ 仅看前 4 位是完全相同的),这会导致大量不同的分支指令被映射到同一个 BTB 表项。当这些分支指令的表现变得不同时,预测器就会大量出错。因此,只有 $n$ 足够大,考虑的全局历史足够长时,才能区分开不同的分支,保证精度。
可以考虑将 Local 和 Bimodal 两种策略与经典 Bimodal 相结合,达到更好的预测效果。在大实验中,我将 Global 信息与 Bimodal 结合,实现了 Gshare/Gselect 预测器。
实现 Gshare/Gselect 预测器 #
Gshare/Gselect 预测器是将 Global Branch Prediction 与 Bimodal 预测相结合的产物。我们维护一个与 Global Branch Prediction 中相同的 GHR 作为全局历史。在查找 BTB 时,将“对分支指令的地址进行哈希”,改为“将分支指令的地址和全局跳转历史共同进行哈希”,从而计算出 BTB 索引,这样就将全局跳转历史纳入了预测考虑。有两种最简单的共同哈希方法:
- BTB 索引 = 分支地址 XOR 全局历史。称作 Gshare
- BTB 索引 = 分支地址 拼接上 全局历史。称作 Gselect
还可以使用不同的 Gselect 函数,例如 {branch_address[3:0], global_history[3:0]} 挑出了低 4 位进行拼接。下表展示了两条分支语句在不同条件下的 4 种情况,和使用如上的 Gselect 函数时,Gselect 和 Gshare 计算出的 BTB 索引。
| Branch Address | Global History | Gselect | Gshare |
|---|---|---|---|
| 0000_0000 | 0000_0001 | 0000_0001 | 0000_0001 |
| 0000_0000 | 0000_0000 | 0000_0000 | 0000_0000 |
| 1111_1111 | 0000_0000 | 1111_0001 | 1111_1111 |
| 1111_1111 | 1000_0000 | 1111_0001 | 0111_1111 |
可以发现,只有在考虑 GHR 的情况下,才有可能对同一分支在不同条件下的情况进行区分,并且只有 Gshare 成功区分出了 4 种不同的情况。一般情况下,Gshare 的预测精度要比 Gselect 好。
那么,如何在大实验中实现这样的 Gshare 分支预测器呢?
在一切开始之前,我们需要先考虑,在经典 RISC-V 五级流水线架构下实现分支预测器,需要哪些控制信号。在 IF 阶段看来,分支预测器相当于一个函数,输入当前 PC 值,输出预测的下一个 PC 值。而在计算出分支指令真实目标的阶段看来,分支预测器相当于一个只进不出的黑盒,只需要输入当前 PC 值和分支指令的真实跳转情况,预测器内部会自己更新状态。假定我们的所有分支指令在 EXE 阶段计算出目标地址,那么就需要在 EXE 阶段更新预测器。从而,我们有分支预测器的 IO 信号:
1val io = new Bundle {
2 // connect to IF stage
3 val if_pc = in UInt(32 bits)
4 val predicted_pc = out UInt(32 bits)
5
6 // connect to EXE stage
7 val exe_pc = in UInt(32 bits)
8 val real_taken = in Bool()
9 val real_target = in UInt(32 bits)
10 val is_branch = in Bool()
11}这里需要说明的是,尽管前面一直在说“BTB 表的每项是一个饱和计数器”,但在实现上,为了工程的便利,我们将查表与查计数器分开,即将课本上的 BTB 一分为二,分为一个分支目标缓冲区(Branch Target Buffer, BTB)和一个模式历史表(Pattern History Table, PHT)。BTB 采用 1-way direct-mapped 组织,表项结构与查找方式和 Cache 类似,需要初始化 valid 字段为 false。此外,还需要一个单独的寄存器 Branch History Register, BHR(即前文的 GHR)存储分支历史信息。
1case class BranchPredictorConfig(
2 index_width: Int = 7,
3 counter_width: Int = 2, // PHT counter width
4 target_width: Int = 32,
5 history_width: Int = 7, // BHR width
6) {
7 assert(history_width == index_width,
8 "For GShare, history_width must be equal to index_width")
9 def entry_num = 1 << index_width
10 def tag_width = 32 - index_width - 2
11
12 /* pc[index_width+1 : 2] used as index, then compare
13 * pc[31 : index_width+2] with tag to check whether it's a hit.
14 * Lower 2 bits of pc are ignored since instructions are
15 * aligned to 4 bytes.
16 */
17}
18case class BTBEntry() extends Bundle {
19 val valid = Bool()
20 val tag = UInt(tag_width bits)
21 val target = UInt(target_width bits)
22}
23val btb = Vec.fill(entry_num)(Reg(BTBEntry())) simPublic()
24btb.foreach(_.valid init(False))
25val bhr = Reg(UInt(history_width bits)) init(0) simPublic()PHT 则是真正的 2-bit 饱和计数器表,表项数目与 BTB 相同。BTB 命中后,使用相同的索引访问 PHT,根据 PHT 中对应的计数器值来决定分支的预测结果。
1object BPState {
2 def STRONGLY_NOT_TAKEN = U"00"
3 def WEAKLY_NOT_TAKEN = U"01"
4 def WEAKLY_TAKEN = U"10"
5 def STRONGLY_TAKEN = U"11"
6}
7val pht = Vec(Reg(UInt(counter_width bits)) init(STRONGLY_NOT_TAKEN), entry_num) simPublic()这样,Gshare 预测器所需的数据结构就定义完成了。
定义一些查表的辅助函数:
1def get_index(pc: UInt): UInt = pc(index_width + 1 downto 2)
2def get_tag(pc: UInt): UInt = pc(31 downto index_width + 2)
3
4def get_pht_index(pc: UInt, history: UInt): UInt = {
5 (get_index(pc) ^ history).resize(index_width bits)
6}get_pht_index 内部的异或说明了这是一个 Gshare 预测器。可以更换该函数,发明不同的 Gxx 预测器。
接下来只需要实现查表和更新表的逻辑。在工程上,一种便于维护的实现思路是,在预测器内部分出两个区域,一个区域是组合逻辑,组合查找 BTB 并组合返回;另一个区域是时序逻辑,用于在获取分支指令真实跳转结果后更新 BTB。
1// combinational logic for IF stage
2val if_area = new Area {
3 val pc = io.if_pc
4 val index = get_index(pc)
5 val tag = get_tag(pc)
6
7 val pht_index = get_pht_index(pc, bhr)
8
9 val btb_entry = btb(index)
10 val pht_counter = pht(pht_index)
11
12 val btb_hit = btb_entry.valid && (btb_entry.tag === tag)
13 val pht_taken = pht_counter.msb
14 val predict_taken = btb_hit && pht_taken
15
16 io.predicted_pc := predict_taken ? btb_entry.target | (pc + 4)
17}
18
19// sequential logic for EXE stage
20val exe_area = new Area {
21 val pc = io.exe_pc
22 val index = get_index(pc)
23 val tag = get_tag(pc)
24 val real_taken = io.real_taken
25
26 val pht_index = get_pht_index(pc, bhr)
27
28 when(io.is_branch) {
29 // update PHT
30 val old_pht_counter = pht(pht_index)
31 val new_pht_counter = UInt(counter_width bits)
32 switch(old_pht_counter) {
33 is(STRONGLY_NOT_TAKEN) {
34 new_pht_counter := real_taken ? WEAKLY_NOT_TAKEN | STRONGLY_NOT_TAKEN
35 }
36 is(WEAKLY_NOT_TAKEN) {
37 new_pht_counter := real_taken ? WEAKLY_TAKEN | STRONGLY_NOT_TAKEN
38 }
39 is(WEAKLY_TAKEN) {
40 new_pht_counter := real_taken ? STRONGLY_TAKEN | WEAKLY_NOT_TAKEN
41 }
42 is(STRONGLY_TAKEN) {
43 new_pht_counter := real_taken ? STRONGLY_TAKEN | WEAKLY_TAKEN
44 }
45 }
46 pht(pht_index) := new_pht_counter
47
48 // update BTB
49 when(real_taken) {
50 btb(index).valid := True
51 btb(index).tag := tag
52 btb(index).target := io.real_target
53 }
54
55 bhr := (bhr |<< 1) | (io.real_taken.asUInt).resize(history_width bits)
56 }
57}至此,我们实现了 Gshare 分支预测器。Gselect 稍作修改即可得到。
What’s Next? #
我们已经有了一个在大实验中表现足够优秀的分支预测器(也没必要卷了),但现代 CPU 的分支预测远远不止于此。
观察我们前面所有的基于查找表的分支预测器,可以发现它们都是通过建立跳转历史的 pattern 到分支结果的映射,来预测分支结果。也就是说,分支预测器是在尝试找到 pattern 和分支结果之间的映射规律。我们完全可以考虑使用机器学习算法来学习这一规律。
考虑机器学习中常见的感知机 Perceptron 模型,也就是神经网络中所谓的神经元。其计算 $$ f = w_0 + \sum_{i=0}^nw_ih_i $$ 的结果。其中,$w$ 为可变的权重,$h$ 为分支历史(Taken or Not Taken, 1 or -1)。当 $f \geq 0$ 时,预测跳转,否则预测不跳转。
下面推导如何更新权重。记真实跳转情况为 $t = 1 \ \text{or} \ -1$;为了简化,记权重向量为 $\mathbf{w}$,输入向量(分支历史,第 0 项为常数项 1)为 $\mathbf{x}$,预测结果表示为 $f = \mathbf{w} \cdot \mathbf{x}$。更新权重时,有两种情况:
- 预测错误,即 $t \cdot f < 0$
- 预测正确但不够自信,即 $f < \theta$,$\theta$ 是一个事先设定的值
这两种情况可以合并为:只要 $t \cdot f < \theta$,就需要更新。这相当于有损失函数:
$$ L(\mathbf{w}) = \max(0, \theta - t \cdot f) $$
容易验证预测正确且自信时 $L=0$,而其他情况 $L = \theta - t \cdot (\mathbf{w} \cdot \mathbf{x})$,此时需要更新权重以最小化损失函数。
使用梯度下降法:
$$ \frac{\partial L}{\partial \mathbf{w}} = \frac{\partial (\theta - t \cdot \mathbf{w} \cdot \mathbf{x})}{\partial \mathbf{w}} = -t \cdot \mathbf{x} $$
设学习率 $\eta$ 为 1,代入上面的梯度:
$$ \mathbf{w}_\text{new} = \mathbf{w}_\text{old} - \eta \cdot (-t \cdot \mathbf{x}) $$ $$ \mathbf{w}_\text{new} = \mathbf{w}_\text{old} + t \cdot \mathbf{x} $$回顾 $\mathbf{x}$ 的含义,可以得出更新权重的方法:
- 对于 $w_0$,如果实际结果是 Taken,则 $+1$,否则 $-1$
- 对于其余 $w_i$,如果 $h_i$ 与当前分支结果相同则 $+1$,否则 $-1$
在实际使用时,可以制作一张表,索引为分支指令的地址,每个表内是一个 Perceptron。这实际上(大致)是 2019 年 AMD Zen2 仍在使用的预测器的结构,距离我们当前的时代相当近。
通过使用 Perceptron,我们将原先基于查找表的“每一个不同的 pattern 对应一个计数器”,转变为“每一个 pattern 长度对应一个 Perceptron”,使得表增长的速度由指数降低为线性,从而可以在使用相同资源(寄存器数量,RAM 等)的情况下,使用更长的分支历史信息进行分支预测,达到更好的精度。
目前,主流的现代处理器都使用一种集大成的分支预测器:TAGE 预测器。其主要思想是使用多张查找表,每张使用不同长度的分支历史信息,且长度几何增长。对每个表项维护一个 useful 计数器,反映不同长度历史信息发挥的作用。具体实现和维护方法较为复杂,在预测中也会使用到 Perceptron,不在本文中描述。
结语 #
尽管受限于实验平台资源数量和时间原因,我在大实验中只实现了略显古老的 Gshare 预测器,没有对 Perceptron 乃至 TAGE 做更多尝试(尝试了也没啥用),但通过了解分支预测器的发展历史和思想,我体验到了实现硬件机器学习的独特魅力。这一学期玩了一些很具代表性的可编程硬件,包括 FPGA 和 Tofino P4,也都尝试在上面实现了某种“预测”和“学习”算法,初步体会了硬件编程“因地制宜”的思想。之后如果有时间,可能会再做一些记录。
本文大量参考了知乎专栏:现代分支预测设计-早期分支预测器,该系列文章(2 篇)写得非常好,可作为扩展阅读。