Procedural-Code-Search-OnlyGraph-POJ104 / structural_encoder_v2.py
dv4aby's picture
Upload source code structural_encoder_v2.py
87d2e85 verified
import hashlib
from collections import defaultdict
from typing import Dict, List, Tuple, TYPE_CHECKING
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch
from torch_geometric.nn import HeteroConv, GATConv, global_mean_pool
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import numpy as np
if TYPE_CHECKING:
import pandas as pd
# Import Builder from dataloader for inference/eval
from dataloader import CodeGraphBuilder
class RelationalGraphEncoder(nn.Module):
"""R-GNN encoder over the AST+CFG heterogeneous graph."""
EDGE_TYPES = (
("ast", "ast_parent_child", "ast"),
("ast", "ast_child_parent", "ast"),
("ast", "ast_next_sibling", "ast"),
("ast", "ast_prev_sibling", "ast"),
("token", "token_to_ast", "ast"),
("ast", "ast_to_token", "token"),
("stmt", "cfg", "stmt"),
("stmt", "cfg_rev", "stmt"),
("stmt", "stmt_to_ast", "ast"),
("ast", "ast_to_stmt", "stmt"),
)
def __init__(self, hidden_dim: int = 256, out_dim: int = 768, num_layers: int = 2) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.ast_encoder = nn.Embedding(2048, hidden_dim)
self.token_encoder = nn.Embedding(8192, hidden_dim)
self.stmt_encoder = nn.Embedding(512, hidden_dim)
self.convs = nn.ModuleList()
for _ in range(num_layers):
hetero_modules = {
edge_type: GATConv((-1, -1), hidden_dim, add_self_loops=False)
for edge_type in self.EDGE_TYPES
}
hetero_conv = HeteroConv(hetero_modules, aggr="sum")
self.convs.append(hetero_conv)
self.output_proj = nn.Linear(hidden_dim, out_dim)
def _encode_nodes(self, data: HeteroData) -> Dict[str, torch.Tensor]:
device = self.ast_encoder.weight.device
def get_embed(node_type, encoder):
if node_type not in data.node_types:
return torch.zeros((0, self.hidden_dim), device=device)
x = data[node_type].get('x')
if x is None:
return torch.zeros((0, self.hidden_dim), device=device)
x = x.to(device)
return encoder(x)
x_dict = {
"ast": get_embed("ast", self.ast_encoder),
"token": get_embed("token", self.token_encoder),
"stmt": get_embed("stmt", self.stmt_encoder),
}
return x_dict
def forward(self, data: HeteroData) -> torch.Tensor:
device = next(self.parameters()).device
data = data.to(device)
x_dict = self._encode_nodes(data)
edge_index_dict = {}
for edge_type in self.EDGE_TYPES:
if edge_type in data.edge_index_dict:
edge_index_dict[edge_type] = data.edge_index_dict[edge_type]
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: F.relu(x) for key, x in x_dict.items()}
# Global Pooling
batch_size = data.num_graphs if hasattr(data, 'num_graphs') else 1
pooled_embeddings = []
for key, x in x_dict.items():
if x.size(0) == 0:
continue
if hasattr(data[key], 'batch') and data[key].batch is not None:
pool = global_mean_pool(x, data[key].batch, size=batch_size)
else:
# Logic for single graph without batch attribute (e.g. inference on one item)
pool = x.mean(dim=0, keepdim=True)
if pool.size(0) != batch_size:
# Should be 1
pass
pooled_embeddings.append(pool)
if not pooled_embeddings:
return torch.zeros((batch_size, self.out_dim), device=device)
# Average across node types [num_types, B, dim] -> [B, dim]
# We need to ensure all pools are [B, dim].
# If a graph misses a node type, its embedding for that type might be 0 or NaN?
# global_mean_pool returns 0 for empty batches.
graph_repr = torch.stack(pooled_embeddings).mean(dim=0)
return self.output_proj(graph_repr)
class GatedFusion(nn.Module):
def __init__(self, text_dim: int, graph_dim: int) -> None:
super().__init__()
self.graph_proj = nn.Linear(graph_dim, text_dim)
self.gate = nn.Linear(text_dim * 2, text_dim)
def forward(self, h_text: torch.Tensor, h_graph: torch.Tensor) -> torch.Tensor:
h_graph_proj = self.graph_proj(h_graph)
joint = torch.cat([h_text, h_graph_proj], dim=-1)
gate = torch.sigmoid(self.gate(joint))
return gate * h_text + (1.0 - gate) * h_graph_proj
class StructuralEncoderV2(nn.Module):
"""Structural encoder that fuses GraphCodeBERT text features with AST+CFG graph context."""
def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2):
super().__init__()
self.device = torch.device(device)
self.text_tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
self.text_model = AutoModel.from_pretrained("microsoft/graphcodebert-base")
self.text_model.to(self.device)
self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=self.text_model.config.hidden_size, num_layers=graph_layers)
self.graph_encoder.to(self.device)
self.fusion = GatedFusion(self.text_model.config.hidden_size, self.text_model.config.hidden_size)
self.fusion.to(self.device)
def encode_text(self, codes: List[str]) -> torch.Tensor:
inputs = self.text_tokenizer(
codes,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
).to(self.device)
outputs = self.text_model(**inputs)
return outputs.last_hidden_state[:, 0, :]
def forward(self, codes: List[str], graph_batch: Batch | HeteroData) -> torch.Tensor:
text_embeddings = self.encode_text(codes)
graph_embeddings = self.graph_encoder(graph_batch)
return self.fusion(text_embeddings, graph_embeddings)
def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural V2 embeddings") -> np.ndarray:
# Create local builder for inference
builder = CodeGraphBuilder()
codes = df["code"].tolist()
batches = range(0, len(codes), batch_size)
all_embeddings: List[torch.Tensor] = []
for start in tqdm(batches, desc=desc):
batch_codes = codes[start:start + batch_size]
# Parallelism here not strictly needed for eval unless slow, but we do it simply
data_list = [builder.build(c) for c in batch_codes]
graph_batch = Batch.from_data_list(data_list)
with torch.no_grad():
fused = self.forward(batch_codes, graph_batch)
all_embeddings.append(fused.cpu())
embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32")
if save_path is not None:
np.save(save_path, embeddings)
return embeddings
def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None:
if not checkpoint_path:
raise ValueError("checkpoint_path must be provided")
state = torch.load(checkpoint_path, map_location=map_location)
if isinstance(state, dict) and "state_dict" in state:
state = state["state_dict"]
self.load_state_dict(state, strict=strict)