Pular para o conteúdo principal

Detalhes de Implementação do MCTS

Este artigo analisa em profundidade os detalhes de implementação da Busca em Árvore de Monte Carlo (MCTS) no KataGo, incluindo estruturas de dados, estratégias de seleção e técnicas de paralelização.


Revisão das Quatro Etapas do MCTS


Estrutura de Dados do Nó

Dados Principais

Cada nó do MCTS precisa armazenar:

class MCTSNode:
def __init__(self, state, parent=None, prior=0.0):
# Informações básicas
self.state = state # Estado do tabuleiro
self.parent = parent # Nó pai
self.children = {} # Dicionário de filhos {ação: nó}
self.action = None # Ação que levou a este nó

# Informações estatísticas
self.visit_count = 0 # N(s): Contagem de visitas
self.value_sum = 0.0 # W(s): Soma dos valores
self.prior = prior # P(s,a): Probabilidade a priori

# Para busca paralela
self.virtual_loss = 0 # Perda virtual
self.is_expanded = False # Se já foi expandido

@property
def value(self):
"""Q(s) = W(s) / N(s)"""
if self.visit_count == 0:
return 0.0
return self.value_sum / self.visit_count

Otimização de Memória

O KataGo usa várias técnicas para reduzir o uso de memória:

# Usando arrays numpy em vez de dict Python
class OptimizedNode:
__slots__ = ['visit_count', 'value_sum', 'prior', 'children_indices']

def __init__(self):
self.visit_count = np.int32(0)
self.value_sum = np.float32(0.0)
self.prior = np.float32(0.0)
self.children_indices = None # Alocação adiada

Selection: Seleção PUCT

Fórmula PUCT

Pontuação de seleção = Q(s,a) + U(s,a)

Onde:
Q(s,a) = W(s,a) / N(s,a) # Valor médio
U(s,a) = c_puct × P(s,a) × √(N(s)) / (1 + N(s,a)) # Termo de exploração

Descrição dos Parâmetros

SímboloSignificadoValor típico
Q(s,a)Valor médio da ação a[-1, +1]
P(s,a)Probabilidade a priori da rede neural[0, 1]
N(s)Contagem de visitas do nó paiInteiro
N(s,a)Contagem de visitas da ação aInteiro
c_puctConstante de exploração1.0 ~ 2.5

Implementação

def select_child(self, c_puct=1.5):
"""Seleciona o nó filho com maior pontuação PUCT"""
best_score = -float('inf')
best_action = None
best_child = None

# Raiz quadrada das visitas do nó pai
sqrt_parent_visits = math.sqrt(self.visit_count)

for action, child in self.children.items():
# Valor Q (valor médio)
if child.visit_count > 0:
q_value = child.value_sum / child.visit_count
else:
q_value = 0.0

# Valor U (termo de exploração)
u_value = c_puct * child.prior * sqrt_parent_visits / (1 + child.visit_count)

# Pontuação total
score = q_value + u_value

if score > best_score:
best_score = score
best_action = action
best_child = child

return best_action, best_child

Equilíbrio entre Exploração e Aproveitamento

Fase inicial: N(s,a) pequeno
├── U(s,a) grande → Foco na exploração
└── Ações com alta probabilidade a priori são exploradas primeiro

Fase posterior: N(s,a) grande
├── U(s,a) pequeno → Foco no aproveitamento
└── Q(s,a) domina, seleciona ações conhecidamente boas

Expansion: Expansão de Nós

Condições de Expansão

Ao chegar em um nó folha, usa a rede neural para expandir:

def expand(self, policy_probs, legal_moves):
"""Expande o nó, criando nós filhos para todas as jogadas legais"""
for action in legal_moves:
if action not in self.children:
prior = policy_probs[action] # Probabilidade prevista pela rede neural
child_state = self.state.play(action)
self.children[action] = MCTSNode(
state=child_state,
parent=self,
prior=prior
)

self.is_expanded = True

Filtragem de Jogadas Legais

def get_legal_moves(state):
"""Obtém todas as jogadas legais"""
legal = []
for i in range(361):
x, y = i // 19, i % 19
if state.is_legal(x, y):
legal.append(i)

# Adiciona pass
legal.append(361)

return legal

Evaluation: Avaliação pela Rede Neural

Avaliação Única

def evaluate(self, state):
"""Avalia a posição usando a rede neural"""
# Codifica recursos de entrada
features = encode_state(state) # (22, 19, 19)
features = torch.tensor(features).unsqueeze(0) # (1, 22, 19, 19)

# Inferência da rede neural
with torch.no_grad():
output = self.network(features)

policy = output['policy'][0].numpy() # (362,)
value = output['value'][0].item() # escalar

return policy, value

Avaliação em Lote (Otimização Crucial)

A GPU é mais eficiente com inferência em lote:

class BatchedEvaluator:
def __init__(self, network, batch_size=8):
self.network = network
self.batch_size = batch_size
self.pending = [] # Lista de (estado, callback) aguardando avaliação

def request_evaluation(self, state, callback):
"""Solicita avaliação, executa automaticamente quando o lote está cheio"""
self.pending.append((state, callback))

if len(self.pending) >= self.batch_size:
self.flush()

def flush(self):
"""Executa avaliação em lote"""
if not self.pending:
return

# Prepara entrada em lote
states = [s for s, _ in self.pending]
features = torch.stack([encode_state(s) for s in states])

# Inferência em lote
with torch.no_grad():
outputs = self.network(features)

# Retorna resultados via callback
for i, (_, callback) in enumerate(self.pending):
policy = outputs['policy'][i].numpy()
value = outputs['value'][i].item()
callback(policy, value)

self.pending.clear()

Backpropagation: Atualização por Retropropagação

Retropropagação Básica

def backpropagate(self, value):
"""Retropropaga do nó folha até a raiz, atualizando estatísticas"""
node = self

while node is not None:
node.visit_count += 1
node.value_sum += value

# Alternância de perspectiva: o valor para o oponente é oposto
value = -value

node = node.parent

Importância da Alternância de Perspectiva

Perspectiva do Preto: value = +0.6 (Preto em vantagem)

Caminho de retropropagação:
Nó folha (Preto joga): value_sum += +0.6

Nó pai (Branco joga): value_sum += -0.6 ← Desfavorável para Branco

Nó avô (Preto joga): value_sum += +0.6

...

Paralelização: Perda Virtual

O Problema

Quando múltiplas threads buscam simultaneamente, podem todas selecionar o mesmo nó:

Thread 1: Seleciona nó A (Q=0.6, N=100)
Thread 2: Seleciona nó A (Q=0.6, N=100) ← Repetido!
Thread 3: Seleciona nó A (Q=0.6, N=100) ← Repetido!

Solução: Perda Virtual

Ao selecionar um nó, primeiro adiciona "perda virtual" para que outras threads não queiram selecioná-lo:

VIRTUAL_LOSS = 3  # Valor da perda virtual

def select_with_virtual_loss(self):
"""Seleção com perda virtual"""
action, child = self.select_child()

# Adiciona perda virtual
child.visit_count += VIRTUAL_LOSS
child.value_sum -= VIRTUAL_LOSS # Finge que perdeu

return action, child

def backpropagate_with_virtual_loss(self, value):
"""Retropropagação removendo perda virtual"""
node = self

while node is not None:
# Remove perda virtual
node.visit_count -= VIRTUAL_LOSS
node.value_sum += VIRTUAL_LOSS

# Atualização normal
node.visit_count += 1
node.value_sum += value

value = -value
node = node.parent

Efeito

Thread 1: Seleciona nó A, adiciona perda virtual
Valor Q de A cai temporariamente

Thread 2: Seleciona nó B (porque A parece pior agora)

Thread 3: Seleciona nó C

→ Diferentes threads exploram diferentes ramos, aumentando eficiência

Implementação Completa da Busca

class MCTS:
def __init__(self, network, c_puct=1.5, num_simulations=800):
self.network = network
self.c_puct = c_puct
self.num_simulations = num_simulations
self.evaluator = BatchedEvaluator(network)

def search(self, root_state):
"""Executa busca MCTS"""
root = MCTSNode(root_state)

# Expande nó raiz
policy, value = self.evaluate(root_state)
legal_moves = get_legal_moves(root_state)
root.expand(policy, legal_moves)

# Executa simulações
for _ in range(self.num_simulations):
node = root
path = [node]

# Selection: Desce pela árvore
while node.is_expanded and node.children:
action, node = node.select_child(self.c_puct)
path.append(node)

# Expansion + Evaluation
if not node.is_expanded:
policy, value = self.evaluate(node.state)
legal_moves = get_legal_moves(node.state)

if legal_moves:
node.expand(policy, legal_moves)

# Backpropagation
for n in reversed(path):
n.visit_count += 1
n.value_sum += value
value = -value

# Seleciona a ação com mais visitas
best_action = max(root.children.items(),
key=lambda x: x[1].visit_count)[0]

return best_action

def evaluate(self, state):
features = encode_state(state)
features = torch.tensor(features).unsqueeze(0)

with torch.no_grad():
output = self.network(features)

return output['policy'][0].numpy(), output['value'][0].item()

Técnicas Avançadas

Ruído de Dirichlet

Adiciona ruído no nó raiz durante o treinamento para aumentar a exploração:

def add_dirichlet_noise(root, alpha=0.03, epsilon=0.25):
"""Adiciona ruído de Dirichlet ao nó raiz"""
noise = np.random.dirichlet([alpha] * len(root.children))

for i, child in enumerate(root.children.values()):
child.prior = (1 - epsilon) * child.prior + epsilon * noise[i]

Parâmetro de Temperatura

Controla a aleatoriedade na seleção de ações:

def select_action_with_temperature(root, temperature=1.0):
"""Seleciona ação baseado na contagem de visitas e temperatura"""
visits = np.array([c.visit_count for c in root.children.values()])
actions = list(root.children.keys())

if temperature == 0:
# Seleção gulosa
return actions[np.argmax(visits)]
else:
# Seleciona de acordo com distribuição de probabilidade das visitas
probs = visits ** (1 / temperature)
probs = probs / probs.sum()
return np.random.choice(actions, p=probs)

Reutilização de Árvore

A nova jogada pode reutilizar a árvore de busca anterior:

def reuse_tree(root, action):
"""Reutiliza subárvore"""
if action in root.children:
new_root = root.children[action]
new_root.parent = None
return new_root
else:
return None # Precisa criar nova árvore

Resumo de Otimizações de Desempenho

TécnicaEfeito
Avaliação em loteUtilização da GPU de 10% → 80%+
Perda virtualEficiência multi-thread aumenta 3-5x
Reutilização de árvoreReduz cold start, economiza 30%+ de computação
Pool de memóriaReduz overhead de alocação de memória

Leitura Adicional