Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import sys | |
| import os | |
| from pathlib import Path | |
| import importlib.util | |
| import huggingface_hub | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
| import selfies as sf | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw | |
| from rdkit.Chem import Descriptors, rdMolDescriptors | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| class SimpleMolecularApp: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.config = None | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def download_and_setup_model(self, model_name="gbyuvd/ChemMiniQ3-SAbRLo", local_dir="./chemq3_model"): | |
| """Download model files and set up the custom modules""" | |
| print(f"π₯ Downloading model files from {model_name}...") | |
| try: | |
| model_path = huggingface_hub.snapshot_download( | |
| repo_id=model_name, | |
| local_dir=local_dir, | |
| local_files_only=False, | |
| resume_download=True | |
| ) | |
| print(f"β Model downloaded to: {model_path}") | |
| return Path(model_path) | |
| except Exception as e: | |
| print(f"β Download failed: {e}") | |
| return None | |
| def load_custom_modules(self, model_path): | |
| """Load all the custom modules required by the model""" | |
| model_path = Path(model_path) | |
| if str(model_path) not in sys.path: | |
| sys.path.insert(0, str(model_path)) | |
| required_files = { | |
| 'configuration_chemq3mtp.py': 'configuration_chemq3mtp', | |
| 'modeling_chemq3mtp.py': 'modeling_chemq3mtp', | |
| 'FastChemTokenizerHF.py': 'FastChemTokenizerHF' | |
| } | |
| loaded_modules = {} | |
| for filename, module_name in required_files.items(): | |
| file_path = model_path / filename | |
| if not file_path.exists(): | |
| print(f"β Required file not found: {filename}") | |
| return None | |
| try: | |
| spec = importlib.util.spec_from_file_location(module_name, file_path) | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| loaded_modules[module_name] = module | |
| print(f" β Loaded {filename}") | |
| except Exception as e: | |
| print(f" β Failed to load {filename}: {e}") | |
| return None | |
| return loaded_modules | |
| def register_model_components(self, loaded_modules): | |
| """Register the model components with transformers""" | |
| try: | |
| ChemQ3MTPConfig = loaded_modules['configuration_chemq3mtp'].ChemQ3MTPConfig | |
| ChemQ3MTPForCausalLM = loaded_modules['modeling_chemq3mtp'].ChemQ3MTPForCausalLM | |
| FastChemTokenizerSelfies = loaded_modules['FastChemTokenizerHF'].FastChemTokenizerSelfies | |
| AutoConfig.register("chemq3_mtp", ChemQ3MTPConfig) | |
| AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM) | |
| AutoTokenizer.register(ChemQ3MTPConfig, FastChemTokenizerSelfies) | |
| print("β Model components registered successfully") | |
| return ChemQ3MTPConfig, ChemQ3MTPForCausalLM, FastChemTokenizerSelfies | |
| except Exception as e: | |
| print(f"β Registration failed: {e}") | |
| return None, None, None | |
| def load_model(self, model_path): | |
| """Load the model using the registered components""" | |
| try: | |
| config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=False) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| str(model_path), | |
| config=config, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| trust_remote_code=False | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=False) | |
| return model, tokenizer, config | |
| except Exception as e: | |
| print(f"β Model loading failed: {e}") | |
| return None, None, None | |
| def initialize_model(self): | |
| """Initialize the model for the app""" | |
| model_name = "gbyuvd/ChemMiniQ3-SAbRLo" | |
| local_dir = "./chemq3_model" | |
| model_path = self.download_and_setup_model(model_name, local_dir) | |
| if model_path is None: | |
| return False | |
| loaded_modules = self.load_custom_modules(model_path) | |
| if loaded_modules is None: | |
| return False | |
| config_class, model_class, tokenizer_class = self.register_model_components(loaded_modules) | |
| if config_class is None: | |
| return False | |
| self.model, self.tokenizer, self.config = self.load_model(model_path) | |
| if self.model is None: | |
| return False | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| return True | |
| def calculate_lipinski_properties(self, mol): | |
| """Calculate Lipinski's Rule of Five properties""" | |
| if mol is None: | |
| return {} | |
| # Calculate molecular descriptors | |
| molecular_weight = Descriptors.MolWt(mol) | |
| h_bond_donors = rdMolDescriptors.CalcNumHBD(mol) # Hydrogen bond donors | |
| h_bond_acceptors = rdMolDescriptors.CalcNumHBA(mol) # Hydrogen bond acceptors | |
| logp = Descriptors.MolLogP(mol) # LogP (octanol-water partition coefficient) | |
| tpsa = Descriptors.TPSA(mol) # Topological Polar Surface Area | |
| rotatable_bonds = rdMolDescriptors.CalcNumRotatableBonds(mol) | |
| heavy_atoms = mol.GetNumHeavyAtoms() | |
| # Lipinski's Rule of Five violations | |
| violations = 0 | |
| if molecular_weight > 500: violations += 1 | |
| if h_bond_donors > 5: violations += 1 | |
| if h_bond_acceptors > 10: violations += 1 | |
| if logp > 5: violations += 1 | |
| return { | |
| 'molecular_weight': round(molecular_weight, 2), | |
| 'h_bond_donors': h_bond_donors, | |
| 'h_bond_acceptors': h_bond_acceptors, | |
| 'logp': round(logp, 2), | |
| 'tpsa': round(tpsa, 2), | |
| 'rotatable_bonds': rotatable_bonds, | |
| 'heavy_atoms': heavy_atoms, | |
| 'lipinski_violations': violations | |
| } | |
| def generate_molecule(self, temperature=1.0, max_length=30, top_k=50): | |
| """Generate a complete molecule using MTP""" | |
| if self.model is None: | |
| return "Model not loaded!", None, "β Model not loaded" | |
| try: | |
| # Use the same logic as your reference code | |
| input_ids = self.tokenizer("<s>", return_tensors="pt").input_ids.to(self.device) | |
| if hasattr(self.model, 'generate_with_logprobs'): | |
| print("Using MTP-specific generation...") | |
| outputs = self.model.generate_with_logprobs( | |
| input_ids, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| top_k=top_k, | |
| do_sample=True, | |
| return_probs=True, | |
| tokenizer=self.tokenizer | |
| ) | |
| # Extract tokens from MTP output (index 2) | |
| gen_tokens = outputs[2] | |
| else: | |
| print("Using standard generation...") | |
| gen_tokens = self.model.generate( | |
| input_ids, | |
| max_length=input_ids.shape[1] + max_length, | |
| temperature=temperature, | |
| top_k=top_k, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id if hasattr(self.tokenizer, 'pad_token_id') else 0 | |
| ) | |
| # Decode the generated molecule | |
| generatedmol = self.tokenizer.decode(gen_tokens[0], skip_special_tokens=True) | |
| selfies_str = generatedmol.replace(' ', '') | |
| smiles = sf.decoder(selfies_str) | |
| info_text = f"Generated SELFIES: {selfies_str}\n" | |
| info_text += f"Decoded SMILES: {smiles}\n" | |
| # Visualize molecule | |
| mol_image = None | |
| property_text = "" | |
| if smiles: | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol: | |
| # Generate molecule image | |
| img = Draw.MolToImage(mol, size=(400, 400)) | |
| mol_image = img | |
| # Calculate Lipinski properties | |
| props = self.calculate_lipinski_properties(mol) | |
| property_text = "π§ͺ Molecular Properties (Lipinski's Rule of Five):\n" | |
| property_text += f"β’ Molecular Weight: {props['molecular_weight']} g/mol\n" | |
| property_text += f"β’ H-Bond Donors: {props['h_bond_donors']}\n" | |
| property_text += f"β’ H-Bond Acceptors: {props['h_bond_acceptors']}\n" | |
| property_text += f"β’ LogP: {props['logp']}\n" | |
| property_text += f"β’ TPSA: {props['tpsa']} Γ Β²\n" | |
| property_text += f"β’ Rotatable Bonds: {props['rotatable_bonds']}\n" | |
| property_text += f"β’ Heavy Atoms: {props['heavy_atoms']}\n" | |
| property_text += f"β’ Lipinski Violations: {props['lipinski_violations']}/4\n" | |
| # Rule of Five assessment | |
| if props['lipinski_violations'] <= 1: | |
| property_text += "β Drug-like molecule (Lipinski compliant)" | |
| else: | |
| property_text += f"β οΈ May have poor bioavailability ({props['lipinski_violations']} violations)" | |
| info_text += "β Valid molecule generated!" | |
| else: | |
| property_text = "β οΈ Could not calculate properties - invalid SMILES structure" | |
| info_text += "β οΈ Invalid SMILES structure" | |
| else: | |
| property_text = "β οΈ Could not calculate properties - could not decode to SMILES" | |
| info_text += "β οΈ Could not decode to SMILES" | |
| return info_text, mol_image, property_text | |
| except Exception as e: | |
| return f"Error generating molecule: {str(e)}", None, "β Error calculating properties" | |
| def create_simple_interface(): | |
| """Create the simplified Gradio interface""" | |
| app = SimpleMolecularApp() | |
| print("Initializing model...") | |
| if not app.initialize_model(): | |
| print("Failed to initialize model!") | |
| return None | |
| print("Model initialized successfully!") | |
| with gr.Blocks(title="π§ͺ Simple Molecular Generation", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π§ͺ Simple Molecular Generation with MTP | |
| Generate complete molecules using the ChemQ3-MTP model with Lipinski properties. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| temp_slider = gr.Slider( | |
| minimum=0.1, maximum=2.0, value=1.0, | |
| label="Temperature", info="Higher = more random", | |
| step=0.1 # Float step for temperature | |
| ) | |
| length_slider = gr.Slider( | |
| minimum=10, maximum=50, value=30, | |
| label="Max Length", info="Max tokens to generate", | |
| step=1, # Integer step for length | |
| precision=0 # Force integer | |
| ) | |
| topk_slider = gr.Slider( | |
| minimum=10, maximum=100, value=50, | |
| label="Top-K", info="Sampling diversity", | |
| step=1, # Integer step for top-k | |
| precision=0 # Force integer | |
| ) | |
| generate_btn = gr.Button("π§ͺ Generate Molecule", variant="primary") | |
| with gr.Column(): | |
| mol_info = gr.Textbox( | |
| label="Molecule Information", | |
| lines=5, | |
| interactive=False | |
| ) | |
| mol_image = gr.Image( | |
| label="Generated Molecule", | |
| type="pil" | |
| ) | |
| # New section for molecular properties | |
| with gr.Row(): | |
| with gr.Column(): | |
| property_info = gr.Textbox( | |
| label="Molecular Properties (Lipinski's Rule of Five)", | |
| lines=10, | |
| interactive=False | |
| ) | |
| generate_btn.click( | |
| fn=app.generate_molecule, | |
| inputs=[temp_slider, length_slider, topk_slider], | |
| outputs=[mol_info, mol_image, property_info] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [1.0, 30, 50], # Default | |
| [0.8, 25, 40], # More focused | |
| [1.5, 35, 60], # More random | |
| ], | |
| inputs=[temp_slider, length_slider, topk_slider], | |
| fn=app.generate_molecule, | |
| outputs=[mol_info, mol_image, property_info], | |
| cache_examples=True | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_simple_interface() | |
| if demo: | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |
| else: | |
| print("Failed to create interface!") |