メインコンテンツまでスキップ

MCTS実装詳細

本記事では、KataGoにおけるモンテカルロ木探索(MCTS)の実装詳細を、データ構造、選択戦略、並列化技術を含めて詳しく解説します。


MCTSの4ステップ復習


ノードのデータ構造

コアデータ

各MCTSノードは以下を保存する必要があります:

class MCTSNode:
def __init__(self, state, parent=None, prior=0.0):
# 基本情報
self.state = state # 盤面状態
self.parent = parent # 親ノード
self.children = {} # 子ノード辞書 {action: node}
self.action = None # このノードに到達した手

# 統計情報
self.visit_count = 0 # N(s):訪問回数
self.value_sum = 0.0 # W(s):価値の合計
self.prior = prior # P(s,a):事前確率

# 並列探索用
self.virtual_loss = 0 # 仮想損失
self.is_expanded = False # 展開済みかどうか

@property
def value(self):
"""Q(s) = W(s) / N(s)"""
if self.visit_count == 0:
return 0.0
return self.value_sum / self.visit_count

メモリ最適化

KataGoはメモリ使用量を削減するために複数の技術を使用しています:

# Python dictの代わりにnumpy配列を使用
class OptimizedNode:
__slots__ = ['visit_count', 'value_sum', 'prior', 'children_indices']

def __init__(self):
self.visit_count = np.int32(0)
self.value_sum = np.float32(0.0)
self.prior = np.float32(0.0)
self.children_indices = None # 遅延割り当て

Selection:PUCT選択

PUCT公式

選択スコア = Q(s,a) + U(s,a)

ここで:
Q(s,a) = W(s,a) / N(s,a) # 平均価値
U(s,a) = c_puct × P(s,a) × √(N(s)) / (1 + N(s,a)) # 探索項

パラメータの説明

記号意味典型値
Q(s,a)手aの平均価値[-1, +1]
P(s,a)ニューラルネットワークの事前確率[0, 1]
N(s)親ノードの訪問回数整数
N(s,a)手aの訪問回数整数
c_puct探索定数1.0 ~ 2.5

実装

def select_child(self, c_puct=1.5):
"""PUCTスコア最大の子ノードを選択"""
best_score = -float('inf')
best_action = None
best_child = None

# 親ノードの訪問回数の平方根
sqrt_parent_visits = math.sqrt(self.visit_count)

for action, child in self.children.items():
# Q値(平均価値)
if child.visit_count > 0:
q_value = child.value_sum / child.visit_count
else:
q_value = 0.0

# U値(探索項)
u_value = c_puct * child.prior * sqrt_parent_visits / (1 + child.visit_count)

# 総スコア
score = q_value + u_value

if score > best_score:
best_score = score
best_action = action
best_child = child

return best_action, best_child

探索と活用のバランス

初期:N(s,a)が小さい
├── U(s,a)が大きい → 探索が主体
└── 高い事前確率の手が優先的に探索される

後期:N(s,a)が大きい
├── U(s,a)が小さい → 活用が主体
└── Q(s,a)が支配し、既知の良い手を選択

Expansion:ノード展開

展開条件

リーフノードに到達したとき、ニューラルネットワークで展開します:

def expand(self, policy_probs, legal_moves):
"""ノードを展開し、全ての合法手の子ノードを作成"""
for action in legal_moves:
if action not in self.children:
prior = policy_probs[action] # ニューラルネットワーク予測の確率
child_state = self.state.play(action)
self.children[action] = MCTSNode(
state=child_state,
parent=self,
prior=prior
)

self.is_expanded = True

合法手のフィルタリング

def get_legal_moves(state):
"""全ての合法手を取得"""
legal = []
for i in range(361):
x, y = i // 19, i % 19
if state.is_legal(x, y):
legal.append(i)

# パスを追加
legal.append(361)

return legal

Evaluation:ニューラルネットワーク評価

単一評価

def evaluate(self, state):
"""ニューラルネットワークで局面を評価"""
# 入力特徴をエンコード
features = encode_state(state) # (22, 19, 19)
features = torch.tensor(features).unsqueeze(0) # (1, 22, 19, 19)

# ニューラルネットワーク推論
with torch.no_grad():
output = self.network(features)

policy = output['policy'][0].numpy() # (362,)
value = output['value'][0].item() # スカラー

return policy, value

バッチ評価(重要な最適化)

GPUはバッチ推論で最も効率的です:

class BatchedEvaluator:
def __init__(self, network, batch_size=8):
self.network = network
self.batch_size = batch_size
self.pending = [] # 評価待ちの (state, callback) リスト

def request_evaluation(self, state, callback):
"""評価をリクエストし、バッチが満たされたら自動実行"""
self.pending.append((state, callback))

if len(self.pending) >= self.batch_size:
self.flush()

def flush(self):
"""バッチ評価を実行"""
if not self.pending:
return

# バッチ入力を準備
states = [s for s, _ in self.pending]
features = torch.stack([encode_state(s) for s in states])

# バッチ推論
with torch.no_grad():
outputs = self.network(features)

# 結果をコールバック
for i, (_, callback) in enumerate(self.pending):
policy = outputs['policy'][i].numpy()
value = outputs['value'][i].item()
callback(policy, value)

self.pending.clear()

Backpropagation:逆伝播更新

基本的な逆伝播

def backpropagate(self, value):
"""リーフノードからルートノードまで逆伝播し、統計情報を更新"""
node = self

while node is not None:
node.visit_count += 1
node.value_sum += value

# 視点の交代:相手の価値は逆になる
value = -value

node = node.parent

視点交代の重要性

黒の視点:value = +0.6(黒有利)

逆伝播経路:
リーフノード(黒の手番): value_sum += +0.6

親ノード(白の手番): value_sum += -0.6 ← 白にとっては不利

祖父ノード(黒の手番): value_sum += +0.6

...

並列化:仮想損失

問題

マルチスレッドで同時探索すると、全て同じノードを選択する可能性があります:

Thread 1: ノードAを選択(Q=0.6, N=100)
Thread 2: ノードAを選択(Q=0.6, N=100)← 重複!
Thread 3: ノードAを選択(Q=0.6, N=100)← 重複!

解決策:仮想損失

ノード選択時に、まず「仮想損失」を追加して、他のスレッドがそれを選びたくなくなるようにします:

VIRTUAL_LOSS = 3  # 仮想損失値

def select_with_virtual_loss(self):
"""仮想損失付きの選択"""
action, child = self.select_child()

# 仮想損失を追加
child.visit_count += VIRTUAL_LOSS
child.value_sum -= VIRTUAL_LOSS # 負けたと仮定

return action, child

def backpropagate_with_virtual_loss(self, value):
"""逆伝播時に仮想損失を除去"""
node = self

while node is not None:
# 仮想損失を除去
node.visit_count -= VIRTUAL_LOSS
node.value_sum += VIRTUAL_LOSS

# 通常の更新
node.visit_count += 1
node.value_sum += value

value = -value
node = node.parent

効果

Thread 1: ノードAを選択、仮想損失を追加
Aの Q値が一時的に低下

Thread 2: ノードBを選択(Aが悪く見えるため)

Thread 3: ノードCを選択

→ 異なるスレッドが異なる分岐を探索し、効率が向上

完全な探索実装

class MCTS:
def __init__(self, network, c_puct=1.5, num_simulations=800):
self.network = network
self.c_puct = c_puct
self.num_simulations = num_simulations
self.evaluator = BatchedEvaluator(network)

def search(self, root_state):
"""MCTS探索を実行"""
root = MCTSNode(root_state)

# ルートノードを展開
policy, value = self.evaluate(root_state)
legal_moves = get_legal_moves(root_state)
root.expand(policy, legal_moves)

# シミュレーションを実行
for _ in range(self.num_simulations):
node = root
path = [node]

# Selection:木を下る
while node.is_expanded and node.children:
action, node = node.select_child(self.c_puct)
path.append(node)

# Expansion + Evaluation
if not node.is_expanded:
policy, value = self.evaluate(node.state)
legal_moves = get_legal_moves(node.state)

if legal_moves:
node.expand(policy, legal_moves)

# Backpropagation
for n in reversed(path):
n.visit_count += 1
n.value_sum += value
value = -value

# 訪問回数最多の手を選択
best_action = max(root.children.items(),
key=lambda x: x[1].visit_count)[0]

return best_action

def evaluate(self, state):
features = encode_state(state)
features = torch.tensor(features).unsqueeze(0)

with torch.no_grad():
output = self.network(features)

return output['policy'][0].numpy(), output['value'][0].item()

高度な技術

ディリクレノイズ

訓練時にルートノードにノイズを追加して探索を増やします:

def add_dirichlet_noise(root, alpha=0.03, epsilon=0.25):
"""ルートノードにディリクレノイズを追加"""
noise = np.random.dirichlet([alpha] * len(root.children))

for i, child in enumerate(root.children.values()):
child.prior = (1 - epsilon) * child.prior + epsilon * noise[i]

温度パラメータ

手の選択のランダム性を制御します:

def select_action_with_temperature(root, temperature=1.0):
"""訪問回数と温度に基づいて手を選択"""
visits = np.array([c.visit_count for c in root.children.values()])
actions = list(root.children.keys())

if temperature == 0:
# 貪欲選択
return actions[np.argmax(visits)]
else:
# 訪問回数の確率分布に基づいて選択
probs = visits ** (1 / temperature)
probs = probs / probs.sum()
return np.random.choice(actions, p=probs)

木の再利用

新しい手で以前の探索木を再利用できます:

def reuse_tree(root, action):
"""部分木を再利用"""
if action in root.children:
new_root = root.children[action]
new_root.parent = None
return new_root
else:
return None # 新しい木を作成する必要あり

パフォーマンス最適化のまとめ

技術効果
バッチ評価GPU使用率が10% → 80%以上に
仮想損失マルチスレッド効率が3-5倍向上
木の再利用コールドスタートを削減、計算を30%以上節約
メモリプールメモリ割り当てオーバーヘッドを削減

関連記事