import gradio as gr import torch import sys import os from pathlib import Path import importlib.util import huggingface_hub from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer import selfies as sf from rdkit import Chem from rdkit.Chem import Draw from rdkit.Chem import Descriptors, rdMolDescriptors import numpy as np from PIL import Image import io class SimpleMolecularApp: def __init__(self): self.model = None self.tokenizer = None self.config = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Shared modules path self.SHARED_MODULES_DIR = Path("./shared_modules") self.SHARED_MODULES_DIR.mkdir(exist_ok=True) # Download shared modules and tokenizer files once self._ensure_shared_modules() # Supported models self.SUPPORTED_MODELS = { "Non-RL Pretrained": { "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo", "subfolder": None, "local_dir": "./chemq3_non_rl_model" }, "RL Finetuned โ€“ Step 9000": { "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo", "subfolder": "ppo_checkpoints/model_step_9000", "local_dir": "./chemq3_rlnp_step9000" }, "RL Pareto Finetuned โ€“ Step 2250": { "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo-RL-checkpoints", "subfolder": "checkpoints-pareto/model_step_2250", "local_dir": "./chemq3_rlp_step2250" }, "RL Pareto Finetuned โ€“ Step 4500": { "repo_id": "gbyuvd/ChemMiniQ3-SAbRLo-RL-checkpoints", "subfolder": "checkpoints-pareto/model_step_4500", "local_dir": "./chemq3_rlp_step4500" } } def _ensure_shared_modules(self): """Download shared Python modules and tokenizer files from main repo""" print("๐Ÿ“ฆ Downloading shared modules and tokenizer files from main repo...") huggingface_hub.snapshot_download( repo_id="gbyuvd/ChemMiniQ3-SAbRLo", local_dir=str(self.SHARED_MODULES_DIR), allow_patterns=["*.py", "tokenizer*", "vocab*", "merges*", "special_tokens*", "tokenizer_config*"], resume_download=True ) print("โœ… Shared modules and tokenizer files ready!") def load_model_by_name(self, model_key): """Load a specific model by key from SUPPORTED_MODELS""" if model_key not in self.SUPPORTED_MODELS: print(f"โŒ Unknown model: {model_key}") return False config = self.SUPPORTED_MODELS[model_key] repo_id = config["repo_id"] subfolder = config["subfolder"] local_dir = config["local_dir"] print(f"๐Ÿ”„ Loading model: {model_key} from {repo_id}") # Download model weights/config only if subfolder: allow_patterns = [ f"{subfolder}/config.json", f"{subfolder}/pytorch_model.bin", f"{subfolder}/model.safetensors", f"{subfolder}/generation_config.json" ] huggingface_hub.snapshot_download( repo_id=repo_id, local_dir=local_dir, allow_patterns=allow_patterns, resume_download=True ) model_path = Path(local_dir) / subfolder else: # Non-RL: download all files (since no subfolder) huggingface_hub.snapshot_download( repo_id=repo_id, local_dir=local_dir, resume_download=True ) model_path = Path(local_dir) if not model_path.exists(): print(f"โŒ Model path not found: {model_path}") return False # Load custom modules from shared path loaded_modules = self.load_custom_modules_from_shared() if not loaded_modules: return False # Register model components config_class, model_class, tokenizer_class = self.register_model_components(loaded_modules) if not config_class: return False # Load model and tokenizer self.model, self.tokenizer, self.config = self.load_model_with_shared_tokenizer(model_path) if self.model is None: return False self.model = self.model.to(self.device) self.model.eval() print(f"โœ… Successfully loaded: {model_key}") return True def load_custom_modules_from_shared(self): """Load custom modules from shared directory""" if str(self.SHARED_MODULES_DIR) not in sys.path: sys.path.insert(0, str(self.SHARED_MODULES_DIR)) required_files = { 'configuration_chemq3mtp.py': 'configuration_chemq3mtp', 'modeling_chemq3mtp.py': 'modeling_chemq3mtp', 'FastChemTokenizerHF.py': 'FastChemTokenizerHF' } loaded_modules = {} for filename, module_name in required_files.items(): file_path = self.SHARED_MODULES_DIR / filename if not file_path.exists(): print(f"โŒ Required file not found in shared modules: {filename}") return None try: spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) loaded_modules[module_name] = module print(f" โœ… Loaded {filename} from shared modules") except Exception as e: print(f" โŒ Failed to load {filename}: {e}") return None return loaded_modules def register_model_components(self, loaded_modules): """Register the model components with transformers""" try: ChemQ3MTPConfig = loaded_modules['configuration_chemq3mtp'].ChemQ3MTPConfig ChemQ3MTPForCausalLM = loaded_modules['modeling_chemq3mtp'].ChemQ3MTPForCausalLM FastChemTokenizerSelfies = loaded_modules['FastChemTokenizerHF'].FastChemTokenizerSelfies AutoConfig.register("chemq3_mtp", ChemQ3MTPConfig) AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM) AutoTokenizer.register(ChemQ3MTPConfig, FastChemTokenizerSelfies) print("โœ… Model components registered successfully") return ChemQ3MTPConfig, ChemQ3MTPForCausalLM, FastChemTokenizerSelfies except Exception as e: print(f"โŒ Registration failed: {e}") return None, None, None def load_model_with_shared_tokenizer(self, model_path): """Load the model using the registered components with shared tokenizer""" try: config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=False) model = AutoModelForCausalLM.from_pretrained( str(model_path), config=config, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=False ) # Use custom tokenizer class with shared tokenizer files FastChemTokenizerSelfies = self.load_custom_modules_from_shared()['FastChemTokenizerHF'].FastChemTokenizerSelfies tokenizer = FastChemTokenizerSelfies.from_pretrained(str(self.SHARED_MODULES_DIR)) return model, tokenizer, config except Exception as e: print(f"โŒ Model loading failed: {e}") return None, None, None def calculate_lipinski_properties(self, mol): """Calculate Lipinski's Rule of Five properties""" if mol is None: return {} # Calculate molecular descriptors molecular_weight = Descriptors.MolWt(mol) h_bond_donors = rdMolDescriptors.CalcNumHBD(mol) # Hydrogen bond donors h_bond_acceptors = rdMolDescriptors.CalcNumHBA(mol) # Hydrogen bond acceptors logp = Descriptors.MolLogP(mol) # LogP (octanol-water partition coefficient) tpsa = Descriptors.TPSA(mol) # Topological Polar Surface Area rotatable_bonds = rdMolDescriptors.CalcNumRotatableBonds(mol) heavy_atoms = mol.GetNumHeavyAtoms() # Lipinski's Rule of Five violations violations = 0 if molecular_weight > 500: violations += 1 if h_bond_donors > 5: violations += 1 if h_bond_acceptors > 10: violations += 1 if logp > 5: violations += 1 return { 'molecular_weight': round(molecular_weight, 2), 'h_bond_donors': h_bond_donors, 'h_bond_acceptors': h_bond_acceptors, 'logp': round(logp, 2), 'tpsa': round(tpsa, 2), 'rotatable_bonds': rotatable_bonds, 'heavy_atoms': heavy_atoms, 'lipinski_violations': violations } def generate_molecule(self, temperature=1.0, max_length=30, top_k=50): """Generate a complete molecule using MTP""" if self.model is None: return "Model not loaded!", None, "โŒ Model not loaded" try: # Use the same logic as your reference code input_ids = self.tokenizer("", return_tensors="pt").input_ids.to(self.device) if hasattr(self.model, 'generate_with_logprobs'): print("Using MTP-specific generation...") outputs = self.model.generate_with_logprobs( input_ids, max_new_tokens=max_length, temperature=temperature, top_k=top_k, do_sample=True, return_probs=True, tokenizer=self.tokenizer ) # Extract tokens from MTP output (index 2) gen_tokens = outputs[2] else: print("Using standard generation...") gen_tokens = self.model.generate( input_ids, max_length=input_ids.shape[1] + max_length, temperature=temperature, top_k=top_k, do_sample=True, pad_token_id=self.tokenizer.pad_token_id if hasattr(self.tokenizer, 'pad_token_id') else 0 ) # Decode the generated molecule generatedmol = self.tokenizer.decode(gen_tokens[0], skip_special_tokens=True) selfies_str = generatedmol.replace(' ', '') smiles = sf.decoder(selfies_str) info_text = f"Generated SELFIES: {selfies_str}\n" info_text += f"Decoded SMILES: {smiles}\n" # Visualize molecule mol_image = None property_text = "" if smiles: mol = Chem.MolFromSmiles(smiles) if mol: # Generate molecule image img = Draw.MolToImage(mol, size=(400, 400)) mol_image = img # Calculate Lipinski properties props = self.calculate_lipinski_properties(mol) property_text = "๐Ÿงช Molecular Properties (Lipinski's Rule of Five):\n" property_text += f"โ€ข Molecular Weight: {props['molecular_weight']} g/mol\n" property_text += f"โ€ข H-Bond Donors: {props['h_bond_donors']}\n" property_text += f"โ€ข H-Bond Acceptors: {props['h_bond_acceptors']}\n" property_text += f"โ€ข LogP: {props['logp']}\n" property_text += f"โ€ข TPSA: {props['tpsa']} ร…ยฒ\n" property_text += f"โ€ข Rotatable Bonds: {props['rotatable_bonds']}\n" property_text += f"โ€ข Heavy Atoms: {props['heavy_atoms']}\n" property_text += f"โ€ข Lipinski Violations: {props['lipinski_violations']}/4\n" # Rule of Five assessment if props['lipinski_violations'] <= 1: property_text += "โœ… Drug-like molecule (Lipinski compliant)" else: property_text += f"โš ๏ธ May have poor bioavailability ({props['lipinski_violations']} violations)" info_text += "โœ… Valid molecule generated!" else: property_text = "โš ๏ธ Could not calculate properties - invalid SMILES structure" info_text += "โš ๏ธ Invalid SMILES structure" else: property_text = "โš ๏ธ Could not calculate properties - could not decode to SMILES" info_text += "โš ๏ธ Could not decode to SMILES" return info_text, mol_image, property_text except Exception as e: return f"Error generating molecule: {str(e)}", None, "โŒ Error calculating properties" def create_simple_interface(): """Create the simplified Gradio interface""" app = SimpleMolecularApp() # Preload default model (Non-RL) default_model = "Non-RL Pretrained" print(f"Initializing default model: {default_model}") if not app.load_model_by_name(default_model): print("Failed to initialize default model!") return None print("Model initialized successfully!") with gr.Blocks(title="๐Ÿงช ChemMiniQ3-SAbRLo Demo", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # ๐Ÿงช ChemMiniQ3-SAbRLo Demo Generate molecules using either the **Non-RL pretrained model** or **RL-finetuned checkpoints** optimized with a **ParetoRewards controller**. """) with gr.Row(): model_choice = gr.Dropdown( choices=list(app.SUPPORTED_MODELS.keys()), value=default_model, label="Select Model" ) load_btn = gr.Button("๐Ÿ” Load Selected Model", variant="secondary") # Model status indicator model_status = gr.Textbox( label="Model Status", value=f"โœ… Current Model: {default_model}", interactive=False, show_copy_button=True ) # Generation controls with gr.Row(): with gr.Column(): temp_slider = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, label="Temperature", info="Higher = more random", step=0.1 ) length_slider = gr.Slider( minimum=10, maximum=50, value=30, label="Max Length", info="Max tokens to generate", step=1, precision=0 ) topk_slider = gr.Slider( minimum=10, maximum=100, value=50, label="Top-K", info="Sampling diversity", step=1, precision=0 ) generate_btn = gr.Button("๐Ÿงช Generate Molecule", variant="primary") with gr.Column(): mol_info = gr.Textbox( label="Molecule Information", lines=5, interactive=False ) mol_image = gr.Image( label="Generated Molecule", type="pil" ) # Molecular properties section property_info = gr.Textbox( label="Molecular Properties (Lipinski's Rule of Five)", lines=10, interactive=False ) def load_model_wrapper(model_name): success = app.load_model_by_name(model_name) if success: status = f"โœ… Current Model: {model_name} (Ready to use!)" else: status = f"โŒ Failed to load: {model_name}" return status load_btn.click( fn=load_model_wrapper, inputs=model_choice, outputs=model_status ) # Generate molecule generate_btn.click( fn=app.generate_molecule, inputs=[temp_slider, length_slider, topk_slider], outputs=[mol_info, mol_image, property_info] ) gr.Examples( examples=[ [1.0, 30, 50], [0.8, 25, 40], [1.5, 35, 60], ], inputs=[temp_slider, length_slider, topk_slider], fn=app.generate_molecule, outputs=[mol_info, mol_image, property_info], cache_examples=False # Disable if model can change ) return demo if __name__ == "__main__": demo = create_simple_interface() if demo: demo.launch(server_name="0.0.0.0", server_port=7860, share=True) else: print("Failed to create interface!")