跳到主要内容

Policy Network 详解

在围棋的任一局面,合法的下法平均有 250 种。如果让电脑随机选择,它永远不可能下出好棋。

AlphaGo 的突破在于:它学会了「看一眼棋盘,就知道哪些位置值得考虑」。

这个能力,来自 Policy Network(策略网络)


什么是 Policy Network?

核心功能

Policy Network 是一个深度卷积神经网络,它的任务是:

给定当前棋盘状态,输出每个位置的落子概率

用数学表示:

p = f_θ(s)

其中:

  • s:当前棋盘状态(19×19 的棋盘 + 其他特征)
  • f_θ:Policy Network(θ 是网络参数)
  • p:361 个位置的概率分布(包含 pass)

直觉理解

想像你是一位职业棋手。当你看到一个局面,你的大脑会自动「亮起」几个重要的位置——这些是你直觉认为值得考虑的点。

Policy Network 就是在模拟这个过程。

載入中...

上面的热力图显示了 Policy Network 的输出。颜色越亮的位置,模型认为越值得下。

为什么需要 Policy Network?

围棋的搜索空间太大了。如果不加筛选地搜索所有可能的走法:

策略每步考虑的着法搜索 10 步的节点数
全部考虑361361^10 ≈ 10^25
Policy Network 筛选~2020^10 ≈ 10^13

Policy Network 将搜索空间缩小了 10^12 倍(一万亿倍)。


网络架构

整体结构

AlphaGo 的 Policy Network 采用深度卷积神经网络(CNN)架构:

输入层 → 卷积层 ×12 → 输出卷积层 → Softmax
↓ ↓ ↓ ↓
19×19×48 19×19×192 19×19×1 362 个概率

输入层

输入是 19×19×48 的特征张量:

这 48 个平面包含:

  • 黑子位置、白子位置
  • 近 8 手的历史
  • 气数、叫吃、征子等特征
  • 合法性(哪些位置可以下)

卷积层

网络包含 12 层卷积层,每层的配置:

参数数值说明
滤波器数量192每层输出 192 个特征图
卷积核大小3×3(第一层 5×5)每次看 3×3 的区域
填充方式same保持 19×19 的尺寸
激活函数ReLUmax(0, x)

为什么是 192 个滤波器?

这是一个经验值。太少会限制模型容量,太多会增加计算量和过拟合风险。DeepMind 团队通过实验确定 192 是一个好的平衡点。

为什么是 3×3 卷积核?

3×3 是卷积神经网络中最常用的尺寸,原因:

  1. 足够捕捉局部模式:围棋中的眼位、接、断等都在 3×3 范围内
  2. 计算效率高:相比大卷积核,3×3 参数更少
  3. 可堆叠:多层 3×3 卷积可以达到大感受野的效果

第一层为什么用 5×5?

第一层使用较大的 5×5 卷积核,是为了在输入层就捕捉稍大范围的模式(如小飞、跳)。这是一个设计选择,后来的 AlphaGo Zero 统一使用 3×3。

ReLU 激活函数

每个卷积层后接 ReLU(Rectified Linear Unit)激活函数:

ReLU(x) = max(0, x)

为什么用 ReLU?

  1. 计算简单:只是取最大值,比 sigmoid 快很多
  2. 缓解梯度消失:正区间梯度恒为 1
  3. 稀疏激活:负值被归零,产生稀疏表示

输出层

最后一层是特殊的卷积层:

19×19×192 → 卷积(1×1, 1个滤波器) → 19×19×1 → 展平 → 362维向量 → Softmax

1×1 卷积

输出层使用 1×1 卷积,将 192 个通道压缩为 1 个。这等价于对每个位置的 192 维特征做线性组合。

Softmax 输出

362 维向量(361 个棋盘位置 + 1 个 pass)经过 Softmax 函数:

Softmax(z_i) = exp(z_i) / Σ_j exp(z_j)

Softmax 确保输出是合法的概率分布:

  • 所有值在 0 到 1 之间
  • 所有值的和为 1

参数数量

让我们计算网络的总参数量:

计算参数数量
第一卷积层5×5×48×192 + 192230,592
中间卷积层 ×11(3×3×192×192 + 192) × 113,633,792
输出卷积层1×1×192×1 + 1193
总计~3.9M

390 万个参数,以今天的标准来看是一个小型网络。


训练目标与方法

训练资料

Policy Network 使用监督学习,从人类棋谱中学习。

资料来源:

  • KGS Go Server:业余和职业棋手的对局
  • 约 3000 万局面:从 16 万局对局中取样
  • 标签:每个局面对应的人类下一步

交叉熵损失函数

训练目标是最大化预测人类着法的概率。用交叉熵损失函数:

L(θ) = -Σ log p_θ(a | s)

其中:

  • s:棋盘状态
  • a:人类实际下的位置
  • p_θ(a | s):模型预测该位置的概率

直觉理解

交叉熵损失有一个简单的含义:

当模型预测正确位置的概率越高,损失越低

如果人类下在 K10,而模型给 K10 的概率是:

  • 0.9 → 损失 = -log(0.9) ≈ 0.1(很低,好)
  • 0.1 → 损失 = -log(0.1) ≈ 2.3(很高,差)
  • 0.01 → 损失 = -log(0.01) ≈ 4.6(非常高,很差)

训练过程

# 伪代码
for epoch in range(num_epochs):
for batch in dataloader:
states, actions = batch

# 前向传播
policy = network(states) # 361 维概率向量

# 计算损失(交叉熵)
loss = cross_entropy(policy, actions)

# 反向传播
loss.backward()
optimizer.step()

训练细节:

  • 优化器:SGD with momentum
  • 学习率:初始 0.003,逐步衰减
  • 批次大小:16
  • 训练时间:约 3 周(50 GPUs)

资料增强

围棋棋盘有 8 重对称性(4 个旋转 × 2 个镜像)。每个训练样本可以变换为 8 个等价样本:

原始 → 旋转90° → 旋转180° → 旋转270°
↓ ↓ ↓ ↓
水平翻转 → ...

这让有效训练资料增加 8 倍,并确保模型学到的模式不依赖方向。


训练结果

57% 准确率

经过训练,Policy Network 达到了 57% 的 top-1 准确率

这意味着:给定任意局面,模型有 57% 的机会预测出人类专家实际下的那一步。

这个准确率高吗?

考虑到每个局面平均有 250 个合法着法,随机猜测的准确率只有 0.4%。

方法Top-1 准确率
随机猜测0.4%
之前最强的电脑围棋~44%
AlphaGo Policy Network57%

提升 13 个百分点,看起来不多,但意义重大。

棋力提升

纯粹使用 Policy Network(不加搜索)下棋,可以达到什么棋力?

配置Elo 评分大约等级
之前最强程式(Pachi)2,500业余 4-5 段
Policy Network alone2,800业余 6-7 段
+ MCTS 1600 simulations3,200+职业水平

单独的 Policy Network 就已经是业余高段,加上 MCTS 后更是跃升到职业水平。

为什么只有 57%?

人类棋谱存在以下特性,限制了准确率:

1. 多个好棋

很多局面有多步都是好棋。例如「挂角」和「守角」可能都是正确选择。模型选了另一步好棋,会被算作「错误」。

2. 风格差异

不同棋手有不同风格。激进型棋手和稳健型棋手在同一局面可能下不同的棋。模型学到的是「平均」的风格。

3. 人类也会犯错

KGS 资料包含业余棋手的对局,他们的选择不一定是最佳的。模型学到一些「错误」是正常的。


在 MCTS 中的作用

Policy Network 在 AlphaGo 的 MCTS 中扮演两个关键角色:

1. 引导搜索方向

在 MCTS 的 Selection 阶段,Policy Network 的输出用于计算 UCB(Upper Confidence Bound):

UCB(s, a) = Q(s, a) + c_puct × P(s, a) × √(N(s)) / (1 + N(s, a))

其中 P(s, a) 就是 Policy Network 给出的概率。

这意味着:

  • 高概率的着法会被优先探索
  • 低概率的着法也有机会被探索(因为有探索项)

2. 扩展节点的先验

当 MCTS 扩展一个新节点时,Policy Network 提供所有子节点的先验概率

展开节点 s:
for each action a:
child = Node()
child.prior = policy_network(s)[a] # 先验概率
child.value = 0
child.visits = 0

这些先验概率让 MCTS 「知道」哪些子节点更值得探索,即使它们还没被访问过。


轻量版 vs 完整版

AlphaGo 实际上有两个 Policy Network:

完整版(SL Policy Network)

  • 架构:13 层 CNN,192 filters
  • 准确率:57%
  • 推理时间:约 3 毫秒/局面
  • 用途:MCTS 中的 Selection 和 Expansion

轻量版(Rollout Policy Network)

  • 架构:线性模型 + 手工特征
  • 准确率:24%
  • 推理时间:约 2 微秒/局面(快 1500 倍)
  • 用途:快速模拟(rollout)

为什么需要轻量版?

在 MCTS 的 Simulation 阶段,需要从当前节点一直下到游戏结束,可能需要下 100+ 步。如果每步都用完整版 Policy Network,太慢了。

轻量版虽然准确率只有 24%,但速度快 1500 倍。在 rollout 中,速度比精度更重要。

轻量版的特征

轻量版使用手工设计的特征,包括:

特征类型范例
局部模式3×3 区域的棋子配置
全局特征是否在边角、大场
战术特征叫吃、征子、接应

这些特征被输入一个线性模型(没有隐藏层),计算速度极快。

AlphaGo Zero 的改进

后来的 AlphaGo Zero 完全弃用了轻量版和 rollout。它直接用 Value Network 评估叶节点,不需要快速模拟。这是一个重大的简化。


强化学习微调(RL Policy Network)

监督学习的局限

监督学习训练的 Policy Network 有一个根本问题:

它学的是「模仿人类」,而非「赢棋」

这意味着它会学到人类的坏习惯,也会在人类从未遇过的局面表现不佳。

自我对弈强化

DeepMind 的解决方案是用策略梯度(Policy Gradient)方法进行强化学习:

1. 让 Policy Network 自我对弈
2. 记录每盘棋的所有着法
3. 根据胜负调整参数:
- 赢了 → 增加这些着法的概率
- 输了 → 减少这些着法的概率

REINFORCE 演算法

具体使用 REINFORCE 演算法:

∇J(θ) = E[Σ_t ∇log π_θ(a_t | s_t) × z]

其中:

  • z:这盘棋的结果(+1 赢,-1 输)
  • π_θ(a_t | s_t):在状态 s_t 选择动作 a_t 的概率

结果

经过约 1 天的自我对弈训练(128 万盘),RL Policy Network:

指标SL PolicyRL Policy
对战 SL Policy50%80%
Elo 提升-+100

准确率可能略有下降(因为它不再完全模仿人类),但实际对战胜率大幅提升。

从「模仿」到「创新」

强化学习让 Policy Network 学会了一些人类未曾想过的着法。这些着法在训练资料中从未出现,但它们是有效的。

这就是为什么 AlphaGo 能下出「神之一手」——它不受人类经验的限制。


视觉化分析

不同局面的概率分布

让我们看看 Policy Network 在不同局面下的输出:

开局(布局阶段)

載入中...

开局时,概率主要集中在:

  • 角部(占角)
  • 边上(挂角、守角)
  • 「大场」位置

这符合围棋的基本原理:金角银边草肚皮。

战斗中的局面

載入中...

战斗时,概率集中在:

  • 关键的切断点
  • 叫吃、接应
  • 做眼、破眼

这显示模型学会了局部战术。

收官阶段

載入中...

收官时,概率分散在各个官子点,需要精确计算目数。

隐藏层学到什么?

通过视觉化卷积层的输出,我们可以看到模型学到的「特征」:

  • 低层:基本形状(眼、断点)
  • 中层:战术模式(叫吃、征子)
  • 高层:全局概念(势力、厚薄)

这与人类认知围棋的层次结构非常相似。


实作要点

PyTorch 实现

以下是一个简化的 Policy Network 实现:

import torch
import torch.nn as nn
import torch.nn.functional as F

class PolicyNetwork(nn.Module):
def __init__(self, input_channels=48, num_filters=192, num_layers=12):
super().__init__()

# 第一卷积层(5×5)
self.conv1 = nn.Conv2d(input_channels, num_filters,
kernel_size=5, padding=2)

# 中间卷积层(3×3)×11
self.conv_layers = nn.ModuleList([
nn.Conv2d(num_filters, num_filters,
kernel_size=3, padding=1)
for _ in range(num_layers - 1)
])

# 输出卷积层(1×1)
self.conv_out = nn.Conv2d(num_filters, 1, kernel_size=1)

def forward(self, x):
# x: (batch, 48, 19, 19)

# 第一层
x = F.relu(self.conv1(x))

# 中间层
for conv in self.conv_layers:
x = F.relu(conv(x))

# 输出层
x = self.conv_out(x) # (batch, 1, 19, 19)

# 展平 + Softmax
x = x.view(x.size(0), -1) # (batch, 361)
x = F.softmax(x, dim=1)

return x

训练循环

def train_step(model, optimizer, states, actions):
"""
states: (batch, 48, 19, 19) - 棋盘特征
actions: (batch,) - 人类下的位置 (0-360)
"""
# 前向传播
policy = model(states) # (batch, 361)

# 交叉熵损失
loss = F.cross_entropy(
torch.log(policy + 1e-8), # 防止 log(0)
actions
)

# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 计算准确率
predictions = policy.argmax(dim=1)
accuracy = (predictions == actions).float().mean()

return loss.item(), accuracy.item()

推理时的注意事项

在实际对弈时,需要注意:

  1. 过滤非法着法:将非法位置的概率设为 0,然后重新归一化
  2. 温度调节:可以用温度参数控制概率分布的「锐利度」
  3. 批次推理:在 MCTS 中可以批次处理多个局面
def get_move_probabilities(model, state, legal_moves, temperature=1.0):
"""获取合法着法的概率分布"""
policy = model(state) # (361,)

# 只保留合法着法
mask = torch.zeros(361)
mask[legal_moves] = 1
policy = policy * mask

# 温度调节
if temperature != 1.0:
policy = policy ** (1 / temperature)

# 重新归一化
policy = policy / policy.sum()

return policy

动画对应

本文涉及的核心概念与动画编号:

编号概念物理/数学对应
🎬 E1Policy Network概率场
🎬 D9CNN 特征提取滤波器响应
🎬 D3监督学习极大似然估计
🎬 H4策略梯度随机优化

延伸阅读


关键要点

  1. Policy Network 是概率分布生成器:输入棋盘,输出 361 个位置的概率
  2. 13 层 CNN + Softmax:深度卷积提取特征,Softmax 输出概率
  3. 57% 准确率:远超之前的电脑围棋程式
  4. 两个版本:完整版用于 MCTS 决策,轻量版用于快速模拟
  5. 强化学习微调:从「模仿人类」进化到「追求胜利」

Policy Network 是 AlphaGo 的「直觉」——它让 AI 能够像人类一样,快速识别出值得考虑的着法。


参考资料

  1. Silver, D., et al. (2016). "Mastering the game of Go with deep neural networks and tree search." Nature, 529, 484-489.
  2. Maddison, C. J., et al. (2014). "Move Evaluation in Go Using Deep Convolutional Neural Networks." arXiv:1412.6564.
  3. Sutton, R. S., & Barto, A. G. (2018). Reinforcement Learning: An Introduction. MIT Press.
  4. LeCun, Y., Bengio, Y., & Hinton, G. (2015). "Deep learning." Nature, 521, 436-444.