Procedural-Reranker-Synthetic / structural_encoder_ablation.py
dv4aby's picture
Upload source code structural_encoder_ablation.py
8ca94cc verified
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Any
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch_geometric.data import Batch
from transformers import AutoTokenizer
# Import builder from dataloader for inference
from dataloader import CodeGraphBuilder
from structural_encoder_v2 import RelationalGraphEncoder, StructuralEncoderV2, GatedFusion
class StructuralEncoderOnlyGraph(nn.Module):
"""
Ablation variant 1: Pure Structural Encoder.
Removes GraphCodeBERT and uses only the graph path (R-GNN).
"""
def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2, out_dim: int = 768):
super().__init__()
self.device = torch.device(device)
self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=out_dim, num_layers=graph_layers)
self.graph_encoder.to(self.device)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor:
# Ignore text inputs for OnlyGraph
return self.graph_encoder(graph_batch)
def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural OnlyGraph embeddings") -> np.ndarray:
builder = CodeGraphBuilder()
codes = df["code"].tolist()
batches = range(0, len(codes), batch_size)
all_embeddings: List[torch.Tensor] = []
for start in tqdm(batches, desc=desc):
batch_codes = codes[start:start + batch_size]
data_list = [builder.build(c) for c in batch_codes]
graph_batch = Batch.from_data_list(data_list)
# Dummy inputs for signature compatibility
dummy_ids = torch.zeros((1,1), device=self.device)
dummy_mask = torch.zeros((1,1), device=self.device)
with torch.no_grad():
out = self.forward(dummy_ids, dummy_mask, graph_batch)
all_embeddings.append(out.cpu())
embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32")
if save_path is not None:
np.save(save_path, embeddings)
return embeddings
def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None:
if not checkpoint_path:
raise ValueError("checkpoint_path must be provided")
state = torch.load(checkpoint_path, map_location=map_location)
if isinstance(state, dict) and "state_dict" in state:
state = state["state_dict"]
self.load_state_dict(state, strict=strict)
class StructuralEncoderConcat(StructuralEncoderV2):
"""
Ablation variant 2: Concatenation Fusion.
Keeps both text and graph paths but fuses them via simple concatenation + projection
instead of Gated Fusion.
"""
def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2):
super().__init__(device, graph_hidden_dim, graph_layers)
text_dim = self.text_model.config.hidden_size
graph_dim = self.text_model.config.hidden_size
self.concat_proj = nn.Linear(text_dim + graph_dim, text_dim)
self.concat_proj.to(self.device)
del self.fusion
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor:
text_embeddings = self.encode_text(input_ids, attention_mask)
graph_embeddings = self.graph_encoder(graph_batch)
combined = torch.cat([text_embeddings, graph_embeddings], dim=-1)
return self.concat_proj(combined)