Saltar al contenido principal

Detalles de implementación de MCTS

Este artículo analiza en profundidad los detalles de implementación de Monte Carlo Tree Search (MCTS) en KataGo, incluyendo estructuras de datos, estrategias de selección y técnicas de paralelización.


Repaso de los cuatro pasos de MCTS


Estructura de datos del nodo

Datos principales

Cada nodo MCTS necesita almacenar:

class MCTSNode:
def __init__(self, state, parent=None, prior=0.0):
# Información básica
self.state = state # Estado del tablero
self.parent = parent # Nodo padre
self.children = {} # Diccionario de hijos {acción: nodo}
self.action = None # Acción para llegar a este nodo

# Información estadística
self.visit_count = 0 # N(s): Número de visitas
self.value_sum = 0.0 # W(s): Suma de valores
self.prior = prior # P(s,a): Probabilidad a priori

# Para búsqueda paralela
self.virtual_loss = 0 # Pérdida virtual
self.is_expanded = False # Si ya está expandido

@property
def value(self):
"""Q(s) = W(s) / N(s)"""
if self.visit_count == 0:
return 0.0
return self.value_sum / self.visit_count

Optimización de memoria

KataGo usa varias técnicas para reducir el uso de memoria:

# Usar arrays numpy en lugar de dict de Python
class OptimizedNode:
__slots__ = ['visit_count', 'value_sum', 'prior', 'children_indices']

def __init__(self):
self.visit_count = np.int32(0)
self.value_sum = np.float32(0.0)
self.prior = np.float32(0.0)
self.children_indices = None # Asignación diferida

Selection: Selección PUCT

Fórmula PUCT

Puntuación de selección = Q(s,a) + U(s,a)

Donde:
Q(s,a) = W(s,a) / N(s,a) # Valor promedio
U(s,a) = c_puct × P(s,a) × √(N(s)) / (1 + N(s,a)) # Término de exploración

Explicación de parámetros

SímboloSignificadoValor típico
Q(s,a)Valor promedio de la acción a[-1, +1]
P(s,a)Probabilidad a priori de la red neuronal[0, 1]
N(s)Número de visitas del nodo padreEntero
N(s,a)Número de visitas de la acción aEntero
c_puctConstante de exploración1.0 ~ 2.5

Implementación

def select_child(self, c_puct=1.5):
"""Seleccionar el nodo hijo con mayor puntuación PUCT"""
best_score = -float('inf')
best_action = None
best_child = None

# Raíz cuadrada del número de visitas del padre
sqrt_parent_visits = math.sqrt(self.visit_count)

for action, child in self.children.items():
# Valor Q (valor promedio)
if child.visit_count > 0:
q_value = child.value_sum / child.visit_count
else:
q_value = 0.0

# Valor U (término de exploración)
u_value = c_puct * child.prior * sqrt_parent_visits / (1 + child.visit_count)

# Puntuación total
score = q_value + u_value

if score > best_score:
best_score = score
best_action = action
best_child = child

return best_action, best_child

Balance entre exploración y explotación

Etapa inicial: N(s,a) pequeño
├── U(s,a) grande → Dominado por exploración
└── Acciones con alta probabilidad a priori se exploran primero

Etapa posterior: N(s,a) grande
├── U(s,a) pequeño → Dominado por explotación
└── Q(s,a) domina, se eligen acciones conocidas como buenas

Expansion: Expansión de nodos

Condición de expansión

Al llegar a un nodo hoja, expandir usando la red neuronal:

def expand(self, policy_probs, legal_moves):
"""Expandir el nodo, crear nodos hijos para todas las acciones legales"""
for action in legal_moves:
if action not in self.children:
prior = policy_probs[action] # Probabilidad predicha por la red
child_state = self.state.play(action)
self.children[action] = MCTSNode(
state=child_state,
parent=self,
prior=prior
)

self.is_expanded = True

Filtrado de movimientos legales

def get_legal_moves(state):
"""Obtener todos los movimientos legales"""
legal = []
for i in range(361):
x, y = i // 19, i % 19
if state.is_legal(x, y):
legal.append(i)

# Añadir pass
legal.append(361)

return legal

Evaluation: Evaluación con red neuronal

Evaluación única

def evaluate(self, state):
"""Usar la red neuronal para evaluar la posición"""
# Codificar características de entrada
features = encode_state(state) # (22, 19, 19)
features = torch.tensor(features).unsqueeze(0) # (1, 22, 19, 19)

# Inferencia de la red neuronal
with torch.no_grad():
output = self.network(features)

policy = output['policy'][0].numpy() # (362,)
value = output['value'][0].item() # escalar

return policy, value

Evaluación por lotes (optimización clave)

La GPU es más eficiente con inferencia por lotes:

class BatchedEvaluator:
def __init__(self, network, batch_size=8):
self.network = network
self.batch_size = batch_size
self.pending = [] # Lista de (state, callback) pendientes

def request_evaluation(self, state, callback):
"""Solicitar evaluación, ejecutar automáticamente cuando el lote esté lleno"""
self.pending.append((state, callback))

if len(self.pending) >= self.batch_size:
self.flush()

def flush(self):
"""Ejecutar evaluación por lotes"""
if not self.pending:
return

# Preparar entrada por lotes
states = [s for s, _ in self.pending]
features = torch.stack([encode_state(s) for s in states])

# Inferencia por lotes
with torch.no_grad():
outputs = self.network(features)

# Callback de resultados
for i, (_, callback) in enumerate(self.pending):
policy = outputs['policy'][i].numpy()
value = outputs['value'][i].item()
callback(policy, value)

self.pending.clear()

Backpropagation: Actualización de retropropagación

Retropropagación básica

def backpropagate(self, value):
"""Retropropagar desde el nodo hoja hasta la raíz, actualizar estadísticas"""
node = self

while node is not None:
node.visit_count += 1
node.value_sum += value

# Perspectiva alterna: el valor del oponente es opuesto
value = -value

node = node.parent

Importancia de la perspectiva alterna

Perspectiva de Negro: value = +0.6 (favorable para Negro)

Camino de retropropagación:
Nodo hoja (turno de Negro): value_sum += +0.6

Nodo padre (turno de Blanco): value_sum += -0.6 ← Desfavorable para Blanco

Nodo abuelo (turno de Negro): value_sum += +0.6

...

Paralelización: Pérdida virtual

Problema

Cuando múltiples hilos buscan simultáneamente, pueden elegir el mismo nodo:

Thread 1: Selecciona nodo A (Q=0.6, N=100)
Thread 2: Selecciona nodo A (Q=0.6, N=100) ← ¡Repetido!
Thread 3: Selecciona nodo A (Q=0.6, N=100) ← ¡Repetido!

Solución: Pérdida virtual

Al seleccionar un nodo, agregar primero una "pérdida virtual" para que otros hilos no quieran seleccionarlo:

VIRTUAL_LOSS = 3  # Valor de pérdida virtual

def select_with_virtual_loss(self):
"""Selección con pérdida virtual"""
action, child = self.select_child()

# Agregar pérdida virtual
child.visit_count += VIRTUAL_LOSS
child.value_sum -= VIRTUAL_LOSS # Simular pérdida

return action, child

def backpropagate_with_virtual_loss(self, value):
"""Retropropagar eliminando pérdida virtual"""
node = self

while node is not None:
# Eliminar pérdida virtual
node.visit_count -= VIRTUAL_LOSS
node.value_sum += VIRTUAL_LOSS

# Actualización normal
node.visit_count += 1
node.value_sum += value

value = -value
node = node.parent

Efecto

Thread 1: Selecciona nodo A, agrega pérdida virtual
El valor Q de A baja temporalmente

Thread 2: Selecciona nodo B (porque A parece peor)

Thread 3: Selecciona nodo C

→ Diferentes hilos exploran diferentes ramas, mejorando la eficiencia

Implementación completa de búsqueda

class MCTS:
def __init__(self, network, c_puct=1.5, num_simulations=800):
self.network = network
self.c_puct = c_puct
self.num_simulations = num_simulations
self.evaluator = BatchedEvaluator(network)

def search(self, root_state):
"""Ejecutar búsqueda MCTS"""
root = MCTSNode(root_state)

# Expandir nodo raíz
policy, value = self.evaluate(root_state)
legal_moves = get_legal_moves(root_state)
root.expand(policy, legal_moves)

# Ejecutar simulaciones
for _ in range(self.num_simulations):
node = root
path = [node]

# Selection: Bajar por el árbol
while node.is_expanded and node.children:
action, node = node.select_child(self.c_puct)
path.append(node)

# Expansion + Evaluation
if not node.is_expanded:
policy, value = self.evaluate(node.state)
legal_moves = get_legal_moves(node.state)

if legal_moves:
node.expand(policy, legal_moves)

# Backpropagation
for n in reversed(path):
n.visit_count += 1
n.value_sum += value
value = -value

# Elegir la acción con más visitas
best_action = max(root.children.items(),
key=lambda x: x[1].visit_count)[0]

return best_action

def evaluate(self, state):
features = encode_state(state)
features = torch.tensor(features).unsqueeze(0)

with torch.no_grad():
output = self.network(features)

return output['policy'][0].numpy(), output['value'][0].item()

Técnicas avanzadas

Ruido de Dirichlet

Agregar ruido en el nodo raíz durante el entrenamiento para aumentar la exploración:

def add_dirichlet_noise(root, alpha=0.03, epsilon=0.25):
"""Agregar ruido de Dirichlet en el nodo raíz"""
noise = np.random.dirichlet([alpha] * len(root.children))

for i, child in enumerate(root.children.values()):
child.prior = (1 - epsilon) * child.prior + epsilon * noise[i]

Parámetro de temperatura

Controlar la aleatoriedad en la selección de acciones:

def select_action_with_temperature(root, temperature=1.0):
"""Seleccionar acción según visitas y temperatura"""
visits = np.array([c.visit_count for c in root.children.values()])
actions = list(root.children.keys())

if temperature == 0:
# Selección voraz
return actions[np.argmax(visits)]
else:
# Seleccionar según distribución de probabilidad basada en visitas
probs = visits ** (1 / temperature)
probs = probs / probs.sum()
return np.random.choice(actions, p=probs)

Reutilización del árbol

Un nuevo movimiento puede reutilizar el árbol de búsqueda anterior:

def reuse_tree(root, action):
"""Reutilizar subárbol"""
if action in root.children:
new_root = root.children[action]
new_root.parent = None
return new_root
else:
return None # Necesita crear un nuevo árbol

Resumen de optimización de rendimiento

TécnicaEfecto
Evaluación por lotesUtilización de GPU de 10% → 80%+
Pérdida virtualEficiencia multi-hilo mejora 3-5x
Reutilización del árbolReduce arranque en frío, ahorra 30%+ de cálculo
Pool de memoriaReduce overhead de asignación de memoria

Lectura adicional