Lewati ke konten utama

Detail Implementasi MCTS

Artikel ini menganalisis secara mendalam detail implementasi Monte Carlo Tree Search (MCTS) di KataGo, termasuk struktur data, strategi seleksi, dan teknik paralelisasi.


Tinjauan Empat Langkah MCTS


Struktur Data Node

Data Inti

Setiap node MCTS perlu menyimpan:

class MCTSNode:
def __init__(self, state, parent=None, prior=0.0):
# Informasi dasar
self.state = state # Status papan
self.parent = parent # Node induk
self.children = {} # Dictionary child node {action: node}
self.action = None # Aksi untuk mencapai node ini

# Informasi statistik
self.visit_count = 0 # N(s): Jumlah kunjungan
self.value_sum = 0.0 # W(s): Total nilai
self.prior = prior # P(s,a): Probabilitas prior

# Untuk pencarian paralel
self.virtual_loss = 0 # Virtual loss
self.is_expanded = False # Apakah sudah diekspansi

@property
def value(self):
"""Q(s) = W(s) / N(s)"""
if self.visit_count == 0:
return 0.0
return self.value_sum / self.visit_count

Optimasi Memori

KataGo menggunakan berbagai teknik untuk mengurangi penggunaan memori:

# Menggunakan array numpy bukan Python dict
class OptimizedNode:
__slots__ = ['visit_count', 'value_sum', 'prior', 'children_indices']

def __init__(self):
self.visit_count = np.int32(0)
self.value_sum = np.float32(0.0)
self.prior = np.float32(0.0)
self.children_indices = None # Alokasi ditunda

Selection: Seleksi PUCT

Formula PUCT

Skor Seleksi = Q(s,a) + U(s,a)

Di mana:
Q(s,a) = W(s,a) / N(s,a) # Nilai rata-rata
U(s,a) = c_puct × P(s,a) × √(N(s)) / (1 + N(s,a)) # Komponen eksplorasi

Penjelasan Parameter

SimbolArtiNilai Tipikal
Q(s,a)Nilai rata-rata aksi a[-1, +1]
P(s,a)Probabilitas prior dari neural network[0, 1]
N(s)Jumlah kunjungan node indukInteger
N(s,a)Jumlah kunjungan aksi aInteger
c_puctKonstanta eksplorasi1.0 ~ 2.5

Implementasi

def select_child(self, c_puct=1.5):
"""Pilih child node dengan skor PUCT tertinggi"""
best_score = -float('inf')
best_action = None
best_child = None

# Akar kuadrat jumlah kunjungan node induk
sqrt_parent_visits = math.sqrt(self.visit_count)

for action, child in self.children.items():
# Nilai Q (nilai rata-rata)
if child.visit_count > 0:
q_value = child.value_sum / child.visit_count
else:
q_value = 0.0

# Nilai U (komponen eksplorasi)
u_value = c_puct * child.prior * sqrt_parent_visits / (1 + child.visit_count)

# Total skor
score = q_value + u_value

if score > best_score:
best_score = score
best_action = action
best_child = child

return best_action, best_child

Keseimbangan Eksplorasi vs Eksploitasi


Expansion: Ekspansi Node

Kondisi Ekspansi

Saat mencapai leaf node, gunakan neural network untuk ekspansi:

def expand(self, policy_probs, legal_moves):
"""Ekspansi node, buat child node untuk semua aksi legal"""
for action in legal_moves:
if action not in self.children:
prior = policy_probs[action] # Probabilitas dari neural network
child_state = self.state.play(action)
self.children[action] = MCTSNode(
state=child_state,
parent=self,
prior=prior
)

self.is_expanded = True
def get_legal_moves(state):
"""Dapatkan semua aksi legal"""
legal = []
for i in range(361):
x, y = i // 19, i % 19
if state.is_legal(x, y):
legal.append(i)

# Tambahkan pass
legal.append(361)

return legal

Evaluation: Evaluasi Neural Network

Evaluasi Tunggal

def evaluate(self, state):
"""Gunakan neural network untuk mengevaluasi posisi"""
# Encode fitur input
features = encode_state(state) # (22, 19, 19)
features = torch.tensor(features).unsqueeze(0) # (1, 22, 19, 19)

# Inferensi neural network
with torch.no_grad():
output = self.network(features)

policy = output['policy'][0].numpy() # (362,)
value = output['value'][0].item() # scalar

return policy, value

Evaluasi Batch (Optimasi Kunci)

GPU paling efisien saat inferensi batch:

class BatchedEvaluator:
def __init__(self, network, batch_size=8):
self.network = network
self.batch_size = batch_size
self.pending = [] # List (state, callback) yang menunggu evaluasi

def request_evaluation(self, state, callback):
"""Minta evaluasi, otomatis eksekusi saat batch penuh"""
self.pending.append((state, callback))

if len(self.pending) >= self.batch_size:
self.flush()

def flush(self):
"""Eksekusi evaluasi batch"""
if not self.pending:
return

# Siapkan input batch
states = [s for s, _ in self.pending]
features = torch.stack([encode_state(s) for s in states])

# Inferensi batch
with torch.no_grad():
outputs = self.network(features)

# Callback hasil
for i, (_, callback) in enumerate(self.pending):
policy = outputs['policy'][i].numpy()
value = outputs['value'][i].item()
callback(policy, value)

self.pending.clear()

Backpropagation: Update Balik

Backprop Dasar

def backpropagate(self, value):
"""Backprop dari leaf node ke root node, update informasi statistik"""
node = self

while node is not None:
node.visit_count += 1
node.value_sum += value

# Perspektif bergantian: nilai lawan adalah kebalikannya
value = -value

node = node.parent

Pentingnya Perspektif Bergantian

Perspektif Hitam: value = +0.6 (Hitam unggul)

Jalur backprop:
Leaf node (giliran Hitam): value_sum += +0.6

Parent node (giliran Putih): value_sum += -0.6 ← Tidak menguntungkan untuk Putih

Grandparent node (giliran Hitam): value_sum += +0.6

...

Paralelisasi: Virtual Loss

Masalah

Saat multi-thread mencari secara bersamaan, mungkin semua memilih node yang sama:

Thread 1: Pilih node A (Q=0.6, N=100)
Thread 2: Pilih node A (Q=0.6, N=100) ← Duplikat!
Thread 3: Pilih node A (Q=0.6, N=100) ← Duplikat!

Solusi: Virtual Loss

Saat memilih node, tambahkan "virtual loss" lebih dulu, membuat thread lain tidak mau memilihnya:

VIRTUAL_LOSS = 3  # Nilai virtual loss

def select_with_virtual_loss(self):
"""Seleksi dengan virtual loss"""
action, child = self.select_child()

# Tambahkan virtual loss
child.visit_count += VIRTUAL_LOSS
child.value_sum -= VIRTUAL_LOSS # Pura-pura kalah

return action, child

def backpropagate_with_virtual_loss(self, value):
"""Hapus virtual loss saat backprop"""
node = self

while node is not None:
# Hapus virtual loss
node.visit_count -= VIRTUAL_LOSS
node.value_sum += VIRTUAL_LOSS

# Update normal
node.visit_count += 1
node.value_sum += value

value = -value
node = node.parent

Efek

Thread 1: Pilih node A, tambahkan virtual loss
Nilai Q node A turun sementara

Thread 2: Pilih node B (karena A terlihat lebih buruk)

Thread 3: Pilih node C

→ Thread berbeda mengeksplorasi cabang berbeda, meningkatkan efisiensi

Implementasi Pencarian Lengkap

class MCTS:
def __init__(self, network, c_puct=1.5, num_simulations=800):
self.network = network
self.c_puct = c_puct
self.num_simulations = num_simulations
self.evaluator = BatchedEvaluator(network)

def search(self, root_state):
"""Eksekusi pencarian MCTS"""
root = MCTSNode(root_state)

# Ekspansi root node
policy, value = self.evaluate(root_state)
legal_moves = get_legal_moves(root_state)
root.expand(policy, legal_moves)

# Eksekusi simulasi
for _ in range(self.num_simulations):
node = root
path = [node]

# Selection: Turun pohon
while node.is_expanded and node.children:
action, node = node.select_child(self.c_puct)
path.append(node)

# Expansion + Evaluation
if not node.is_expanded:
policy, value = self.evaluate(node.state)
legal_moves = get_legal_moves(node.state)

if legal_moves:
node.expand(policy, legal_moves)

# Backpropagation
for n in reversed(path):
n.visit_count += 1
n.value_sum += value
value = -value

# Pilih aksi dengan kunjungan terbanyak
best_action = max(root.children.items(),
key=lambda x: x[1].visit_count)[0]

return best_action

def evaluate(self, state):
features = encode_state(state)
features = torch.tensor(features).unsqueeze(0)

with torch.no_grad():
output = self.network(features)

return output['policy'][0].numpy(), output['value'][0].item()

Teknik Lanjutan

Noise Dirichlet

Tambahkan noise di root node saat pelatihan untuk meningkatkan eksplorasi:

def add_dirichlet_noise(root, alpha=0.03, epsilon=0.25):
"""Tambahkan noise Dirichlet di root node"""
noise = np.random.dirichlet([alpha] * len(root.children))

for i, child in enumerate(root.children.values()):
child.prior = (1 - epsilon) * child.prior + epsilon * noise[i]

Parameter Temperatur

Mengontrol keacakan pemilihan aksi:

def select_action_with_temperature(root, temperature=1.0):
"""Pilih aksi berdasarkan jumlah kunjungan dan temperatur"""
visits = np.array([c.visit_count for c in root.children.values()])
actions = list(root.children.keys())

if temperature == 0:
# Pilihan greedy
return actions[np.argmax(visits)]
else:
# Pilih berdasarkan distribusi probabilitas dari jumlah kunjungan
probs = visits ** (1 / temperature)
probs = probs / probs.sum()
return np.random.choice(actions, p=probs)

Penggunaan Ulang Pohon

Langkah baru bisa menggunakan ulang pohon pencarian sebelumnya:

def reuse_tree(root, action):
"""Gunakan ulang subtree"""
if action in root.children:
new_root = root.children[action]
new_root.parent = None
return new_root
else:
return None # Perlu buat pohon baru

Ringkasan Optimasi Performa

TeknikEfek
Evaluasi BatchUtilisasi GPU dari 10% → 80%+
Virtual LossEfisiensi multi-thread meningkat 3-5x
Penggunaan Ulang PohonKurangi cold start, hemat 30%+ komputasi
Memory PoolKurangi overhead alokasi memori

Bacaan Lanjutan