跳到主要内容

PUCT 公式详解

在前一篇文章中,我们概述了 MCTS 与神经网络的结合。现在,让我们深入探讨 MCTS 选择阶段的核心——PUCT 公式

这个公式看似简单,却是 AlphaGo 成功的关键之一。它优雅地平衡了「利用已知的好走法」和「探索可能更好的走法」这两个看似矛盾的目标。


PUCT 公式

公式定义

AlphaGo 使用的 PUCT(Predictor Upper Confidence Trees) 公式:

a=argmaxa[Q(s,a)+U(s,a)]a^* = \arg\max_a \left[ Q(s,a) + U(s,a) \right]

其中:

U(s,a)=cpuctP(s,a)N(s)1+N(s,a)U(s,a) = c_{\text{puct}} \cdot P(s,a) \cdot \frac{\sqrt{N(s)}}{1 + N(s,a)}

完整展开:

a=argmaxa[Q(s,a)+cpuctP(s,a)N(s)1+N(s,a)]a^* = \arg\max_a \left[ Q(s,a) + c_{\text{puct}} \cdot P(s,a) \cdot \frac{\sqrt{N(s)}}{1 + N(s,a)} \right]

符号说明

符号意义来源
Q(s,a)Q(s,a)动作 aa 的平均价值MCTS 统计
P(s,a)P(s,a)动作 aa 的先验概率Policy Network
N(s)N(s)父节点的访问次数MCTS 统计
N(s,a)N(s,a)子节点的访问次数MCTS 统计
cpuctc_{\text{puct}}探索常数超参数

直观理解

PUCT 公式可以分为两部分:

总分数 = Q(s,a)        + U(s,a)
↓ ↓
利用项 探索项
「这步棋多好?」 「这步棋值得再探索吗?」

利用项 Q(s,a)

  • 这步棋过去的平均表现
  • 访问越多,估计越准确
  • 鼓励选择「已知好」的走法

探索项 U(s,a)

  • 这步棋还有多少探索价值
  • 访问越少,探索奖励越高
  • 鼓励尝试「可能好」的走法

各项的意义

Q(s,a):平均价值

Q(s,a)Q(s,a) 是从节点 (s,a)(s,a) 出发的所有模拟的平均结果:

Q(s,a)=W(s,a)N(s,a)=iziN(s,a)Q(s,a) = \frac{W(s,a)}{N(s,a)} = \frac{\sum_i z_i}{N(s,a)}

其中 zi{1,+1}z_i \in \{-1, +1\} 是第 ii 次模拟的结果。

特性

  • 范围:[1,+1][-1, +1]
  • 初始值:未定义(需要至少一次访问)
  • 随着访问次数增加趋于稳定

解读

  • Q=0.6Q = 0.6:这步棋的胜率约 80%(因为 Q=2×胜率1Q = 2 \times \text{胜率} - 1
  • Q=0Q = 0:胜负各半
  • Q=0.3Q = -0.3:这步棋的胜率约 35%

P(s,a):先验概率

P(s,a)P(s,a) 来自 Policy Network 的输出:

P(s,a)=πθ(as)P(s,a) = \pi_\theta(a|s)

特性

  • 范围:[0,1][0, 1],且 aP(s,a)=1\sum_a P(s,a) = 1
  • 在节点首次展开时计算
  • 反映神经网络对「这步棋有多好」的判断

作用

  • 高概率的动作会被优先探索
  • 即使访问次数为 0,也有探索动力
  • 引导搜索聚焦在「看起来合理」的走法

N(s) 与 N(s,a):访问次数

N(s)N(s) 是父节点的总访问次数:

N(s)=aN(s,a)N(s) = \sum_a N(s,a)

探索项中的角色

N(s)1+N(s,a)\frac{\sqrt{N(s)}}{1 + N(s,a)}

这个分数的行为:

  • N(s,a)=0N(s,a) = 0 时,分数 = N(s)\sqrt{N(s)}(最大探索动力)
  • 随着 N(s,a)N(s,a) 增加,分数下降
  • N(s,a)N(s)N(s,a) \gg \sqrt{N(s)} 时,分数趋近于 0

这确保了:

  1. 每个动作至少被探索一次(如果 P(s,a)>0P(s,a) > 0
  2. 探索动力随访问递减
  3. 最终选择由 QQ 值主导

c_puct:探索常数

cpuctc_{\text{puct}} 控制探索与利用的平衡:

cpuctc_{\text{puct}}效果
较小(如 0.5)更偏向利用,快速聚焦在好的走法
适中(如 1-2)平衡探索与利用
较大(如 5)更偏向探索,尝试更多可能性

AlphaGo 使用的值:cpuct=1.5c_{\text{puct}} = 1.5(根据论文)。


与 UCB1 的关系

UCB1 公式回顾

传统 MCTS 使用的 UCB1 公式:

UCB1(s,a)=Xˉs,a+clnN(s)N(s,a)\text{UCB1}(s,a) = \bar{X}_{s,a} + c \sqrt{\frac{\ln N(s)}{N(s,a)}}

其中 Xˉs,a\bar{X}_{s,a} 是平均回报。

比较

方面UCB1PUCT
利用项Xˉs,a\bar{X}_{s,a}(平均回报)Q(s,a)Q(s,a)(平均价值)
探索项lnNn\sqrt{\frac{\ln N}{n}}(置信界限)PN1+nP \cdot \frac{\sqrt{N}}{1+n}(先验引导)
先验信息使用 Policy Network
探索衰减对数衰减线性衰减

PUCT 的优势

  1. 利用先验知识:Policy Network 提供的 P(s,a)P(s,a) 让搜索从一开始就聚焦在合理的走法上

  2. 更快的收敛:线性衰减(1/(1+n)1/(1+n))比对数衰减(1/lnN/n1/\sqrt{\ln N / n})更快让搜索聚焦

  3. 可调控的探索P(s,a)P(s,a)cpuctc_{\text{puct}} 提供了更多调控探索的手段

理论背景

UCB1 有严格的理论保证(regret bound),但这些保证假设:

  • 每个臂(动作)是独立的
  • 没有先验信息

在围棋中,我们有强大的先验(Policy Network),PUCT 能更好地利用这些信息。


数学推导

从多臂赌博机说起

PUCT 的灵感来自多臂赌博机(Multi-Armed Bandit) 问题。

想象你面前有 KK 台老虎机,每台的获胜概率不同但未知。你的目标是最大化总获胜次数。策略是:

  • 利用:拉看起来最好的那台
  • 探索:尝试其他台,可能发现更好的

UCB1 是这个问题的经典解法,PUCT 是其变体。

UCB 的理论基础

对于随机变量 XX,由 Hoeffding 不等式:

P(Xˉnμϵ)2exp(2nϵ2)P(|\bar{X}_n - \mu| \geq \epsilon) \leq 2 \exp(-2n\epsilon^2)

如果我们想要以 1/t41/t^4 的概率犯错,需要:

ϵ=2lntn\epsilon = \sqrt{\frac{2 \ln t}{n}}

这就是 UCB1 探索项的来源。

PUCT 的修改

PUCT 对经典 UCB 做了几个修改:

1. 加入先验概率

U(s,a)P(s,a)(探索项)U(s,a) \propto P(s,a) \cdot (\text{探索项})

这让探索集中在高概率的动作上。

2. 改变探索项形式

lnNn\sqrt{\frac{\ln N}{n}} 改为 N1+n\frac{\sqrt{N}}{1+n}

这加速了收敛:

比较(假设 N = 1000, n = 10):

UCB1: sqrt(ln(1000) / 10) = sqrt(0.69) ≈ 0.83
PUCT: sqrt(1000) / 11 ≈ 2.87

PUCT 给予更多探索奖励,但衰减更快

3. 可学习的先验

P(s,a)P(s,a) 来自神经网络,会随着训练改进。这让 MCTS 和神经网络形成正向循环。

为什么这个形式有效?

直观解释:

U(s,a)=cpuctP(s,a)N(s)1+N(s,a)U(s,a) = c_{\text{puct}} \cdot P(s,a) \cdot \frac{\sqrt{N(s)}}{1 + N(s,a)}

  1. P(s,a)P(s,a):「专家说这步棋有多好」
  2. N(s)\sqrt{N(s)}:「我们对这个局面了解多少」
  3. 1/(1+N(s,a))1/(1 + N(s,a)):「我们对这步棋了解多少」

组合起来:当我们对局面了解很多,但对某步棋了解很少,且专家认为这步棋不错时,应该去探索它


可视化理解

探索项的变化

让我们可视化探索项如何随访问次数变化:

U(s,a) = c_puct × P(s,a) × √N(s) / (1 + N(s,a))

假设 P(s,a) = 0.1, c_puct = 1.5, N(s) = 1600

N(s,a) | U(s,a)
--------|----------
0 | 6.00 ← 未访问,最高探索奖励
1 | 3.00
5 | 1.00
10 | 0.55
50 | 0.12
100 | 0.06
400 | 0.015 ← 访问多次后,探索奖励很小

不同先验概率的影响

假设 c_puct = 1.5, N(s) = 1600, N(s,a) = 0

P(s,a) | U(s,a)
--------|----------
0.30 | 18.00 ← 高概率动作有更多探索动力
0.10 | 6.00
0.03 | 1.80
0.01 | 0.60
0.001 | 0.06 ← 低概率动作几乎不被探索

交互式探索

尝试调整下面的 cpuctc_{\text{puct}} 参数,观察它如何影响 MCTS 的选择:

載入中...

AlphaGo 中的具体实现

AlphaGo Fan/Lee 的实现

原版 AlphaGo 使用稍微不同的公式:

U(s,a)=cpuctP(s,a)bN(s,b)1+N(s,a)U(s,a) = c_{\text{puct}} \cdot P(s,a) \cdot \frac{\sqrt{\sum_b N(s,b)}}{1 + N(s,a)}

并且 Q(s,a)Q(s,a) 的计算考虑了虚拟损失:

def get_ucb_score(node, action, c_puct=1.5):
Q = node.W[action] / (node.N[action] + 1) # 避免除以零
P = node.prior[action]
N_parent = sum(node.N.values())
N_child = node.N[action]

U = c_puct * P * math.sqrt(N_parent) / (1 + N_child)

return Q + U

AlphaGo Zero 的实现

AlphaGo Zero 使用更简洁的实现:

def select_action(node, c_puct=1.5):
"""选择 PUCT 分数最高的动作"""
N_parent = sum(node.visit_count.values())

def puct_score(action):
Q = node.value_sum[action] / (node.visit_count[action] + 1)
P = node.prior[action]
U = c_puct * P * math.sqrt(N_parent) / (1 + node.visit_count[action])
return Q + U

return max(node.legal_actions, key=puct_score)

处理未访问节点

N(s,a)=0N(s,a) = 0 时,Q(s,a)Q(s,a) 未定义。常见处理方式:

方法 1:使用父节点价值

Q = parent_value if N[action] == 0 else W[action] / N[action]

方法 2:使用初始值

Q = 0 if N[action] == 0 else W[action] / N[action]

方法 3:使用 FPU(First Play Urgency)

# 未访问节点使用较低的 Q 值
fpu_value = parent_Q - fpu_reduction
Q = fpu_value if N[action] == 0 else W[action] / N[action]

AlphaGo Zero 使用 FPU,这让搜索更倾向于先尝试访问过的节点。


实际调参经验

c_puct 的选择

cpuctc_{\text{puct}} 是最重要的超参数。实践中的指导原则:

1. 与网络品质相关

  • 网络很强(高准确率):可以用较小的 cpuctc_{\text{puct}}
  • 网络较弱:需要较大的 cpuctc_{\text{puct}} 来修正错误

2. 与搜索预算相关

  • 模拟次数多:可以用较大的 cpuctc_{\text{puct}}(有足够时间探索)
  • 模拟次数少:用较小的 cpuctc_{\text{puct}}(快速聚焦)

3. 与游戏特性相关

  • 战术性强的游戏:可能需要更多探索
  • 战略性强的游戏:可以更信任先验

典型值

项目cpuctc_{\text{puct}}
AlphaGo1.5
AlphaGo Zero1.0 - 2.0
AlphaZero1.25
KataGo0.5 - 1.0(动态调整)
Leela Zero1.5 - 2.0

动态调整

一些进阶实现使用动态 cpuctc_{\text{puct}}

def dynamic_cpuct(visit_count):
"""根据访问次数调整探索常数"""
base = 1.0
init = 1.5
log_base = 19652 # 调整参数

return math.log((visit_count + log_base + 1) / log_base) + init

这让搜索在早期更偏向探索,后期更偏向利用。

敏感度分析

cpuctc_{\text{puct}} 对棋力的影响:

实验(固定其他条件,变化 c_puct):

c_puct | 相对 Elo
-------|----------
0.5 | -50 ← 过度利用,错过好棋
1.0 | +20
1.5 | 0 ← 基准
2.0 | -10
3.0 | -30 ← 过度探索,浪费搜索
5.0 | -80

最佳值通常在 1.0-2.0 之间,但具体取决于网络品质和搜索预算。


进阶变体

PUCT 的变体

1. Polynomial PUCT (P-UCT)

U(s,a)=cP(s,a)N(s)α1+N(s,a)U(s,a) = c \cdot P(s,a) \cdot \frac{N(s)^\alpha}{1 + N(s,a)}

其中 α\alpha 是可调参数(通常 α=0.5\alpha = 0.5)。

2. 带噪音的 PUCT

在根节点加入 Dirichlet 噪音:

P(s,a)=(1ε)P(s,a)+εηaP'(s,a) = (1-\varepsilon) P(s,a) + \varepsilon \cdot \eta_a

其中 ηDir(α)\eta \sim \text{Dir}(\alpha)。这增加了探索的多样性。

3. UCT-like PUCT

U(s,a)=cP(s,a)ln(1+N(s)+cbase)1+N(s,a)U(s,a) = c \cdot P(s,a) \cdot \sqrt{\frac{\ln(1 + N(s) + c_{\text{base}})}{1 + N(s,a)}}

这结合了 UCT 的对数形式和 PUCT 的先验引导。

KataGo 的改进

KataGo 对 PUCT 做了多项改进:

1. 动态 cpuctc_{\text{puct}} 根据局面复杂度和搜索进度调整。

2. 价值预测的不确定性 考虑 Value Network 的预测信心。

3. 政策目标学习 直接学习 MCTS 访问分布,而非仅策略头输出。

其他选择公式

除了 PUCT,还有其他选择公式:

RAVE(Rapid Action Value Estimation)

QRAVE(s,a)=(1β)Q(s,a)+βQAMAF(s,a)Q_{\text{RAVE}}(s,a) = (1-\beta) Q(s,a) + \beta Q_{\text{AMAF}}(s,a)

使用「All Moves As First」统计来加速学习。

GRAVE(Generalized RAVE)

RAVE 的变体,使用祖先节点的统计信息。


理论分析

收敛性

PUCT 是否保证收敛到最佳解?

严格来说:没有像 UCB1 那样的理论保证。

实践中:在足够多的模拟后,PUCT 会收敛到高品质的解,因为:

  1. 探索项最终会趋近于零
  2. 所有动作最终都会被访问
  3. QQ 值会收敛到真实价值

复杂度分析

时间复杂度(每次模拟):

  • Selection:O(logN)O(\log N)(树的深度)
  • Expansion:O(A)O(A)(合法动作数,需要神经网络推理)
  • Evaluation:O(1)O(1)(Value Network)或 O(T)O(T)(Rollout,TT 是游戏长度)
  • Backpropagation:O(logN)O(\log N)

空间复杂度

  • 每个节点:O(A)O(A)(存储先验和访问统计)
  • 整棵树:O(NA)O(N \cdot A)NN 是节点数)

与 Minimax 的关系

cpuct0c_{\text{puct}} \to 0 且模拟次数 \to \infty 时,MCTS + PUCT 会近似于 Minimax 搜索。

但在有限预算下,PUCT 通常比 Minimax + Alpha-Beta 更有效率,因为它能更好地利用先验知识。


常见问题

Q:为什么不直接用 Policy Network 的输出作为选择?

A:Policy Network 可能会犯错。MCTS 的搜索能够:

  1. 验证神经网络的判断
  2. 发现神经网络忽略的好棋
  3. 修正神经网络的系统性偏见

实验显示,即使神经网络很强,加入搜索仍能显著提升棋力。

Q:PUCT 在哪些情况下表现不好?

  1. 先验概率完全错误:如果 Policy Network 把好棋评为低概率,PUCT 需要很多模拟才能修正

  2. 长期战术:PUCT 可能难以发现需要精确计算的长序列战术

  3. 对手模型缺失:PUCT 假设对手也用最佳策略,面对不合理的对手可能不是最优

Q:如何处理超大的动作空间?

一些技术:

  1. Policy Network 过滤:只考虑 P(s,a)>ϵP(s,a) > \epsilon 的动作
  2. 渐进展宽:先只展开少量动作,需要时再扩展
  3. 动态剪枝:移除明显差的动作

动画对应

本文涉及的核心概念与动画编号:

编号概念物理/数学对应
🎬 E4探索与利用多臂赌博机
🎬 C3PUCT 选择置信界限

总结

PUCT 公式是 AlphaGo MCTS 的核心,我们学习了:

  1. 公式结构Q+UQ + U,利用项加探索项
  2. 各项意义QQ 是经验价值,PP 是先验概率,NN 控制探索衰减
  3. 与 UCB1 的关系:PUCT 加入了先验并使用不同的探索项形式
  4. 数学推导:从多臂赌博机到 MCTS 选择
  5. 实际调参cpuctc_{\text{puct}} 的选择与影响
  6. 进阶变体:动态调整、噪音、KataGo 改进

PUCT 的优雅之处在于它简单而有效——用一个公式就平衡了探索与利用,并优雅地整合了神经网络的先验知识。


延伸阅读


参考资料

  1. Rosin, C. D. (2011). "Multi-armed bandits with episode context." Annals of Mathematics and Artificial Intelligence, 61(3), 203-230.
  2. Silver, D., et al. (2016). "Mastering the game of Go with deep neural networks and tree search." Nature, 529, 484-489.
  3. Silver, D., et al. (2017). "Mastering the game of Go without human knowledge." Nature, 550, 354-359.
  4. Kocsis, L., & Szepesvári, C. (2006). "Bandit based Monte-Carlo Planning." ECML.
  5. Auer, P., Cesa-Bianchi, N., & Fischer, P. (2002). "Finite-time analysis of the multiarmed bandit problem." Machine Learning, 47(2), 235-256.
  6. Wu, D., et al. (2019). "Accelerating Self-Play Learning in Go." arXiv preprint (KataGo 技术报告).

附录:完整实现范例

import math
import numpy as np
from typing import Dict, List, Optional

class MCTSNode:
"""MCTS 节点"""
def __init__(self, prior: float = 0.0):
self.prior = prior # P(s,a)
self.visit_count = 0 # N(s,a)
self.value_sum = 0.0 # W(s,a)
self.children: Dict[int, 'MCTSNode'] = {}

@property
def q_value(self) -> float:
"""计算 Q(s,a)"""
if self.visit_count == 0:
return 0.0
return self.value_sum / self.visit_count


class MCTS:
"""MCTS 搜索器,使用 PUCT"""

def __init__(
self,
policy_network,
value_network,
c_puct: float = 1.5,
num_simulations: int = 800
):
self.policy_network = policy_network
self.value_network = value_network
self.c_puct = c_puct
self.num_simulations = num_simulations

def search(self, root_state) -> Dict[int, float]:
"""执行 MCTS 搜索,返回动作的访问分布"""
root = MCTSNode()

# 展开根节点
policy = self.policy_network(root_state)
for action, prob in enumerate(policy):
if is_legal(root_state, action):
root.children[action] = MCTSNode(prior=prob)

# 执行模拟
for _ in range(self.num_simulations):
self._simulate(root, root_state)

# 返回访问分布
total_visits = sum(
child.visit_count for child in root.children.values()
)
return {
action: child.visit_count / total_visits
for action, child in root.children.items()
}

def _simulate(self, node: MCTSNode, state) -> float:
"""执行单次模拟"""

# 如果是终局状态
if is_terminal(state):
return get_result(state)

# 如果是叶节点,展开并评估
if not node.children:
policy = self.policy_network(state)
value = self.value_network(state)

for action, prob in enumerate(policy):
if is_legal(state, action):
node.children[action] = MCTSNode(prior=prob)

return value

# Selection:选择 PUCT 分数最高的动作
action = self._select_action(node)
child = node.children[action]
next_state = apply_action(state, action)

# 递归模拟
value = -self._simulate(child, next_state)

# Backpropagation:更新统计
child.visit_count += 1
child.value_sum += value

return value

def _select_action(self, node: MCTSNode) -> int:
"""使用 PUCT 公式选择动作"""
total_visits = sum(
child.visit_count for child in node.children.values()
)

def puct_score(action: int, child: MCTSNode) -> float:
# Q(s,a):平均价值
q = child.q_value

# U(s,a):探索加成
u = (
self.c_puct
* child.prior
* math.sqrt(total_visits)
/ (1 + child.visit_count)
)

return q + u

return max(
node.children.keys(),
key=lambda a: puct_score(a, node.children[a])
)


# 使用范例
def play_game():
policy_net = PolicyNetwork()
value_net = ValueNetwork()

mcts = MCTS(
policy_network=policy_net,
value_network=value_net,
c_puct=1.5,
num_simulations=1600
)

state = initial_state()

while not is_terminal(state):
# 执行 MCTS 搜索
visit_distribution = mcts.search(state)

# 选择访问次数最多的动作
action = max(visit_distribution.keys(),
key=lambda a: visit_distribution[a])

# 执行动作
state = apply_action(state, action)
print(f"Selected action {action} with visit ratio "
f"{visit_distribution[action]:.2%}")

print(f"Game result: {get_result(state)}")

这个实现展示了 PUCT 公式在 MCTS 中的核心角色。实际的 AlphaGo 实现还包括:

  • 并行搜索(虚拟损失)
  • 批次神经网络评估
  • 树的重用
  • 狄利克雷噪音
  • 温度控制等功能