|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import List, Optional, Tuple, Any |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from torch_geometric.data import Batch |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
from dataloader import CodeGraphBuilder |
|
|
|
|
|
from structural_encoder_v2 import RelationalGraphEncoder, StructuralEncoderV2, GatedFusion |
|
|
|
|
|
class StructuralEncoderOnlyGraph(nn.Module): |
|
|
""" |
|
|
Ablation variant 1: Pure Structural Encoder. |
|
|
Removes GraphCodeBERT and uses only the graph path (R-GNN). |
|
|
""" |
|
|
|
|
|
def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2, out_dim: int = 768): |
|
|
super().__init__() |
|
|
self.device = torch.device(device) |
|
|
|
|
|
self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=out_dim, num_layers=graph_layers) |
|
|
self.graph_encoder.to(self.device) |
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor: |
|
|
|
|
|
return self.graph_encoder(graph_batch) |
|
|
|
|
|
def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural OnlyGraph embeddings") -> np.ndarray: |
|
|
builder = CodeGraphBuilder() |
|
|
codes = df["code"].tolist() |
|
|
batches = range(0, len(codes), batch_size) |
|
|
all_embeddings: List[torch.Tensor] = [] |
|
|
|
|
|
for start in tqdm(batches, desc=desc): |
|
|
batch_codes = codes[start:start + batch_size] |
|
|
|
|
|
data_list = [builder.build(c) for c in batch_codes] |
|
|
graph_batch = Batch.from_data_list(data_list) |
|
|
|
|
|
|
|
|
dummy_ids = torch.zeros((1,1), device=self.device) |
|
|
dummy_mask = torch.zeros((1,1), device=self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = self.forward(dummy_ids, dummy_mask, graph_batch) |
|
|
all_embeddings.append(out.cpu()) |
|
|
|
|
|
embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32") |
|
|
if save_path is not None: |
|
|
np.save(save_path, embeddings) |
|
|
return embeddings |
|
|
|
|
|
def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None: |
|
|
if not checkpoint_path: |
|
|
raise ValueError("checkpoint_path must be provided") |
|
|
state = torch.load(checkpoint_path, map_location=map_location) |
|
|
if isinstance(state, dict) and "state_dict" in state: |
|
|
state = state["state_dict"] |
|
|
self.load_state_dict(state, strict=strict) |
|
|
|
|
|
|
|
|
class StructuralEncoderConcat(StructuralEncoderV2): |
|
|
""" |
|
|
Ablation variant 2: Concatenation Fusion. |
|
|
Keeps both text and graph paths but fuses them via simple concatenation + projection |
|
|
instead of Gated Fusion. |
|
|
""" |
|
|
|
|
|
def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2): |
|
|
super().__init__(device, graph_hidden_dim, graph_layers) |
|
|
|
|
|
text_dim = self.text_model.config.hidden_size |
|
|
graph_dim = self.text_model.config.hidden_size |
|
|
|
|
|
self.concat_proj = nn.Linear(text_dim + graph_dim, text_dim) |
|
|
self.concat_proj.to(self.device) |
|
|
|
|
|
del self.fusion |
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor: |
|
|
text_embeddings = self.encode_text(input_ids, attention_mask) |
|
|
graph_embeddings = self.graph_encoder(graph_batch) |
|
|
|
|
|
combined = torch.cat([text_embeddings, graph_embeddings], dim=-1) |
|
|
return self.concat_proj(combined) |
|
|
|