Spaces:
Running
Running
| # app.py | |
| import gradio as gr | |
| from utils import BEST_ACROSS, ROIS, load_roi_image, get_hypotheses_for_selection, load_images_and_views, load_roi_for_selection | |
| MODELS = ["SAE", "SAE_ICA", "ICA", "PCA", "NMF", "Voxels"] | |
| INITIAL_MODEL = "SAE_ICA" | |
| INITIAL_ROI = "EBA" | |
| INITIAL_HYP = "Airline livery" | |
| INITIAL_FID_IMAGES, INITIAL_VIEW_IMAGES = load_images_and_views( | |
| INITIAL_MODEL, INITIAL_ROI, INITIAL_HYP | |
| ) | |
| INITIAL_VIEW_IMAGE = INITIAL_VIEW_IMAGES[0] if INITIAL_VIEW_IMAGES else None | |
| INITIAL_ROI_IMAGE = load_roi_image(roi_name=INITIAL_ROI) | |
| # -------- CALLBACKS -------- | |
| def update_hypotheses(model, roi): | |
| new_hypotheses = get_hypotheses_for_selection(model, roi) | |
| return gr.Dropdown(choices=new_hypotheses, value=new_hypotheses[0] if new_hypotheses else None) | |
| def update_activating_images(model, roi, hyp): | |
| fid_images, brain_images = load_images_and_views(model, roi, hyp) | |
| # For Across, fetch the ROI image dynamically | |
| roi_img = None | |
| if roi.lower() == "across": | |
| override_roi = BEST_ACROSS.get(model, {}).get(hyp, {}).get("roi") | |
| roi_img = load_roi_image(roi_name=override_roi) | |
| else: | |
| roi_img = load_roi_image(roi_name=roi) | |
| initial_view = brain_images[0] if brain_images else None | |
| return fid_images, brain_images, initial_view, roi_img | |
| def update_view_on_click(evt: gr.SelectData, brain_images): | |
| idx = evt.index | |
| if brain_images and idx is not None: | |
| return brain_images[idx] | |
| return None | |
| # -------- UI -------- | |
| with gr.Blocks(fill_width=True) as demo: | |
| brain_state = gr.State(value=INITIAL_VIEW_IMAGES) | |
| gr.Markdown("## BrainExplore: Visual Concept Explorer") | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| label="Select Model", choices=MODELS, value=INITIAL_MODEL, scale=1 | |
| ) | |
| roi_dropdown = gr.Dropdown( | |
| label="Select ROI", choices=ROIS, value=INITIAL_ROI, scale=1 | |
| ) | |
| hyp_dropdown = gr.Dropdown( | |
| label="Select Hypothesis", | |
| choices=get_hypotheses_for_selection(INITIAL_MODEL, INITIAL_ROI), | |
| value=INITIAL_HYP, | |
| scale=1, | |
| ) | |
| with gr.Row(equal_height=True): | |
| # LEFT — ROI + VIEW image | |
| with gr.Column(scale=1): | |
| roi_image = gr.Image( | |
| label="Selected Brain Region (ROI)", | |
| value=INITIAL_ROI_IMAGE, # empty initially | |
| interactive=False, | |
| height=300, | |
| ) | |
| view_image = gr.Image( | |
| label="Viewing Angle", | |
| value=INITIAL_VIEW_IMAGE, # empty initially | |
| interactive=False, | |
| height=300, | |
| ) | |
| # RIGHT — gallery | |
| with gr.Column(scale=3): | |
| image_gallery = gr.Gallery( | |
| label="Top Activating Images", | |
| columns=None, | |
| value=INITIAL_FID_IMAGES, | |
| rows=1, | |
| object_fit="contain", | |
| preview=True, | |
| height=600, | |
| allow_preview=True, | |
| elem_id="gallery" | |
| ) | |
| # -------- CALLBACK WIRING -------- | |
| # Update hypotheses when model/ROI changes | |
| model_dropdown.change( | |
| fn=update_hypotheses, | |
| inputs=[model_dropdown, roi_dropdown], | |
| outputs=[hyp_dropdown], | |
| ) | |
| roi_dropdown.change( | |
| fn=update_hypotheses, | |
| inputs=[model_dropdown, roi_dropdown], | |
| outputs=[hyp_dropdown], | |
| ).then( | |
| fn=lambda model, roi, hyp: load_roi_for_selection(model, roi, hyp), | |
| inputs=[model_dropdown, roi_dropdown, hyp_dropdown], | |
| outputs=[roi_image], | |
| ) | |
| # Update gallery + brain state + initial view image | |
| for cb in [model_dropdown, roi_dropdown, hyp_dropdown]: | |
| cb.change( | |
| fn=update_activating_images, | |
| inputs=[model_dropdown, roi_dropdown, hyp_dropdown], | |
| outputs=[image_gallery, brain_state, view_image, roi_image], | |
| ) | |
| # Update view image when clicking a different gallery item | |
| image_gallery.select( | |
| fn=update_view_on_click, | |
| inputs=[brain_state], | |
| outputs=[view_image], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |