|
|
""" |
|
|
ML Model Loader and Utilities |
|
|
Handles loading and using the conflict prediction model and package embeddings. |
|
|
Loads from local files if available, otherwise downloads from Hugging Face Hub. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import pickle |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
import numpy as np |
|
|
from packaging.requirements import Requirement |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
HF_HUB_AVAILABLE = True |
|
|
except ImportError: |
|
|
HF_HUB_AVAILABLE = False |
|
|
print("Warning: huggingface_hub not available. Models must be loaded locally.") |
|
|
|
|
|
|
|
|
class ConflictPredictor: |
|
|
"""Load and use the conflict prediction model.""" |
|
|
|
|
|
def __init__(self, model_path: Optional[Path] = None, repo_id: str = "ysakhale/dependency-conflict-models"): |
|
|
"""Initialize the conflict predictor. |
|
|
|
|
|
Args: |
|
|
model_path: Local path to model file (optional) |
|
|
repo_id: Hugging Face repository ID to download from if local file not found |
|
|
""" |
|
|
self.repo_id = repo_id |
|
|
self.model = None |
|
|
self.model_path = model_path |
|
|
|
|
|
|
|
|
if model_path is None: |
|
|
model_path = Path(__file__).parent / "models" / "conflict_predictor.pkl" |
|
|
|
|
|
self.model_path = model_path |
|
|
|
|
|
|
|
|
if model_path.exists(): |
|
|
try: |
|
|
with open(model_path, 'rb') as f: |
|
|
self.model = pickle.load(f) |
|
|
print(f"Loaded conflict prediction model from {model_path}") |
|
|
return |
|
|
except Exception as e: |
|
|
print(f"Could not load conflict prediction model from local: {e}") |
|
|
|
|
|
|
|
|
if HF_HUB_AVAILABLE: |
|
|
try: |
|
|
print(f"Model not found locally. Downloading from Hugging Face Hub: {repo_id}") |
|
|
downloaded_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename="conflict_predictor.pkl", |
|
|
repo_type="model" |
|
|
) |
|
|
with open(downloaded_path, 'rb') as f: |
|
|
self.model = pickle.load(f) |
|
|
print(f"Loaded conflict prediction model from Hugging Face Hub") |
|
|
|
|
|
try: |
|
|
model_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
import shutil |
|
|
shutil.copy(downloaded_path, model_path) |
|
|
print(f"Cached model locally at {model_path}") |
|
|
except: |
|
|
pass |
|
|
return |
|
|
except Exception as e: |
|
|
print(f"Could not download model from Hugging Face Hub: {e}") |
|
|
|
|
|
print(f"Warning: Conflict prediction model not available") |
|
|
|
|
|
def extract_features(self, requirements_text: str) -> np.ndarray: |
|
|
"""Extract features from requirements text (same as training).""" |
|
|
features = [] |
|
|
|
|
|
packages = {} |
|
|
lines = requirements_text.strip().split('\n') |
|
|
num_packages = 0 |
|
|
has_pins = 0 |
|
|
version_specificity = [] |
|
|
|
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if not line or line.startswith('#'): |
|
|
continue |
|
|
|
|
|
try: |
|
|
req = Requirement(line) |
|
|
pkg_name = req.name.lower() |
|
|
specifier = str(req.specifier) if req.specifier else '' |
|
|
|
|
|
if pkg_name in packages: |
|
|
features.append(1) |
|
|
else: |
|
|
packages[pkg_name] = specifier |
|
|
num_packages += 1 |
|
|
|
|
|
if specifier: |
|
|
has_pins += 1 |
|
|
if '==' in specifier: |
|
|
version_specificity.append(3) |
|
|
elif '>=' in specifier or '<=' in specifier: |
|
|
version_specificity.append(2) |
|
|
else: |
|
|
version_specificity.append(1) |
|
|
else: |
|
|
version_specificity.append(0) |
|
|
except: |
|
|
pass |
|
|
|
|
|
feature_vec = [] |
|
|
feature_vec.append(min(num_packages / 20.0, 1.0)) |
|
|
feature_vec.append(has_pins / max(num_packages, 1)) |
|
|
feature_vec.append(np.mean(version_specificity) / 3.0 if version_specificity else 0) |
|
|
feature_vec.append(1 if len(packages) < num_packages else 0) |
|
|
|
|
|
common_packages = [ |
|
|
'torch', 'pytorch-lightning', 'tensorflow', 'keras', 'fastapi', 'pydantic', |
|
|
'numpy', 'pandas', 'scipy', 'scikit-learn', 'matplotlib', 'seaborn', |
|
|
'requests', 'httpx', 'sqlalchemy', 'alembic', 'uvicorn', 'starlette', |
|
|
'langchain', 'openai', 'chromadb', 'redis', 'celery', 'gunicorn', |
|
|
'pillow', 'opencv-python', 'beautifulsoup4', 'scrapy', 'plotly', 'jax' |
|
|
] |
|
|
|
|
|
for pkg in common_packages: |
|
|
feature_vec.append(1 if pkg in packages else 0) |
|
|
|
|
|
has_torch = 'torch' in packages |
|
|
has_pl = 'pytorch-lightning' in packages |
|
|
has_tf = 'tensorflow' in packages |
|
|
has_keras = 'keras' in packages |
|
|
has_fastapi = 'fastapi' in packages |
|
|
has_pydantic = 'pydantic' in packages |
|
|
|
|
|
feature_vec.append(1 if (has_torch and has_pl) else 0) |
|
|
feature_vec.append(1 if (has_tf and has_keras) else 0) |
|
|
feature_vec.append(1 if (has_fastapi and has_pydantic) else 0) |
|
|
|
|
|
return np.array(feature_vec) |
|
|
|
|
|
def predict(self, requirements_text: str) -> Tuple[bool, float]: |
|
|
""" |
|
|
Predict if requirements have conflicts. |
|
|
|
|
|
Returns: |
|
|
(has_conflict, confidence_score) |
|
|
""" |
|
|
if self.model is None: |
|
|
return False, 0.0 |
|
|
|
|
|
try: |
|
|
features = self.extract_features(requirements_text) |
|
|
features = features.reshape(1, -1) |
|
|
|
|
|
prediction = self.model.predict(features)[0] |
|
|
probability = self.model.predict_proba(features)[0] |
|
|
|
|
|
has_conflict = bool(prediction) |
|
|
confidence = float(probability[1] if has_conflict else probability[0]) |
|
|
|
|
|
return has_conflict, confidence |
|
|
except Exception as e: |
|
|
print(f"Error in conflict prediction: {e}") |
|
|
return False, 0.0 |
|
|
|
|
|
|
|
|
class PackageEmbeddings: |
|
|
"""Load and use package embeddings for similarity matching.""" |
|
|
|
|
|
def __init__(self, embeddings_path: Optional[Path] = None, repo_id: str = "ysakhale/dependency-conflict-models"): |
|
|
"""Initialize package embeddings. |
|
|
|
|
|
Args: |
|
|
embeddings_path: Local path to embeddings file (optional) |
|
|
repo_id: Hugging Face repository ID to download from if local file not found |
|
|
""" |
|
|
self.repo_id = repo_id |
|
|
self.embeddings = {} |
|
|
self.embeddings_path = embeddings_path |
|
|
self.model = None |
|
|
|
|
|
if embeddings_path is None: |
|
|
embeddings_path = Path(__file__).parent / "models" / "package_embeddings.json" |
|
|
|
|
|
self.embeddings_path = embeddings_path |
|
|
|
|
|
|
|
|
if embeddings_path.exists(): |
|
|
try: |
|
|
with open(embeddings_path, 'r') as f: |
|
|
self.embeddings = json.load(f) |
|
|
print(f"Loaded {len(self.embeddings)} package embeddings from {embeddings_path}") |
|
|
return |
|
|
except Exception as e: |
|
|
print(f"Could not load embeddings from local: {e}") |
|
|
|
|
|
|
|
|
if HF_HUB_AVAILABLE: |
|
|
try: |
|
|
print(f"Embeddings not found locally. Downloading from Hugging Face Hub: {repo_id}") |
|
|
downloaded_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename="package_embeddings.json", |
|
|
repo_type="model" |
|
|
) |
|
|
with open(downloaded_path, 'r') as f: |
|
|
self.embeddings = json.load(f) |
|
|
print(f"Loaded {len(self.embeddings)} package embeddings from Hugging Face Hub") |
|
|
|
|
|
try: |
|
|
embeddings_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
import shutil |
|
|
shutil.copy(downloaded_path, embeddings_path) |
|
|
print(f"Cached embeddings locally at {embeddings_path}") |
|
|
except: |
|
|
pass |
|
|
return |
|
|
except Exception as e: |
|
|
print(f"Could not download embeddings from Hugging Face Hub: {e}") |
|
|
|
|
|
print(f"Warning: Package embeddings not available") |
|
|
|
|
|
def _load_model(self): |
|
|
"""Lazy load the sentence transformer model.""" |
|
|
if self.model is None: |
|
|
try: |
|
|
from sentence_transformers import SentenceTransformer |
|
|
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
|
except ImportError: |
|
|
print("⚠️ sentence-transformers not available, embedding similarity disabled") |
|
|
return None |
|
|
return self.model |
|
|
|
|
|
def get_embedding(self, package_name: str) -> Optional[np.ndarray]: |
|
|
"""Get embedding for a package (from cache or compute on-the-fly).""" |
|
|
package_lower = package_name.lower() |
|
|
|
|
|
|
|
|
if package_lower in self.embeddings: |
|
|
return np.array(self.embeddings[package_lower]) |
|
|
|
|
|
|
|
|
model = self._load_model() |
|
|
if model is not None: |
|
|
embedding = model.encode([package_name])[0] |
|
|
|
|
|
self.embeddings[package_lower] = embedding.tolist() |
|
|
return embedding |
|
|
|
|
|
return None |
|
|
|
|
|
def find_similar(self, package_name: str, top_k: int = 5, threshold: float = 0.6) -> List[Tuple[str, float]]: |
|
|
""" |
|
|
Find similar packages using cosine similarity. |
|
|
|
|
|
Returns: |
|
|
List of (package_name, similarity_score) tuples |
|
|
""" |
|
|
query_emb = self.get_embedding(package_name) |
|
|
if query_emb is None: |
|
|
return [] |
|
|
|
|
|
similarities = [] |
|
|
|
|
|
for pkg, emb in self.embeddings.items(): |
|
|
if pkg == package_name.lower(): |
|
|
continue |
|
|
|
|
|
emb_array = np.array(emb) |
|
|
|
|
|
similarity = np.dot(query_emb, emb_array) / ( |
|
|
np.linalg.norm(query_emb) * np.linalg.norm(emb_array) |
|
|
) |
|
|
|
|
|
if similarity >= threshold: |
|
|
similarities.append((pkg, float(similarity))) |
|
|
|
|
|
|
|
|
similarities.sort(key=lambda x: x[1], reverse=True) |
|
|
return similarities[:top_k] |
|
|
|
|
|
def get_best_match(self, package_name: str, threshold: float = 0.7) -> Optional[str]: |
|
|
"""Get the best matching package name.""" |
|
|
similar = self.find_similar(package_name, top_k=1, threshold=threshold) |
|
|
if similar: |
|
|
return similar[0][0] |
|
|
return None |
|
|
|
|
|
|