# 5_compare_gcn_pools_reddit_nomask.py (PROTEINS graph classification) # Compares: # - Plain GCN # - GCN + LRMC hard clustering (contract via S^T A S, S^T X) # - GCN + DiffPool (1 layer, dense per-graph) # - GCN + gPool (TopKPooling) # # Also provides utilities to export PROTEINS into per-graph edgelists (1-based) # and to consume LRMC seeds dumped per graph from Java. # # Usage examples: # # 1) Export per-graph edge lists (for Java LRMC seeder) # python 5_compare_gcn_pools_reddit_nomask.py --export_edgelists --out_dir ./proteins_edgelists # # # 2) Train/eval with precomputed LRMC seeds (JSON files in seeds_dir) # python 5_compare_gcn_pools_reddit_nomask.py --seeds_dir ./proteins_seeds --pool_ratio 0.5 --epochs 200 # # # 3) Train/eval without LRMC (still runs Plain/DiffPool/gPool) # python 5_compare_gcn_pools_reddit_nomask.py --pool_ratio 0.5 --epochs 200 # import argparse, os, json, math, random from pathlib import Path from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import random_split from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GCNConv, global_mean_pool, TopKPooling from torch_geometric.utils import to_dense_adj, add_self_loops, remove_self_loops, coalesce, subgraph Device = torch.device # ----------------------------- # Utilities # ----------------------------- class DatasetWithGID(torch.utils.data.Dataset): """Wraps a PyG dataset and attaches a per-item global id as a tensor so that Batch will collate it into data.gid (shape [num_graphs, 1]).""" def __init__(self, base): self.base = base # Expose common attributes for convenience for attr in ("num_classes", "num_features"): if hasattr(base, attr): setattr(self, attr, getattr(base, attr)) def __len__(self): return len(self.base) def __getitem__(self, idx): data = self.base[idx] data.gid = torch.tensor([idx], dtype=torch.long) return data def set_seed(seed: int): random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def export_proteins_edgelists(root: Path, out_dir: Path): ds = TUDataset(root=str(root), name='PROTEINS') out_dir.mkdir(parents=True, exist_ok=True) for i, data in enumerate(ds): n = int(data.num_nodes) ei = data.edge_index # 1-based ids per graph with (out_dir / f"graph_{i:06d}.txt").open('w') as f: for u, v in ei.t().tolist(): f.write(f"{u+1} {v+1}\n") print(f"[export] Wrote {len(ds)} edge lists to {out_dir}") def load_lrmc_seeds_dir(seeds_dir: Optional[Path]) -> Optional[List[List[List[int]]]]: """Return list indexed by graph idx -> list of clusters -> list of node indices. Expect files named graph_000000.json OR i.json with structure {"clusters":[{"members":[...]}, ...]} """ if seeds_dir is None: return None by_graph: Dict[int, List[List[int]]] = {} for p in sorted(seeds_dir.glob("*.json")): stem = p.stem try: gi = int(stem.split('_')[-1]) if stem.startswith("graph_") else int(stem) except: continue obj = json.loads(p.read_text()) clusters = [] for c in obj.get("clusters", []): mem = c.get("members") or c.get("nodes") or [] clusters.append([int(x) for x in mem]) by_graph[gi] = clusters if not by_graph: print(f"[warn] no seed jsons found in {seeds_dir}") return None # Convert to list ordered by graph idx with possible Nones for missing max_i = max(by_graph.keys()) out: List[Optional[List[List[int]]]] = [None]*(max_i+1) for i, clusters in by_graph.items(): out[i] = clusters return out # ----------------------------- # Models # ----------------------------- class PlainGCN(nn.Module): def __init__(self, in_dim: int, hidden: int, num_classes: int): super().__init__() self.conv1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True) self.conv2 = GCNConv(hidden, hidden, add_self_loops=True, normalize=True) self.lin = nn.Linear(hidden, num_classes) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x = global_mean_pool(x, batch) return self.lin(x) class gPoolNet(nn.Module): def __init__(self, in_dim: int, hidden: int, num_classes: int, ratio: float): super().__init__() self.conv1 = GCNConv(in_dim, hidden) self.pool1 = TopKPooling(hidden, ratio=ratio) self.conv2 = GCNConv(hidden, hidden) self.lin = nn.Linear(hidden, num_classes) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) x = F.relu(self.conv2(x, edge_index)) x = global_mean_pool(x, batch) return self.lin(x) class DiffPoolOneShot(nn.Module): # One DiffPool layer per graph (dense inside each graph; PROTEINS graphs are small). def __init__(self, in_dim: int, hidden: int, num_classes: int, ratio: float): super().__init__() self.gnn_embed = GCNConv(in_dim, hidden) self.assign = nn.Linear(hidden, max(1, int(round(ratio * 50)))) # fallback K upper bound; reset per-graph at runtime self.post_conv = GCNConv(hidden, hidden) self.lin = nn.Linear(hidden, num_classes) self.ratio = ratio def forward(self, data): # We will process graphs in the batch independently to avoid padding logic. x, edge_index, batch = data.x, data.edge_index, data.batch out = [] for gi in batch.unique().tolist(): mask = (batch == gi) xi = x[mask] # select intra-graph edges and reindex to local [0..n_i-1] edges_mask = (batch[edge_index[0]] == gi) & (batch[edge_index[1]] == gi) eidx_g = edge_index[:, edges_mask] idx = torch.where(mask)[0] n_i = idx.numel() local = torch.full((batch.size(0),), -1, dtype=torch.long, device=batch.device) local[idx] = torch.arange(n_i, device=batch.device) eidx_i = local[eidx_g] # embed nodes zi = F.relu(self.gnn_embed(xi, eidx_i)) K_i = max(1, int(round(self.ratio * n_i))) Si = F.softmax(self.assign(zi)[:, :K_i], dim=-1) # [n_i, K_i] Ai = to_dense_adj(eidx_i, max_num_nodes=n_i)[0] # [n_i, n_i] Zi = Si.t() @ zi # [K_i, hidden] Ai_pooled = Si.t() @ Ai @ Si # [K_i, K_i] # Process each graph in the batch independently with relabeled node indices. x, edge_index, batch = data.x, data.edge_index, data.batch # global graph ids (dataset indices) gids = data.gid.view(-1) out = [] for gi_local, gi_global in enumerate(gids.tolist()): subset_idx = torch.where(batch == gi_local)[0] if subset_idx.numel() == 0: continue # Relabel edge_index to [0..n_i-1] for this subgraph eidx_i, _ = subgraph(subset_idx, edge_index, relabel_nodes=True, num_nodes=x.size(0)) xi = x[subset_idx] # embed nodes zi = F.relu(self.gnn_embed(xi, eidx_i)) n_i = zi.size(0) K_i = max(1, int(round(self.ratio * n_i))) Si = F.softmax(self.assign(zi)[:, :K_i], dim=-1) # [n_i, K_i] Ai = to_dense_adj(eidx_i, max_num_nodes=n_i)[0] # [n_i, n_i] Zi = Si.t() @ zi # [K_i, hidden] Ai_pooled = Si.t() @ Ai @ Si # [K_i, K_i] # Convert Ai_pooled to sparse for GCNConv rows, cols = (Ai_pooled > 0).nonzero(as_tuple=True) if rows.numel() == 0: # Fallback to identity if empty edge_index_coarse = torch.stack([torch.arange(K_i, device=Zi.device), torch.arange(K_i, device=Zi.device)]) edge_weight = None else: edge_index_coarse = torch.stack([rows, cols], dim=0) edge_weight = Ai_pooled[rows, cols] Zi2 = F.relu(self.post_conv(Zi, edge_index_coarse, edge_weight)) out.append(Zi2.mean(dim=0, keepdim=True)) Xg = torch.cat(out, dim=0) if out else torch.zeros(0, self.lin.in_features, device=x.device) return self.lin(Xg) class LRMCPoolNet(nn.Module): """Hard cluster assignment from LRMC seeds per graph; falls back to no pooling if seeds missing.""" def __init__(self, in_dim: int, hidden: int, num_classes: int, pool_ratio: float, seeds_by_graph: Optional[List[Optional[List[List[int]]]]] = None): super().__init__() self.conv1 = GCNConv(in_dim, hidden) self.conv2 = GCNConv(hidden, hidden) self.post_conv = GCNConv(hidden, hidden) self.lin = nn.Linear(hidden, num_classes) self.pool_ratio = pool_ratio self.seeds_by_graph = seeds_by_graph def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch z = F.relu(self.conv1(x, edge_index)) z = F.relu(self.conv2(z, edge_index)) # global ids for graphs in this batch (shape [num_graphs]) gids = data.gid.view(-1) # one per graph in the batch pooled = [] for gi_local, gi_global in enumerate(gids.tolist()): mask = (batch == gi_local) zi = z[mask] # build local 0..n_i-1 index map and reindex edges idx_nodes = torch.where(mask)[0] local = -torch.ones(z.size(0), dtype=torch.long, device=z.device) local[idx_nodes] = torch.arange(idx_nodes.size(0), device=z.device) edge_mask = (batch[edge_index[0]] == gi_local) & (batch[edge_index[1]] == gi_local) eidx_i = local[edge_index[:, edge_mask]] n_i = zi.size(0) clusters = None if self.seeds_by_graph is not None and 0 <= gi_global < len(self.seeds_by_graph): clusters = self.seeds_by_graph[gi_global] if not hasattr(self, "_dbg"): self._dbg = True print(f"[LRMC] used_seeds={clusters is not None} K={len(clusters) if clusters else 0} n={n_i}") # Validate clusters: must be a list of int lists within node range valid = True if clusters: cleaned = [] for mem in clusters: if not mem: continue mt = torch.tensor(mem, dtype=torch.long, device=z.device) # Try 1-based → 0-based if it fits if mt.min().item() >= 1 and mt.max().item() <= n_i and (mt == 0).sum().item() == 0: mt = mt - 1 # Now check bounds if (mt < 0).any() or (mt >= n_i).any(): valid = False break cleaned.append(mt.tolist()) clusters = cleaned if valid else None if not clusters: # Fallback: identity pooling (no change) pooled.append(zi.mean(dim=0, keepdim=True)) continue # target K and simple coarsen targetK = max(1, int(round(self.pool_ratio * n_i))) clusters = [sorted(set(c)) for c in clusters if len(c) >= 1] if len(clusters) > targetK: clusters = sorted(clusters, key=len, reverse=True)[:targetK] # Build cluster id vector (auto-detect 1-based vs 0-based) and clamp cid = -torch.ones(n_i, dtype=torch.long, device=z.device) bad = False for k, mem in enumerate(clusters): mem_t = torch.tensor(mem, dtype=torch.long, device=z.device) if mem_t.numel() > 0 and mem_t.min().item() >= 1 and mem_t.max().item() <= n_i and ( mem_t == 0).sum().item() == 0: mem_t = mem_t - 1 # if out-of-range even after adjustment, bail out for this graph if mem_t.numel() and (mem_t.min().item() < 0 or mem_t.max().item() >= n_i): bad = True break cid[mem_t] = k if bad: # fallback if seeds look inconsistent with this graph if not hasattr(self, "_bad_warned"): self._bad_warned = set() if gi_global not in self._bad_warned: print( f"[LRMC] Warning: seed indices out of range for graph gid={gi_global} (n={n_i}); falling back to mean pooling for this graph.") self._bad_warned.add(gi_global) pooled.append(zi.mean(dim=0, keepdim=True)) continue # give stragglers their own clusters up to targetK for u in torch.where(cid < 0)[0].tolist(): clusters.append([int(u)]) cid[u] = len(clusters) - 1 if len(clusters) >= targetK: break if (cid < 0).any(): # any still unassigned → dump to last cluster cid[cid < 0] = len(clusters) - 1 # mean over clusters → one vector per graph Zi = torch.zeros(len(clusters), z.size(1), device=z.device) Zi.index_add_(0, cid, zi) counts = torch.bincount(cid, minlength=len(clusters)).clamp(min=1).view(-1, 1).to(Zi.dtype) Zi = Zi / counts # Build coarse edge_index via cluster assignment: e_coarse = coalesce([cid[u], cid[v]]) e_src = cid[eidx_i[0]] e_dst = cid[eidx_i[1]] e_coarse, _ = coalesce(torch.stack([e_src, e_dst], dim=0), None, len(clusters), len(clusters)) # Post-pool convolution on (Zi, e_coarse) Zi2 = F.relu(self.post_conv(Zi, e_coarse)) pooled.append(Zi2.mean(dim=0, keepdim=True)) Xg = torch.cat(pooled, dim=0) return self.lin(Xg) # ----------------------------- # Training & eval # ----------------------------- def train_epoch(model, loader, device, opt): model.train() total = 0.0 for data in loader: data = data.to(device) opt.zero_grad() out = model(data) loss = F.cross_entropy(out, data.y) loss.backward() opt.step() total += loss.detach().item() * data.num_graphs return total / len(loader.dataset) @torch.no_grad() def evaluate(model, loader, device): model.eval() correct = 0 total = 0 for data in loader: data = data.to(device) pred = model(data).argmax(dim=-1) correct += int((pred == data.y).sum()) total += data.num_graphs return correct / total # ----------------------------- # Main # ----------------------------- def main(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') # Dataset base = TUDataset(root=str(args.data_root), name='PROTEINS') ds = DatasetWithGID(base) # attach global ids so the batch knows which dataset graph each sample is for i in range(len(ds)): ds[i].gid = torch.tensor([i], dtype=torch.long) num_classes = base.num_classes in_dim = base.num_features if base.num_features and base.num_features > 0 else 1 # Fallback: if no features, use degree as a single feature if in_dim == 0: from torch_geometric.utils import degree for data in base: deg = degree(data.edge_index[0], num_nodes=data.num_nodes).view(-1,1) data.x = deg in_dim = 1 # Splits N = len(ds) n_train = int(0.8 * N); n_val = int(0.1 * N); n_test = N - n_train - n_val train_set, val_set, test_set = random_split(ds, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(args.seed)) train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False) test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False) seeds_by_graph = load_lrmc_seeds_dir(Path(args.seeds_dir)) if args.seeds_dir else None # Models hidden = args.hidden plain = PlainGCN(in_dim, hidden, num_classes).to(device) lrmc = LRMCPoolNet(in_dim, hidden, num_classes, pool_ratio=args.pool_ratio, seeds_by_graph=seeds_by_graph).to(device) gpool = gPoolNet(in_dim, hidden, num_classes, ratio=args.pool_ratio).to(device) diffp = DiffPoolOneShot(in_dim, hidden, num_classes, ratio=args.pool_ratio).to(device) # Train each model separately for name, model in [('PlainGCN', plain), ('L-RMC', lrmc), ('gPool', gpool), ('DiffPool', diffp)]: opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4) best_val = 0.0; best_state = None for epoch in range(1, args.epochs+1): loss = train_epoch(model, train_loader, device, opt) acc_val = evaluate(model, val_loader, device) if acc_val >= best_val: best_val = acc_val; best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()} if epoch % 20 == 0 or epoch == args.epochs: print(f"{name} epoch {epoch}: loss={loss:.4f}, val_acc={acc_val:.3f}") if best_state is not None: model.load_state_dict(best_state, strict=False) acc_test = evaluate(model, test_loader, device) print(f"{name:8s} test_acc={acc_test:.3f}") if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument("--data_root", type=str, default="./data", help="Root dir for TUDataset(PROTEINS)") p.add_argument("--export_edgelists", action="store_true", help="Export per-graph edgelists for Java LRMC seeder") p.add_argument("--out_dir", type=str, default="./proteins_edgelists", help="Where to write edge lists") p.add_argument("--seeds_dir", type=str, default="", help="Directory with per-graph LRMC seed JSON files") p.add_argument("--pool_ratio", type=float, default=0.5) p.add_argument("--hidden", type=int, default=64) p.add_argument("--batch_size", type=int, default=64) p.add_argument("--epochs", type=int, default=200) p.add_argument("--seed", type=int, default=42) p.add_argument("--cpu", action="store_true") args = p.parse_args() if args.export_edgelists: export_proteins_edgelists(Path(args.data_root), Path(args.out_dir)) else: main(args)