跳至主要内容

PUCT 公式詳解

喺上一篇文章入面,我哋概述咗 MCTS 與神經網絡嘅結合。而家,等我哋深入探討 MCTS 揀選階段嘅核心——PUCT 公式

呢條公式睇落簡單,卻係 AlphaGo 成功嘅關鍵之一。佢優雅噉平衡咗「利用已知嘅好著法」同「探索可能更好嘅著法」呢兩個睇落矛盾嘅目標。


PUCT 公式

公式定義

AlphaGo 使用嘅 PUCT(Predictor Upper Confidence Trees) 公式:

a=argmaxa[Q(s,a)+U(s,a)]a^* = \arg\max_a \left[ Q(s,a) + U(s,a) \right]

其中:

U(s,a)=cpuctP(s,a)N(s)1+N(s,a)U(s,a) = c_{\text{puct}} \cdot P(s,a) \cdot \frac{\sqrt{N(s)}}{1 + N(s,a)}

完整展開:

a=argmaxa[Q(s,a)+cpuctP(s,a)N(s)1+N(s,a)]a^* = \arg\max_a \left[ Q(s,a) + c_{\text{puct}} \cdot P(s,a) \cdot \frac{\sqrt{N(s)}}{1 + N(s,a)} \right]

符號說明

符號意義來源
Q(s,a)Q(s,a)動作 aa 嘅平均價值MCTS 統計
P(s,a)P(s,a)動作 aa 嘅先驗機率Policy Network
N(s)N(s)父節點嘅訪問次數MCTS 統計
N(s,a)N(s,a)子節點嘅訪問次數MCTS 統計
cpuctc_{\text{puct}}探索常數超參數

直觀理解

PUCT 公式可以分為兩部分:

總分數 = Q(s,a)        + U(s,a)
↓ ↓
利用項 探索項
「呢步棋幾好?」 「呢步棋值得再探索嗎?」

利用項 Q(s,a)

  • 呢步棋過去嘅平均表現
  • 訪問越多,估計越準確
  • 鼓勵揀選「已知好」嘅著法

探索項 U(s,a)

  • 呢步棋仲有幾多探索價值
  • 訪問越少,探索獎勵越高
  • 鼓勵嘗試「可能好」嘅著法

各項嘅意義

Q(s,a):平均價值

Q(s,a)Q(s,a) 係由節點 (s,a)(s,a) 出發嘅所有模擬嘅平均結果:

Q(s,a)=W(s,a)N(s,a)=iziN(s,a)Q(s,a) = \frac{W(s,a)}{N(s,a)} = \frac{\sum_i z_i}{N(s,a)}

其中 zi{1,+1}z_i \in \{-1, +1\} 係第 ii 次模擬嘅結果。

特性

  • 範圍:[1,+1][-1, +1]
  • 初始值:未定義(需要至少一次訪問)
  • 隨住訪問次數增加趨於穩定

解讀

  • Q=0.6Q = 0.6:呢步棋嘅勝率約 80%(因為 Q=2×勝率1Q = 2 \times \text{勝率} - 1
  • Q=0Q = 0:勝負各半
  • Q=0.3Q = -0.3:呢步棋嘅勝率約 35%

P(s,a):先驗機率

P(s,a)P(s,a) 嚟自 Policy Network 嘅輸出:

P(s,a)=πθ(as)P(s,a) = \pi_\theta(a|s)

特性

  • 範圍:[0,1][0, 1],且 aP(s,a)=1\sum_a P(s,a) = 1
  • 喺節點首次展開嗰陣計算
  • 反映神經網絡對「呢步棋有幾好」嘅判斷

作用

  • 高機率嘅動作會被優先探索
  • 即使訪問次數為 0,都有探索動力
  • 引導搜索聚焦喺「睇落合理」嘅著法

N(s) 與 N(s,a):訪問次數

N(s)N(s) 係父節點嘅總訪問次數:

N(s)=aN(s,a)N(s) = \sum_a N(s,a)

探索項入面嘅角色

N(s)1+N(s,a)\frac{\sqrt{N(s)}}{1 + N(s,a)}

呢個分數嘅行為:

  • N(s,a)=0N(s,a) = 0 嗰陣,分數 = N(s)\sqrt{N(s)}(最大探索動力)
  • 隨住 N(s,a)N(s,a) 增加,分數下降
  • N(s,a)N(s)N(s,a) \gg \sqrt{N(s)} 嗰陣,分數趨近於 0

呢個確保咗:

  1. 每個動作至少被探索一次(如果 P(s,a)>0P(s,a) > 0
  2. 探索動力隨訪問遞減
  3. 最終揀選由 QQ 值主導

c_puct:探索常數

cpuctc_{\text{puct}} 控制探索同利用嘅平衡:

cpuctc_{\text{puct}}效果
較細(如 0.5)更偏向利用,快速聚焦喺好嘅著法
適中(如 1-2)平衡探索同利用
較大(如 5)更偏向探索,嘗試更多可能性

AlphaGo 使用嘅值:cpuct=1.5c_{\text{puct}} = 1.5(根據論文)。


與 UCB1 嘅關係

UCB1 公式回顧

傳統 MCTS 使用嘅 UCB1 公式:

UCB1(s,a)=Xˉs,a+clnN(s)N(s,a)\text{UCB1}(s,a) = \bar{X}_{s,a} + c \sqrt{\frac{\ln N(s)}{N(s,a)}}

其中 Xˉs,a\bar{X}_{s,a} 係平均回報。

比較

方面UCB1PUCT
利用項Xˉs,a\bar{X}_{s,a}(平均回報)Q(s,a)Q(s,a)(平均價值)
探索項lnNn\sqrt{\frac{\ln N}{n}}(信賴界限)PN1+nP \cdot \frac{\sqrt{N}}{1+n}(先驗引導)
先驗資訊使用 Policy Network
探索衰減對數衰減線性衰減

PUCT 嘅優勢

  1. 利用先驗知識:Policy Network 提供嘅 P(s,a)P(s,a) 令搜索一開始就聚焦喺合理嘅著法上面

  2. 更快嘅收斂:線性衰減(1/(1+n)1/(1+n))比對數衰減(1/lnN/n1/\sqrt{\ln N / n})更快令搜索聚焦

  3. 可調控嘅探索P(s,a)P(s,a)cpuctc_{\text{puct}} 提供咗更多調控探索嘅手段

理論背景

UCB1 有嚴格嘅理論保證(regret bound),但呢啲保證假設:

  • 每個臂(動作)係獨立嘅
  • 無先驗資訊

喺圍棋入面,我哋有強大嘅先驗(Policy Network),PUCT 能更好噉利用呢啲資訊。


數學推導

由多臂賭博機講起

PUCT 嘅靈感嚟自多臂賭博機(Multi-Armed Bandit) 問題。

想像你面前有 KK 部老虎機,每部嘅獲勝機率唔同但未知。你嘅目標係最大化總獲勝次數。策略係:

  • 利用:拉睇落最好嗰部
  • 探索:試其他部,可能發現更好嘅

UCB1 係呢個問題嘅經典解法,PUCT 係佢嘅變體。

UCB 嘅理論基礎

對於隨機變數 XX,由 Hoeffding 不等式:

P(Xˉnμϵ)2exp(2nϵ2)P(|\bar{X}_n - \mu| \geq \epsilon) \leq 2 \exp(-2n\epsilon^2)

如果我哋想要以 1/t41/t^4 嘅機率犯錯,需要:

ϵ=2lntn\epsilon = \sqrt{\frac{2 \ln t}{n}}

呢個就係 UCB1 探索項嘅來源。

PUCT 嘅修改

PUCT 對經典 UCB 做咗幾個修改:

1. 加入先驗機率

U(s,a)P(s,a)(探索項)U(s,a) \propto P(s,a) \cdot (\text{探索項})

呢個令探索集中喺高機率嘅動作上面。

2. 改變探索項形式

lnNn\sqrt{\frac{\ln N}{n}} 改為 N1+n\frac{\sqrt{N}}{1+n}

呢個加速咗收斂:

比較(假設 N = 1000, n = 10):

UCB1: sqrt(ln(1000) / 10) = sqrt(0.69) ≈ 0.83
PUCT: sqrt(1000) / 11 ≈ 2.87

PUCT 畀更多探索獎勵,但衰減更快

3. 可學習嘅先驗

P(s,a)P(s,a) 嚟自神經網絡,會隨住訓練改進。呢個令 MCTS 同神經網絡形成正向循環。

點解呢個形式有效?

直觀解釋:

U(s,a)=cpuctP(s,a)N(s)1+N(s,a)U(s,a) = c_{\text{puct}} \cdot P(s,a) \cdot \frac{\sqrt{N(s)}}{1 + N(s,a)}

  1. P(s,a)P(s,a):「專家話呢步棋有幾好」
  2. N(s)\sqrt{N(s)}:「我哋對呢個局面了解幾多」
  3. 1/(1+N(s,a))1/(1 + N(s,a)):「我哋對呢步棋了解幾多」

組合起嚟:當我哋對局面了解好多,但對某步棋了解好少,且專家認為呢步棋唔錯嗰陣,應該去探索佢


視覺化理解

探索項嘅變化

等我哋視覺化探索項點樣隨訪問次數變化:

U(s,a) = c_puct × P(s,a) × √N(s) / (1 + N(s,a))

假設 P(s,a) = 0.1, c_puct = 1.5, N(s) = 1600

N(s,a) | U(s,a)
--------|----------
0 | 6.00 ← 未訪問,最高探索獎勵
1 | 3.00
5 | 1.00
10 | 0.55
50 | 0.12
100 | 0.06
400 | 0.015 ← 訪問多次後,探索獎勵好細

唔同先驗機率嘅影響

假設 c_puct = 1.5, N(s) = 1600, N(s,a) = 0

P(s,a) | U(s,a)
--------|----------
0.30 | 18.00 ← 高機率動作有更多探索動力
0.10 | 6.00
0.03 | 1.80
0.01 | 0.60
0.001 | 0.06 ← 低機率動作幾乎唔被探索

互動式探索

嘗試調整下面嘅 cpuctc_{\text{puct}} 參數,觀察佢點樣影響 MCTS 嘅揀選:

載入中...

AlphaGo 入面嘅具體實現

AlphaGo Fan/Lee 嘅實現

原版 AlphaGo 使用稍微唔同嘅公式:

U(s,a)=cpuctP(s,a)bN(s,b)1+N(s,a)U(s,a) = c_{\text{puct}} \cdot P(s,a) \cdot \frac{\sqrt{\sum_b N(s,b)}}{1 + N(s,a)}

並且 Q(s,a)Q(s,a) 嘅計算考慮咗虛擬損失:

def get_ucb_score(node, action, c_puct=1.5):
Q = node.W[action] / (node.N[action] + 1) # 避免除以零
P = node.prior[action]
N_parent = sum(node.N.values())
N_child = node.N[action]

U = c_puct * P * math.sqrt(N_parent) / (1 + N_child)

return Q + U

AlphaGo Zero 嘅實現

AlphaGo Zero 使用更簡潔嘅實現:

def select_action(node, c_puct=1.5):
"""揀選 PUCT 分數最高嘅動作"""
N_parent = sum(node.visit_count.values())

def puct_score(action):
Q = node.value_sum[action] / (node.visit_count[action] + 1)
P = node.prior[action]
U = c_puct * P * math.sqrt(N_parent) / (1 + node.visit_count[action])
return Q + U

return max(node.legal_actions, key=puct_score)

處理未訪問節點

N(s,a)=0N(s,a) = 0 嗰陣,Q(s,a)Q(s,a) 未定義。常見處理方式:

方法 1:使用父節點價值

Q = parent_value if N[action] == 0 else W[action] / N[action]

方法 2:使用初始值

Q = 0 if N[action] == 0 else W[action] / N[action]

方法 3:使用 FPU(First Play Urgency)

# 未訪問節點使用較低嘅 Q 值
fpu_value = parent_Q - fpu_reduction
Q = fpu_value if N[action] == 0 else W[action] / N[action]

AlphaGo Zero 使用 FPU,呢個令搜索更傾向於先嘗試訪問過嘅節點。


實際調參經驗

c_puct 嘅選擇

cpuctc_{\text{puct}} 係最重要嘅超參數。實踐入面嘅指導原則:

1. 同網絡質素相關

  • 網絡好強(高準確率):可以用較細嘅 cpuctc_{\text{puct}}
  • 網絡較弱:需要較大嘅 cpuctc_{\text{puct}} 嚟修正錯誤

2. 同搜索預算相關

  • 模擬次數多:可以用較大嘅 cpuctc_{\text{puct}}(有足夠時間探索)
  • 模擬次數少:用較細嘅 cpuctc_{\text{puct}}(快速聚焦)

3. 同遊戲特性相關

  • 戰術性強嘅遊戲:可能需要更多探索
  • 戰略性強嘅遊戲:可以更信任先驗

典型值

專案cpuctc_{\text{puct}}
AlphaGo1.5
AlphaGo Zero1.0 - 2.0
AlphaZero1.25
KataGo0.5 - 1.0(動態調整)
Leela Zero1.5 - 2.0

動態調整

一啲進階實現使用動態 cpuctc_{\text{puct}}

def dynamic_cpuct(visit_count):
"""根據訪問次數調整探索常數"""
base = 1.0
init = 1.5
log_base = 19652 # 調整參數

return math.log((visit_count + log_base + 1) / log_base) + init

呢個令搜索喺早期更偏向探索,後期更偏向利用。

敏感度分析

cpuctc_{\text{puct}} 對棋力嘅影響:

實驗(固定其他條件,變化 c_puct):

c_puct | 相對 Elo
-------|----------
0.5 | -50 ← 過度利用,錯過好棋
1.0 | +20
1.5 | 0 ← 基準
2.0 | -10
3.0 | -30 ← 過度探索,嘥搜索
5.0 | -80

最佳值通常喺 1.0-2.0 之間,但具體取決於網絡質素同搜索預算。


進階變體

PUCT 嘅變體

1. Polynomial PUCT (P-UCT)

U(s,a)=cP(s,a)N(s)α1+N(s,a)U(s,a) = c \cdot P(s,a) \cdot \frac{N(s)^\alpha}{1 + N(s,a)}

其中 α\alpha 係可調參數(通常 α=0.5\alpha = 0.5)。

2. 帶噪音嘅 PUCT

喺根節點加入 Dirichlet 噪音:

P(s,a)=(1ε)P(s,a)+εηaP'(s,a) = (1-\varepsilon) P(s,a) + \varepsilon \cdot \eta_a

其中 ηDir(α)\eta \sim \text{Dir}(\alpha)。呢個增加咗探索嘅多樣性。

3. UCT-like PUCT

U(s,a)=cP(s,a)ln(1+N(s)+cbase)1+N(s,a)U(s,a) = c \cdot P(s,a) \cdot \sqrt{\frac{\ln(1 + N(s) + c_{\text{base}})}{1 + N(s,a)}}

呢個結合咗 UCT 嘅對數形式同 PUCT 嘅先驗引導。

KataGo 嘅改進

KataGo 對 PUCT 做咗多項改進:

1. 動態 cpuctc_{\text{puct}} 根據局面複雜度同搜索進度調整。

2. 價值預測嘅不確定性 考慮 Value Network 嘅預測信心。

3. 政策目標學習 直接學習 MCTS 訪問分佈,而唔係淨係策略頭輸出。

其他揀選公式

除咗 PUCT,仲有其他揀選公式:

RAVE(Rapid Action Value Estimation)

QRAVE(s,a)=(1β)Q(s,a)+βQAMAF(s,a)Q_{\text{RAVE}}(s,a) = (1-\beta) Q(s,a) + \beta Q_{\text{AMAF}}(s,a)

使用「All Moves As First」統計嚟加速學習。

GRAVE(Generalized RAVE)

RAVE 嘅變體,使用祖先節點嘅統計資訊。


理論分析

收斂性

PUCT 係咪保證收斂到最佳解?

嚴格嚟講:無好似 UCB1 咁嘅理論保證。

實踐入面:喺足夠多嘅模擬之後,PUCT 會收斂到高質素嘅解,因為:

  1. 探索項最終會趨近於零
  2. 所有動作最終都會被訪問
  3. QQ 值會收斂到真實價值

複雜度分析

時間複雜度(每次模擬):

  • Selection:O(logN)O(\log N)(樹嘅深度)
  • Expansion:O(A)O(A)(合法動作數,需要神經網絡推理)
  • Evaluation:O(1)O(1)(Value Network)或 O(T)O(T)(Rollout,TT 係遊戲長度)
  • Backpropagation:O(logN)O(\log N)

空間複雜度

  • 每個節點:O(A)O(A)(儲存先驗同訪問統計)
  • 成棵樹:O(NA)O(N \cdot A)NN 係節點數)

與 Minimax 嘅關係

cpuct0c_{\text{puct}} \to 0 且模擬次數 \to \infty 嗰陣,MCTS + PUCT 會近似於 Minimax 搜索。

但喺有限預算下,PUCT 通常比 Minimax + Alpha-Beta 更有效率,因為佢能更好噉利用先驗知識。


常見問題

Q:點解唔直接用 Policy Network 嘅輸出做揀選?

A:Policy Network 可能會出錯。MCTS 嘅搜索能夠:

  1. 驗證神經網絡嘅判斷
  2. 發現神經網絡忽略嘅好棋
  3. 修正神經網絡嘅系統性偏見

實驗顯示,即使神經網絡好強,加入搜索仍能顯著提升棋力。

Q:PUCT 喺邊啲情況下表現唔好?

  1. 先驗機率完全錯誤:如果 Policy Network 將好棋評為低機率,PUCT 需要好多模擬先能修正

  2. 長期戰術:PUCT 可能難以發現需要精確計算嘅長序列戰術

  3. 對手模型缺失:PUCT 假設對手都用最佳策略,面對唔合理嘅對手可能唔係最優

Q:點樣處理超大嘅動作空間?

一啲技術:

  1. Policy Network 過濾:只考慮 P(s,a)>ϵP(s,a) > \epsilon 嘅動作
  2. 漸進展寬:先只展開少量動作,需要嗰陣先擴展
  3. 動態剪枝:移除明顯差嘅動作

動畫對應

本文涉及嘅核心概念與動畫編號:

編號概念物理/數學對應
🎬 E4探索與利用多臂賭博機
🎬 C3PUCT 揀選信賴界限

總結

PUCT 公式係 AlphaGo MCTS 嘅核心,我哋學習咗:

  1. 公式結構Q+UQ + U,利用項加探索項
  2. 各項意義QQ 係經驗價值,PP 係先驗機率,NN 控制探索衰減
  3. 與 UCB1 嘅關係:PUCT 加入咗先驗並使用唔同嘅探索項形式
  4. 數學推導:由多臂賭博機到 MCTS 揀選
  5. 實際調參cpuctc_{\text{puct}} 嘅選擇同影響
  6. 進階變體:動態調整、噪音、KataGo 改進

PUCT 嘅優雅之處在於佢簡單而有效——用一條公式就平衡咗探索同利用,並優雅噉整合咗神經網絡嘅先驗知識。


延伸閱讀


參考資料

  1. Rosin, C. D. (2011). "Multi-armed bandits with episode context." Annals of Mathematics and Artificial Intelligence, 61(3), 203-230.
  2. Silver, D., et al. (2016). "Mastering the game of Go with deep neural networks and tree search." Nature, 529, 484-489.
  3. Silver, D., et al. (2017). "Mastering the game of Go without human knowledge." Nature, 550, 354-359.
  4. Kocsis, L., & Szepesvári, C. (2006). "Bandit based Monte-Carlo Planning." ECML.
  5. Auer, P., Cesa-Bianchi, N., & Fischer, P. (2002). "Finite-time analysis of the multiarmed bandit problem." Machine Learning, 47(2), 235-256.
  6. Wu, D., et al. (2019). "Accelerating Self-Play Learning in Go." arXiv preprint (KataGo 技術報告).

附錄:完整實現範例

import math
import numpy as np
from typing import Dict, List, Optional

class MCTSNode:
"""MCTS 節點"""
def __init__(self, prior: float = 0.0):
self.prior = prior # P(s,a)
self.visit_count = 0 # N(s,a)
self.value_sum = 0.0 # W(s,a)
self.children: Dict[int, 'MCTSNode'] = {}

@property
def q_value(self) -> float:
"""計算 Q(s,a)"""
if self.visit_count == 0:
return 0.0
return self.value_sum / self.visit_count


class MCTS:
"""MCTS 搜索器,使用 PUCT"""

def __init__(
self,
policy_network,
value_network,
c_puct: float = 1.5,
num_simulations: int = 800
):
self.policy_network = policy_network
self.value_network = value_network
self.c_puct = c_puct
self.num_simulations = num_simulations

def search(self, root_state) -> Dict[int, float]:
"""執行 MCTS 搜索,返回動作嘅訪問分佈"""
root = MCTSNode()

# 展開根節點
policy = self.policy_network(root_state)
for action, prob in enumerate(policy):
if is_legal(root_state, action):
root.children[action] = MCTSNode(prior=prob)

# 執行模擬
for _ in range(self.num_simulations):
self._simulate(root, root_state)

# 返回訪問分佈
total_visits = sum(
child.visit_count for child in root.children.values()
)
return {
action: child.visit_count / total_visits
for action, child in root.children.items()
}

def _simulate(self, node: MCTSNode, state) -> float:
"""執行單次模擬"""

# 如果係終局狀態
if is_terminal(state):
return get_result(state)

# 如果係葉節點,展開並評估
if not node.children:
policy = self.policy_network(state)
value = self.value_network(state)

for action, prob in enumerate(policy):
if is_legal(state, action):
node.children[action] = MCTSNode(prior=prob)

return value

# Selection:揀選 PUCT 分數最高嘅動作
action = self._select_action(node)
child = node.children[action]
next_state = apply_action(state, action)

# 遞迴模擬
value = -self._simulate(child, next_state)

# Backpropagation:更新統計
child.visit_count += 1
child.value_sum += value

return value

def _select_action(self, node: MCTSNode) -> int:
"""使用 PUCT 公式揀選動作"""
total_visits = sum(
child.visit_count for child in node.children.values()
)

def puct_score(action: int, child: MCTSNode) -> float:
# Q(s,a):平均價值
q = child.q_value

# U(s,a):探索加成
u = (
self.c_puct
* child.prior
* math.sqrt(total_visits)
/ (1 + child.visit_count)
)

return q + u

return max(
node.children.keys(),
key=lambda a: puct_score(a, node.children[a])
)


# 使用範例
def play_game():
policy_net = PolicyNetwork()
value_net = ValueNetwork()

mcts = MCTS(
policy_network=policy_net,
value_network=value_net,
c_puct=1.5,
num_simulations=1600
)

state = initial_state()

while not is_terminal(state):
# 執行 MCTS 搜索
visit_distribution = mcts.search(state)

# 揀選訪問次數最多嘅動作
action = max(visit_distribution.keys(),
key=lambda a: visit_distribution[a])

# 執行動作
state = apply_action(state, action)
print(f"Selected action {action} with visit ratio "
f"{visit_distribution[action]:.2%}")

print(f"Game result: {get_result(state)}")

呢個實現展示咗 PUCT 公式喺 MCTS 入面嘅核心角色。實際嘅 AlphaGo 實現仲包括:

  • 並行搜索(虛擬損失)
  • 批次神經網絡評估
  • 樹嘅重用
  • 狄利克雷噪音
  • 溫度控制等功能