跳至主要内容

Policy Network 詳解

喺圍棋嘅任一局面,合法嘅下法平均有 250 種。如果畀電腦隨機揀,佢永遠都唔可能落出好棋。

AlphaGo 嘅突破在於:佢學識咗「望一眼棋盤,就知道邊啲位置值得考慮」。

呢個能力,嚟自 Policy Network(策略網絡)


咩係 Policy Network?

核心功能

Policy Network 係一個深度卷積神經網絡,佢嘅任務係:

俾定當前棋盤狀態,輸出每個位置嘅落子概率

用數學表示:

p = f_θ(s)

其中:

  • s:當前棋盤狀態(19×19 嘅棋盤 + 其他特徵)
  • f_θ:Policy Network(θ 係網絡參數)
  • p:361 個位置嘅概率分布(包含 pass)

直覺理解

想像你係一位職業棋手。當你睇到一個局面,你嘅大腦會自動「亮起」幾個重要嘅位置——呢啲係你直覺認為值得考慮嘅點。

Policy Network 就係喺模擬呢個過程。

載入中...

上面嘅熱力圖顯示咗 Policy Network 嘅輸出。顏色越光嘅位置,模型認為越值得落子。

點解需要 Policy Network?

圍棋嘅搜索空間太大喇。如果唔加篩選咁搜索所有可能嘅走法:

策略每步考慮嘅着法搜索 10 步嘅節點數
全部考慮361361^10 ≈ 10^25
Policy Network 篩選~2020^10 ≈ 10^13

Policy Network 將搜索空間縮細咗 10^12 倍(一萬億倍)。


網絡架構

整體結構

AlphaGo 嘅 Policy Network 採用深度卷積神經網絡(CNN)架構:

輸入層 → 卷積層 ×12 → 輸出卷積層 → Softmax
↓ ↓ ↓ ↓
19×19×48 19×19×192 19×19×1 362 個概率

輸入層

輸入係 19×19×48 嘅特徵張量:

呢 48 個平面包含:

  • 黑子位置、白子位置
  • 近 8 手嘅歷史
  • 氣數、叫吃、征子等特徵
  • 合法性(邊啲位置可以落子)

卷積層

網絡包含 12 層卷積層,每層嘅配置:

參數數值說明
濾波器數量192每層輸出 192 個特徵圖
卷積核大小3×3(第一層 5×5)每次睇 3×3 嘅區域
填充方式same保持 19×19 嘅尺寸
激活函數ReLUmax(0, x)

點解係 192 個濾波器?

呢個係一個經驗值。太少會限制模型容量,太多會增加計算量同過擬合風險。DeepMind 團隊通過實驗確定 192 係一個好嘅平衡點。

點解係 3×3 卷積核?

3×3 係卷積神經網絡入面最常用嘅尺寸,原因:

  1. 足夠捕捉局部模式:圍棋入面嘅眼位、接、斷等都喺 3×3 範圍內
  2. 計算效率高:相比大卷積核,3×3 參數更少
  3. 可堆疊:多層 3×3 卷積可以達到大感受野嘅效果

第一層點解用 5×5?

第一層使用較大嘅 5×5 卷積核,係為咗喺輸入層就捕捉稍大範圍嘅模式(如小飛、跳)。呢個係一個設計選擇,之後嘅 AlphaGo Zero 統一使用 3×3。

ReLU 激活函數

每個卷積層後接 ReLU(Rectified Linear Unit)激活函數:

ReLU(x) = max(0, x)

點解用 ReLU?

  1. 計算簡單:只係攞最大值,比 sigmoid 快好多
  2. 緩解梯度消失:正區間梯度恆為 1
  3. 稀疏激活:負值被歸零,產生稀疏表示

輸出層

最後一層係特殊嘅卷積層:

19×19×192 → 卷積(1×1, 1個濾波器) → 19×19×1 → 展平 → 362維向量 → Softmax

1×1 卷積

輸出層使用 1×1 卷積,將 192 個通道壓縮為 1 個。呢個等價於對每個位置嘅 192 維特徵做線性組合。

Softmax 輸出

362 維向量(361 個棋盤位置 + 1 個 pass)經過 Softmax 函數:

Softmax(z_i) = exp(z_i) / Σ_j exp(z_j)

Softmax 確保輸出係合法嘅概率分布:

  • 所有值喺 0 到 1 之間
  • 所有值嘅和為 1

參數數量

等我哋計算網絡嘅總參數量:

計算參數數量
第一卷積層5×5×48×192 + 192230,592
中間卷積層 ×11(3×3×192×192 + 192) × 113,633,792
輸出卷積層1×1×192×1 + 1193
總計~3.9M

大約 390 萬個參數,以今日嘅標準嚟睇係一個細型網絡。


訓練目標同方法

訓練資料

Policy Network 使用監督學習,由人類棋譜入面學習。

資料來源:

  • KGS Go Server:業餘同職業棋手嘅對局
  • 大約 3000 萬局面:由 16 萬局對局入面取樣
  • 標籤:每個局面對應嘅人類下一步

交叉熵損失函數

訓練目標係最大化預測人類着法嘅概率。用交叉熵損失函數:

L(θ) = -Σ log p_θ(a | s)

其中:

  • s:棋盤狀態
  • a:人類實際落子嘅位置
  • p_θ(a | s):模型預測該位置嘅概率

直覺理解

交叉熵損失有一個簡單嘅含義:

當模型預測正確位置嘅概率越高,損失越低

如果人類落喺 K10,而模型俾 K10 嘅概率係:

  • 0.9 → 損失 = -log(0.9) ≈ 0.1(好低,正)
  • 0.1 → 損失 = -log(0.1) ≈ 2.3(好高,差)
  • 0.01 → 損失 = -log(0.01) ≈ 4.6(非常高,好差)

訓練過程

# 偽代碼
for epoch in range(num_epochs):
for batch in dataloader:
states, actions = batch

# 前向傳播
policy = network(states) # 361 維概率向量

# 計算損失(交叉熵)
loss = cross_entropy(policy, actions)

# 反向傳播
loss.backward()
optimizer.step()

訓練細節:

  • 優化器:SGD with momentum
  • 學習率:初始 0.003,逐步衰減
  • 批次大小:16
  • 訓練時間:大約 3 星期(50 GPUs)

資料增強

圍棋棋盤有 8 重對稱性(4 個旋轉 × 2 個鏡像)。每個訓練樣本可以變換為 8 個等價樣本:

原始 → 旋轉90° → 旋轉180° → 旋轉270°
↓ ↓ ↓ ↓
水平翻轉 → ...

咁樣有效訓練資料增加 8 倍,並確保模型學到嘅模式唔依賴方向。


訓練結果

57% 準確率

經過訓練,Policy Network 達到咗 57% 嘅 top-1 準確率

呢個意味住:俾定任意局面,模型有 57% 嘅機會預測出人類專家實際落嘅嗰一步。

呢個準確率高咩?

考慮到每個局面平均有 250 個合法着法,隨機估嘅準確率只有 0.4%。

方法Top-1 準確率
隨機估0.4%
之前最強嘅電腦圍棋~44%
AlphaGo Policy Network57%

提升 13 個百分點,睇落唔多,但意義重大。

棋力提升

純粹使用 Policy Network(唔加搜索)落棋,可以達到咩棋力?

配置Elo 評分大約等級
之前最強程式(Pachi)2,500業餘 4-5 段
Policy Network alone2,800業餘 6-7 段
+ MCTS 1600 simulations3,200+職業水平

單獨嘅 Policy Network 已經係業餘高段,加上 MCTS 之後更加躍升到職業水平。

點解只有 57%?

人類棋譜存在以下特性,限制咗準確率:

1. 多個好棋

好多局面有多步都係好棋。例如「掛角」同「守角」可能都係正確選擇。模型揀咗另一步好棋,會被算作「錯誤」。

2. 風格差異

唔同棋手有唔同風格。激進型棋手同穩健型棋手喺同一局面可能落唔同嘅棋。模型學到嘅係「平均」嘅風格。

3. 人類都會出錯

KGS 資料包含業餘棋手嘅對局,佢哋嘅選擇唔一定係最佳嘅。模型學到一啲「錯誤」係正常嘅。


喺 MCTS 入面嘅作用

Policy Network 喺 AlphaGo 嘅 MCTS 入面扮演兩個關鍵角色:

1. 引導搜索方向

喺 MCTS 嘅 Selection 階段,Policy Network 嘅輸出用嚟計算 UCB(Upper Confidence Bound):

UCB(s, a) = Q(s, a) + c_puct × P(s, a) × √(N(s)) / (1 + N(s, a))

其中 P(s, a) 就係 Policy Network 俾出嘅概率。

呢個意味住:

  • 高概率嘅着法會被優先探索
  • 低概率嘅着法都有機會被探索(因為有探索項)

2. 擴展節點嘅先驗

當 MCTS 擴展一個新節點嘅時候,Policy Network 提供所有子節點嘅先驗概率

展開節點 s:
for each action a:
child = Node()
child.prior = policy_network(s)[a] # 先驗概率
child.value = 0
child.visits = 0

呢啲先驗概率令 MCTS 「知道」邊啲子節點更值得探索,即使佢哋仲未被訪問過。


輕量版 vs 完整版

AlphaGo 實際上有兩個 Policy Network:

完整版(SL Policy Network)

  • 架構:13 層 CNN,192 filters
  • 準確率:57%
  • 推理時間:大約 3 毫秒/局面
  • 用途:MCTS 入面嘅 Selection 同 Expansion

輕量版(Rollout Policy Network)

  • 架構:線性模型 + 手工特徵
  • 準確率:24%
  • 推理時間:大約 2 微秒/局面(快 1500 倍)
  • 用途:快速模擬(rollout)

點解需要輕量版?

喺 MCTS 嘅 Simulation 階段,需要由當前節點一直落到遊戲結束,可能需要落 100+ 步。如果每步都用完整版 Policy Network,太慢喇。

輕量版雖然準確率只有 24%,但速度快 1500 倍。喺 rollout 入面,速度比精度更重要。

輕量版嘅特徵

輕量版使用手工設計嘅特徵,包括:

特徵類型範例
局部模式3×3 區域嘅棋子配置
全局特徵係咪喺邊角、大場
戰術特徵叫吃、征子、接應

呢啲特徵被輸入一個線性模型(冇隱藏層),計算速度極快。

AlphaGo Zero 嘅改進

之後嘅 AlphaGo Zero 完全棄用咗輕量版同 rollout。佢直接用 Value Network 評估葉節點,唔需要快速模擬。呢個係一個重大嘅簡化。


強化學習微調(RL Policy Network)

監督學習嘅局限

監督學習訓練嘅 Policy Network 有一個根本問題:

佢學嘅係「模仿人類」,而唔係「贏棋」

呢個意味住佢會學到人類嘅壞習慣,亦會喺人類從未遇過嘅局面表現唔好。

自我對弈強化

DeepMind 嘅解決方案係用策略梯度(Policy Gradient)方法進行強化學習:

1. 令 Policy Network 自我對弈
2. 記錄每盤棋嘅所有着法
3. 根據勝負調整參數:
- 贏咗 → 增加呢啲着法嘅概率
- 輸咗 → 減少呢啲着法嘅概率

REINFORCE 演算法

具體使用 REINFORCE 演算法:

∇J(θ) = E[Σ_t ∇log π_θ(a_t | s_t) × z]

其中:

  • z:呢盤棋嘅結果(+1 贏,-1 輸)
  • π_θ(a_t | s_t):喺狀態 s_t 選擇動作 a_t 嘅概率

結果

經過大約 1 日嘅自我對弈訓練(128 萬盤),RL Policy Network:

指標SL PolicyRL Policy
對戰 SL Policy50%80%
Elo 提升-+100

準確率可能略有下降(因為佢唔再完全模仿人類),但實際對戰勝率大幅提升。

由「模仿」到「創新」

強化學習令 Policy Network 學識咗一啲人類未曾諗過嘅着法。呢啲着法喺訓練資料入面從未出現,但佢哋係有效嘅。

呢個就係點解 AlphaGo 能夠落出「神之一手」——佢唔受人類經驗嘅限制。


視覺化分析

唔同局面嘅概率分布

等我哋睇吓 Policy Network 喺唔同局面下嘅輸出:

開局(佈局階段)

載入中...

開局時,概率主要集中喺:

  • 角部(佔角)
  • 邊上(掛角、守角)
  • 「大場」位置

呢個符合圍棋嘅基本原理:金角銀邊草肚皮。

戰鬥中嘅局面

載入中...

戰鬥時,概率集中喺:

  • 關鍵嘅切斷點
  • 叫吃、接應
  • 做眼、破眼

呢個顯示模型學識咗局部戰術。

收官階段

載入中...

收官時,概率分散喺各個官子點,需要精確計算目數。

隱藏層學到咩?

通過視覺化卷積層嘅輸出,我哋可以睇到模型學到嘅「特徵」:

  • 低層:基本形狀(眼、斷點)
  • 中層:戰術模式(叫吃、征子)
  • 高層:全局概念(勢力、厚薄)

呢個同人類認知圍棋嘅層次結構非常相似。


實作要點

PyTorch 實現

以下係一個簡化嘅 Policy Network 實現:

import torch
import torch.nn as nn
import torch.nn.functional as F

class PolicyNetwork(nn.Module):
def __init__(self, input_channels=48, num_filters=192, num_layers=12):
super().__init__()

# 第一卷積層(5×5)
self.conv1 = nn.Conv2d(input_channels, num_filters,
kernel_size=5, padding=2)

# 中間卷積層(3×3)×11
self.conv_layers = nn.ModuleList([
nn.Conv2d(num_filters, num_filters,
kernel_size=3, padding=1)
for _ in range(num_layers - 1)
])

# 輸出卷積層(1×1)
self.conv_out = nn.Conv2d(num_filters, 1, kernel_size=1)

def forward(self, x):
# x: (batch, 48, 19, 19)

# 第一層
x = F.relu(self.conv1(x))

# 中間層
for conv in self.conv_layers:
x = F.relu(conv(x))

# 輸出層
x = self.conv_out(x) # (batch, 1, 19, 19)

# 展平 + Softmax
x = x.view(x.size(0), -1) # (batch, 361)
x = F.softmax(x, dim=1)

return x

訓練循環

def train_step(model, optimizer, states, actions):
"""
states: (batch, 48, 19, 19) - 棋盤特徵
actions: (batch,) - 人類落子嘅位置 (0-360)
"""
# 前向傳播
policy = model(states) # (batch, 361)

# 交叉熵損失
loss = F.cross_entropy(
torch.log(policy + 1e-8), # 防止 log(0)
actions
)

# 反向傳播
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 計算準確率
predictions = policy.argmax(dim=1)
accuracy = (predictions == actions).float().mean()

return loss.item(), accuracy.item()

推理時嘅注意事項

喺實際對弈嘅時候,需要注意:

  1. 過濾非法着法:將非法位置嘅概率設為 0,然後重新歸一化
  2. 溫度調節:可以用溫度參數控制概率分布嘅「銳利度」
  3. 批次推理:喺 MCTS 入面可以批次處理多個局面
def get_move_probabilities(model, state, legal_moves, temperature=1.0):
"""獲取合法着法嘅概率分布"""
policy = model(state) # (361,)

# 只保留合法着法
mask = torch.zeros(361)
mask[legal_moves] = 1
policy = policy * mask

# 溫度調節
if temperature != 1.0:
policy = policy ** (1 / temperature)

# 重新歸一化
policy = policy / policy.sum()

return policy

動畫對應

本文涉及嘅核心概念同動畫編號:

編號概念物理/數學對應
E1Policy Network概率場
D9CNN 特徵提取濾波器響應
D3監督學習極大似然估計
H4策略梯度隨機優化

延伸閱讀


關鍵要點

  1. Policy Network 係概率分布生成器:輸入棋盤,輸出 361 個位置嘅概率
  2. 13 層 CNN + Softmax:深度卷積提取特徵,Softmax 輸出概率
  3. 57% 準確率:遠超之前嘅電腦圍棋程式
  4. 兩個版本:完整版用於 MCTS 決策,輕量版用於快速模擬
  5. 強化學習微調:由「模仿人類」進化到「追求勝利」

Policy Network 係 AlphaGo 嘅「直覺」——佢令 AI 能夠好似人類咁,快速識別出值得考慮嘅着法。


參考資料

  1. Silver, D., et al. (2016). "Mastering the game of Go with deep neural networks and tree search." Nature, 529, 484-489.
  2. Maddison, C. J., et al. (2014). "Move Evaluation in Go Using Deep Convolutional Neural Networks." arXiv:1412.6564.
  3. Sutton, R. S., & Barto, A. G. (2018). Reinforcement Learning: An Introduction. MIT Press.
  4. LeCun, Y., Bengio, Y., & Hinton, G. (2015). "Deep learning." Nature, 521, 436-444.