import torch import torch.nn as nn from typing import List, Optional, Tuple, Any from tqdm import tqdm import numpy as np import pandas as pd from torch_geometric.data import Batch from transformers import AutoTokenizer # Import builder from dataloader for inference from dataloader import CodeGraphBuilder from structural_encoder_v2 import RelationalGraphEncoder, StructuralEncoderV2, GatedFusion class StructuralEncoderOnlyGraph(nn.Module): """ Ablation variant 1: Pure Structural Encoder. Removes GraphCodeBERT and uses only the graph path (R-GNN). """ def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2, out_dim: int = 768): super().__init__() self.device = torch.device(device) self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=out_dim, num_layers=graph_layers) self.graph_encoder.to(self.device) def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor: # Ignore text inputs for OnlyGraph return self.graph_encoder(graph_batch) def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural OnlyGraph embeddings") -> np.ndarray: builder = CodeGraphBuilder() codes = df["code"].tolist() batches = range(0, len(codes), batch_size) all_embeddings: List[torch.Tensor] = [] for start in tqdm(batches, desc=desc): batch_codes = codes[start:start + batch_size] data_list = [builder.build(c) for c in batch_codes] graph_batch = Batch.from_data_list(data_list) # Dummy inputs for signature compatibility dummy_ids = torch.zeros((1,1), device=self.device) dummy_mask = torch.zeros((1,1), device=self.device) with torch.no_grad(): out = self.forward(dummy_ids, dummy_mask, graph_batch) all_embeddings.append(out.cpu()) embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32") if save_path is not None: np.save(save_path, embeddings) return embeddings def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None: if not checkpoint_path: raise ValueError("checkpoint_path must be provided") state = torch.load(checkpoint_path, map_location=map_location) if isinstance(state, dict) and "state_dict" in state: state = state["state_dict"] self.load_state_dict(state, strict=strict) class StructuralEncoderConcat(StructuralEncoderV2): """ Ablation variant 2: Concatenation Fusion. Keeps both text and graph paths but fuses them via simple concatenation + projection instead of Gated Fusion. """ def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2): super().__init__(device, graph_hidden_dim, graph_layers) text_dim = self.text_model.config.hidden_size graph_dim = self.text_model.config.hidden_size self.concat_proj = nn.Linear(text_dim + graph_dim, text_dim) self.concat_proj.to(self.device) del self.fusion def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor: text_embeddings = self.encode_text(input_ids, attention_mask) graph_embeddings = self.graph_encoder(graph_batch) combined = torch.cat([text_embeddings, graph_embeddings], dim=-1) return self.concat_proj(combined)