gbyuvd's picture
Update app.py
30aefbc verified
raw
history blame
13.6 kB
import gradio as gr
import torch
import sys
import os
from pathlib import Path
import importlib.util
import huggingface_hub
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import selfies as sf
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import Descriptors, rdMolDescriptors
import numpy as np
from PIL import Image
import io
class SimpleMolecularApp:
def __init__(self):
self.model = None
self.tokenizer = None
self.config = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def download_and_setup_model(self, model_name="gbyuvd/ChemMiniQ3-SAbRLo", local_dir="./chemq3_model"):
"""Download model files and set up the custom modules"""
print(f"πŸ“₯ Downloading model files from {model_name}...")
try:
model_path = huggingface_hub.snapshot_download(
repo_id=model_name,
local_dir=local_dir,
local_files_only=False,
resume_download=True
)
print(f"βœ… Model downloaded to: {model_path}")
return Path(model_path)
except Exception as e:
print(f"❌ Download failed: {e}")
return None
def load_custom_modules(self, model_path):
"""Load all the custom modules required by the model"""
model_path = Path(model_path)
if str(model_path) not in sys.path:
sys.path.insert(0, str(model_path))
required_files = {
'configuration_chemq3mtp.py': 'configuration_chemq3mtp',
'modeling_chemq3mtp.py': 'modeling_chemq3mtp',
'FastChemTokenizerHF.py': 'FastChemTokenizerHF'
}
loaded_modules = {}
for filename, module_name in required_files.items():
file_path = model_path / filename
if not file_path.exists():
print(f"❌ Required file not found: {filename}")
return None
try:
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
loaded_modules[module_name] = module
print(f" βœ… Loaded {filename}")
except Exception as e:
print(f" ❌ Failed to load {filename}: {e}")
return None
return loaded_modules
def register_model_components(self, loaded_modules):
"""Register the model components with transformers"""
try:
ChemQ3MTPConfig = loaded_modules['configuration_chemq3mtp'].ChemQ3MTPConfig
ChemQ3MTPForCausalLM = loaded_modules['modeling_chemq3mtp'].ChemQ3MTPForCausalLM
FastChemTokenizerSelfies = loaded_modules['FastChemTokenizerHF'].FastChemTokenizerSelfies
AutoConfig.register("chemq3_mtp", ChemQ3MTPConfig)
AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM)
AutoTokenizer.register(ChemQ3MTPConfig, FastChemTokenizerSelfies)
print("βœ… Model components registered successfully")
return ChemQ3MTPConfig, ChemQ3MTPForCausalLM, FastChemTokenizerSelfies
except Exception as e:
print(f"❌ Registration failed: {e}")
return None, None, None
def load_model(self, model_path):
"""Load the model using the registered components"""
try:
config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=False)
model = AutoModelForCausalLM.from_pretrained(
str(model_path),
config=config,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=False
)
tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=False)
return model, tokenizer, config
except Exception as e:
print(f"❌ Model loading failed: {e}")
return None, None, None
def initialize_model(self):
"""Initialize the model for the app"""
model_name = "gbyuvd/ChemMiniQ3-SAbRLo"
local_dir = "./chemq3_model"
model_path = self.download_and_setup_model(model_name, local_dir)
if model_path is None:
return False
loaded_modules = self.load_custom_modules(model_path)
if loaded_modules is None:
return False
config_class, model_class, tokenizer_class = self.register_model_components(loaded_modules)
if config_class is None:
return False
self.model, self.tokenizer, self.config = self.load_model(model_path)
if self.model is None:
return False
self.model = self.model.to(self.device)
self.model.eval()
return True
def calculate_lipinski_properties(self, mol):
"""Calculate Lipinski's Rule of Five properties"""
if mol is None:
return {}
# Calculate molecular descriptors
molecular_weight = Descriptors.MolWt(mol)
h_bond_donors = rdMolDescriptors.CalcNumHBD(mol) # Hydrogen bond donors
h_bond_acceptors = rdMolDescriptors.CalcNumHBA(mol) # Hydrogen bond acceptors
logp = Descriptors.MolLogP(mol) # LogP (octanol-water partition coefficient)
tpsa = Descriptors.TPSA(mol) # Topological Polar Surface Area
rotatable_bonds = rdMolDescriptors.CalcNumRotatableBonds(mol)
heavy_atoms = mol.GetNumHeavyAtoms()
# Lipinski's Rule of Five violations
violations = 0
if molecular_weight > 500: violations += 1
if h_bond_donors > 5: violations += 1
if h_bond_acceptors > 10: violations += 1
if logp > 5: violations += 1
return {
'molecular_weight': round(molecular_weight, 2),
'h_bond_donors': h_bond_donors,
'h_bond_acceptors': h_bond_acceptors,
'logp': round(logp, 2),
'tpsa': round(tpsa, 2),
'rotatable_bonds': rotatable_bonds,
'heavy_atoms': heavy_atoms,
'lipinski_violations': violations
}
def generate_molecule(self, temperature=1.0, max_length=30, top_k=50):
"""Generate a complete molecule using MTP"""
if self.model is None:
return "Model not loaded!", None, "❌ Model not loaded"
try:
# Use the same logic as your reference code
input_ids = self.tokenizer("<s>", return_tensors="pt").input_ids.to(self.device)
if hasattr(self.model, 'generate_with_logprobs'):
print("Using MTP-specific generation...")
outputs = self.model.generate_with_logprobs(
input_ids,
max_new_tokens=max_length,
temperature=temperature,
top_k=top_k,
do_sample=True,
return_probs=True,
tokenizer=self.tokenizer
)
# Extract tokens from MTP output (index 2)
gen_tokens = outputs[2]
else:
print("Using standard generation...")
gen_tokens = self.model.generate(
input_ids,
max_length=input_ids.shape[1] + max_length,
temperature=temperature,
top_k=top_k,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id if hasattr(self.tokenizer, 'pad_token_id') else 0
)
# Decode the generated molecule
generatedmol = self.tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
selfies_str = generatedmol.replace(' ', '')
smiles = sf.decoder(selfies_str)
info_text = f"Generated SELFIES: {selfies_str}\n"
info_text += f"Decoded SMILES: {smiles}\n"
# Visualize molecule
mol_image = None
property_text = ""
if smiles:
mol = Chem.MolFromSmiles(smiles)
if mol:
# Generate molecule image
img = Draw.MolToImage(mol, size=(400, 400))
mol_image = img
# Calculate Lipinski properties
props = self.calculate_lipinski_properties(mol)
property_text = "πŸ§ͺ Molecular Properties (Lipinski's Rule of Five):\n"
property_text += f"β€’ Molecular Weight: {props['molecular_weight']} g/mol\n"
property_text += f"β€’ H-Bond Donors: {props['h_bond_donors']}\n"
property_text += f"β€’ H-Bond Acceptors: {props['h_bond_acceptors']}\n"
property_text += f"β€’ LogP: {props['logp']}\n"
property_text += f"β€’ TPSA: {props['tpsa']} Γ…Β²\n"
property_text += f"β€’ Rotatable Bonds: {props['rotatable_bonds']}\n"
property_text += f"β€’ Heavy Atoms: {props['heavy_atoms']}\n"
property_text += f"β€’ Lipinski Violations: {props['lipinski_violations']}/4\n"
# Rule of Five assessment
if props['lipinski_violations'] <= 1:
property_text += "βœ… Drug-like molecule (Lipinski compliant)"
else:
property_text += f"⚠️ May have poor bioavailability ({props['lipinski_violations']} violations)"
info_text += "βœ… Valid molecule generated!"
else:
property_text = "⚠️ Could not calculate properties - invalid SMILES structure"
info_text += "⚠️ Invalid SMILES structure"
else:
property_text = "⚠️ Could not calculate properties - could not decode to SMILES"
info_text += "⚠️ Could not decode to SMILES"
return info_text, mol_image, property_text
except Exception as e:
return f"Error generating molecule: {str(e)}", None, "❌ Error calculating properties"
def create_simple_interface():
"""Create the simplified Gradio interface"""
app = SimpleMolecularApp()
print("Initializing model...")
if not app.initialize_model():
print("Failed to initialize model!")
return None
print("Model initialized successfully!")
with gr.Blocks(title="πŸ§ͺ Simple Molecular Generation", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ§ͺ Simple Molecular Generation with MTP
Generate complete molecules using the ChemQ3-MTP model with Lipinski properties.
""")
with gr.Row():
with gr.Column():
temp_slider = gr.Slider(
minimum=0.1, maximum=2.0, value=1.0,
label="Temperature", info="Higher = more random",
step=0.1 # Float step for temperature
)
length_slider = gr.Slider(
minimum=10, maximum=50, value=30,
label="Max Length", info="Max tokens to generate",
step=1, # Integer step for length
precision=0 # Force integer
)
topk_slider = gr.Slider(
minimum=10, maximum=100, value=50,
label="Top-K", info="Sampling diversity",
step=1, # Integer step for top-k
precision=0 # Force integer
)
generate_btn = gr.Button("πŸ§ͺ Generate Molecule", variant="primary")
with gr.Column():
mol_info = gr.Textbox(
label="Molecule Information",
lines=5,
interactive=False
)
mol_image = gr.Image(
label="Generated Molecule",
type="pil"
)
# New section for molecular properties
with gr.Row():
with gr.Column():
property_info = gr.Textbox(
label="Molecular Properties (Lipinski's Rule of Five)",
lines=10,
interactive=False
)
generate_btn.click(
fn=app.generate_molecule,
inputs=[temp_slider, length_slider, topk_slider],
outputs=[mol_info, mol_image, property_info]
)
gr.Examples(
examples=[
[1.0, 30, 50], # Default
[0.8, 25, 40], # More focused
[1.5, 35, 60], # More random
],
inputs=[temp_slider, length_slider, topk_slider],
fn=app.generate_molecule,
outputs=[mol_info, mol_image, property_info],
cache_examples=True
)
return demo
if __name__ == "__main__":
demo = create_simple_interface()
if demo:
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
else:
print("Failed to create interface!")