angeluriot/chess_games
Viewer • Updated • 14.2M • 26k • 3
Recurrent transformer that predicts the best chess move from a board state (FEN). Trained from scratch; uses From/To square prediction heads and a shared transformer block applied 8 times with iteration embeddings (CORnet-s / Universal Transformer style). Output is always a legal move (zero fallbacks).
Part of: INFOMTALC 2025/26 (Utrecht University, MSc Applied Data Science) — chess tournament assignment.
from_logits[src] + to_logits[dst] and the best is returned.e2e4). No sampling; deterministic given the position.Requires the chess_exam package (for the Player base class and Game). Install it, then use the model via the tournament player class from the assignment repo:
# Install tournament framework
# git clone https://github.com/bylinina/chess_exam.git && cd chess_exam && pip install -e .
from chess_tournament import Game, RandomPlayer
from player import TransformerPlayer # from the assignment repo that contains model.py + player.py
tp = TransformerPlayer("RecurrentTransformer") # downloads this model from HF on first use
rp = RandomPlayer("Random")
game = Game(tp, rp, max_half_moves=200)
outcome, scores, fallbacks = game.play()
print(outcome, fallbacks)
Loading only the PyTorch state dict (no player):
import json
import torch
from huggingface_hub import hf_hub_download
from model import RecurrentTransformer # need model.py from the assignment repo
config_path = hf_hub_download("Izzent/recurrent-transformer-chess", "config.json")
weights_path = hf_hub_download("Izzent/recurrent-transformer-chess", "model.pt")
with open(config_path) as f:
config = json.load(f)
model = RecurrentTransformer.from_config(config)
state = torch.load(weights_path, map_location="cpu", weights_only=True)
model.load_state_dict(state)
model.eval()
# Forward pass expects a batch dict: board (B,64), turn (B,1), castling (B,4), ep (B,1)
# Use BoardTokenizer.encode(fen) to get these from a FEN string.
config.json: Model config (d_model, nhead, d_ff, num_iterations, dropout).model.pt: PyTorch state dict (weights only).All rights reserved.