Saltar al contenido principal

Arquitectura de redes neuronales en detalle

Este artículo analiza en profundidad la arquitectura completa de la red neuronal de KataGo, desde la codificación de características de entrada hasta el diseño de salida múltiple.


Visión general de la arquitectura

KataGo utiliza un diseño de red neuronal única con salida múltiple:


Codificación de características de entrada

Visión general de planos de características

KataGo utiliza 22 planos de características (19×19×22), cada plano es una matriz de 19×19:

PlanoContenidoDescripción
0Piedras propias1 = hay piedra propia, 0 = no
1Piedras del oponente1 = hay piedra del oponente, 0 = no
2Puntos vacíos1 = punto vacío, 0 = hay piedra
3-10Estado históricoCambios del tablero en los últimos 8 movimientos
11Punto de ko1 = este punto es ko prohibido, 0 = jugable
12-17Codificación de libertadesCadenas con 1, 2, 3... libertades
18-21Codificación de reglasReglas chinas/japonesas, komi, etc.

Apilamiento de estado histórico

Para que la red neuronal entienda los cambios dinámicos de la posición, KataGo apila los estados del tablero de los últimos 8 movimientos:

# Codificación de estado histórico (concepto)
def encode_history(game_history, current_player):
features = []

for t in range(8): # Últimos 8 movimientos
if t < len(game_history):
board = game_history[-(t+1)]
# Codificar piedras propias/oponente en ese momento
features.append(encode_board(board, current_player))
else:
# Historia insuficiente, rellenar con ceros
features.append(np.zeros((19, 19)))

return np.stack(features, axis=0)

Codificación de reglas

KataGo soporta múltiples reglas, comunicándolas a la red neuronal a través de planos de características:

# Codificación de reglas (concepto)
def encode_rules(rules, komi):
rule_features = np.zeros((4, 19, 19))

# Tipo de regla (one-hot)
if rules == "chinese":
rule_features[0] = 1.0
elif rules == "japanese":
rule_features[1] = 1.0

# Komi normalizado
normalized_komi = komi / 15.0 # Normalizado a [-1, 1]
rule_features[2] = normalized_komi

# Jugador actual
rule_features[3] = 1.0 if current_player == BLACK else 0.0

return rule_features

Red troncal: Torre residual

Estructura del bloque residual

KataGo utiliza la estructura Pre-activation ResNet:

Ejemplo de código

class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.bn1 = nn.BatchNorm2d(channels)
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)

def forward(self, x):
residual = x

out = self.bn1(x)
out = F.relu(out)
out = self.conv1(out)

out = self.bn2(out)
out = F.relu(out)
out = self.conv2(out)

return out + residual # Conexión residual

Capa de pooling global

Una de las innovaciones clave de KataGo: agregar pooling global en los bloques residuales, permitiendo que la red vea información global:

class GlobalPoolingBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
self.fc = nn.Linear(channels, channels)

def forward(self, x):
# Ruta local
local = self.conv(x)

# Ruta global
global_pool = x.mean(dim=[2, 3]) # Pooling promedio global
global_fc = self.fc(global_pool)
global_broadcast = global_fc.unsqueeze(2).unsqueeze(3)
global_broadcast = global_broadcast.expand(-1, -1, 19, 19)

# Fusión
return local + global_broadcast

¿Por qué se necesita pooling global?

Las convoluciones tradicionales solo ven localmente (campo receptivo 3×3), incluso apilando muchas capas, la percepción de información global sigue siendo limitada. El pooling global permite que la red "vea" directamente:

  • La diferencia de cantidad de piedras en todo el tablero
  • La distribución global de influencia
  • El juicio general de la posición

Diseño de cabezas de salida

Policy Head (Cabeza de política)

Produce la probabilidad de jugar en cada posición:

class PolicyHead(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, 2, 1) # Conv. 1×1
self.bn = nn.BatchNorm2d(2)
self.fc = nn.Linear(2 * 19 * 19, 362) # 361 + pass

def forward(self, x):
out = F.relu(self.bn(self.conv(x)))
out = out.view(out.size(0), -1)
out = self.fc(out)
return F.softmax(out, dim=1) # Distribución de probabilidad

Formato de salida: Vector de 362 dimensiones

  • Índices 0-360: Probabilidad de jugar en las 361 posiciones del tablero
  • Índice 361: Probabilidad de pasar

Value Head (Cabeza de valor)

Produce la tasa de victoria de la posición actual:

class ValueHead(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, 1, 1)
self.bn = nn.BatchNorm2d(1)
self.fc1 = nn.Linear(19 * 19, 256)
self.fc2 = nn.Linear(256, 1)

def forward(self, x):
out = F.relu(self.bn(self.conv(x)))
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = torch.tanh(self.fc2(out)) # Salida de -1 a +1
return out

Formato de salida: Valor único [-1, +1]

  • +1: Victoria segura propia
  • -1: Victoria segura del oponente
  • 0: Posición equilibrada

Score Head (Cabeza de puntuación)

Exclusivo de KataGo, predice la diferencia final de puntos:

class ScoreHead(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, 1, 1)
self.bn = nn.BatchNorm2d(1)
self.fc1 = nn.Linear(19 * 19, 256)
self.fc2 = nn.Linear(256, 1)

def forward(self, x):
out = F.relu(self.bn(self.conv(x)))
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = self.fc2(out) # Salida sin restricción
return out

Formato de salida: Valor único (puntos)

  • Positivo: Ventaja propia
  • Negativo: Ventaja del oponente

Ownership Head (Cabeza de propiedad)

Predice la propiedad final de cada punto:

class OwnershipHead(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, 32, 1)
self.bn = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 1, 1)

def forward(self, x):
out = F.relu(self.bn(self.conv1(x)))
out = torch.tanh(self.conv2(out)) # Cada punto de -1 a +1
return out.view(out.size(0), -1) # Aplanar a 361

Formato de salida: Vector de 361 dimensiones, cada valor en [-1, +1]

  • +1: Ese punto pertenece al territorio propio
  • -1: Ese punto pertenece al territorio del oponente
  • 0: Zona neutral o disputada

Diferencias con AlphaZero

AspectoAlphaZeroKataGo
Cabezas de salida2 (Policy + Value)4 (+ Score + Ownership)
Pooling globalNo
Características de entrada17 planos22 planos (con codificación de reglas)
Bloques residualesResNet estándarPre-activation + pooling global
Soporte multi-reglasNo (a través de codificación de características)

Escala del modelo

KataGo ofrece modelos de diferentes escalas:

ModeloBloques res.CanalesParámetrosEscenario de uso
b10c12810128~5MCPU, pruebas rápidas
b18c38418384~75MGPU general
b40c25640256~95MGPU de gama alta
b60c32060320~200MGPU de alta gama

Convención de nomenclatura: b{número de bloques residuales}c{número de canales}


Implementación completa de la red

class KataGoNetwork(nn.Module):
def __init__(self, num_blocks=18, channels=384):
super().__init__()

# Convolución inicial
self.initial_conv = nn.Conv2d(22, channels, 3, padding=1)
self.initial_bn = nn.BatchNorm2d(channels)

# Torre residual
self.residual_blocks = nn.ModuleList([
ResidualBlock(channels) for _ in range(num_blocks)
])

# Bloques de pooling global (insertar uno cada varios bloques residuales)
self.global_pooling_blocks = nn.ModuleList([
GlobalPoolingBlock(channels) for _ in range(num_blocks // 6)
])

# Cabezas de salida
self.policy_head = PolicyHead(channels)
self.value_head = ValueHead(channels)
self.score_head = ScoreHead(channels)
self.ownership_head = OwnershipHead(channels)

def forward(self, x):
# Convolución inicial
out = F.relu(self.initial_bn(self.initial_conv(x)))

# Torre residual
gp_idx = 0
for i, block in enumerate(self.residual_blocks):
out = block(out)

# Insertar pooling global después de cada 6 bloques residuales
if (i + 1) % 6 == 0 and gp_idx < len(self.global_pooling_blocks):
out = self.global_pooling_blocks[gp_idx](out)
gp_idx += 1

# Cabezas de salida
policy = self.policy_head(out)
value = self.value_head(out)
score = self.score_head(out)
ownership = self.ownership_head(out)

return {
'policy': policy,
'value': value,
'score': score,
'ownership': ownership
}

Lectura adicional