Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import plotly.express as px | |
| import time | |
| import os | |
| import tempfile | |
| import requests | |
| import duckdb | |
| import json | |
| from datasets import load_dataset | |
| from huggingface_hub import logout as hf_logout | |
| from gradio_rangeslider import RangeSlider | |
| # --- Constants --- | |
| TOP_K_CHOICES = list(range(5, 51, 5)) | |
| HF_DATASET_ID = "evijit/paperverse_daily_data" | |
| # Direct parquet file URL (public) | |
| PARQUET_URL = "https://huggingface.co/datasets/evijit/paperverse_daily_data/resolve/main/papers_with_semantic_taxonomy.parquet" | |
| TAXONOMY_JSON_PATH = "integrated_ml_taxonomy.json" | |
| # Simple content filters derived from the new dataset | |
| TAG_FILTER_CHOICES = [ | |
| "None", | |
| "Has Code", | |
| "Has Media", | |
| "Has Organization", | |
| ] | |
| # Load taxonomy from JSON file | |
| def load_taxonomy(): | |
| """Load the ML taxonomy from JSON file.""" | |
| try: | |
| with open(TAXONOMY_JSON_PATH, 'r') as f: | |
| taxonomy = json.load(f) | |
| # Extract choices for dropdowns | |
| categories = sorted(taxonomy.keys()) | |
| # Build subcategories and topics | |
| all_subcategories = set() | |
| all_topics = set() | |
| for category, subcats in taxonomy.items(): | |
| for subcat, topics in subcats.items(): | |
| all_subcategories.add(subcat) | |
| all_topics.update(topics) | |
| return { | |
| 'categories': ["All"] + categories, | |
| 'subcategories': ["All"] + sorted(all_subcategories), | |
| 'topics': ["All"] + sorted(all_topics), | |
| 'taxonomy': taxonomy | |
| } | |
| except Exception as e: | |
| print(f"Error loading taxonomy from JSON: {e}") | |
| return { | |
| 'categories': ["All"], | |
| 'subcategories': ["All"], | |
| 'topics': ["All"], | |
| 'taxonomy': {} | |
| } | |
| TAXONOMY_DATA = load_taxonomy() | |
| def _first_non_null(*values): | |
| for v in values: | |
| if v is None: | |
| continue | |
| # treat empty strings as null-ish | |
| if isinstance(v, str) and v.strip() == "": | |
| continue | |
| return v | |
| return None | |
| def _get_nested(row, *paths): | |
| """Try multiple dotted paths in a row that may contain dicts; return first non-null.""" | |
| for path in paths: | |
| cur = row | |
| ok = True | |
| for key in path.split('.'): | |
| if isinstance(cur, dict) and key in cur: | |
| cur = cur[key] | |
| else: | |
| ok = False | |
| break | |
| if ok and cur is not None: | |
| return cur | |
| return None | |
| def load_datasets_data(): | |
| """Load the PaperVerse Daily dataset from the Hugging Face Hub and normalize columns used by the app.""" | |
| start_time = time.time() | |
| print(f"Attempting to load dataset from Hugging Face Hub: {HF_DATASET_ID}") | |
| try: | |
| # First try: direct parquet download (avoids any auth header issues) | |
| try: | |
| print(f"Trying direct parquet download: {PARQUET_URL}") | |
| with requests.get(PARQUET_URL, stream=True, timeout=120) as resp: | |
| resp.raise_for_status() | |
| with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmpf: | |
| for chunk in resp.iter_content(chunk_size=1024 * 1024): | |
| if chunk: | |
| tmpf.write(chunk) | |
| tmp_path = tmpf.name | |
| try: | |
| # Use DuckDB to read parquet to avoid pyarrow decoding issues | |
| df = duckdb.query(f"SELECT * FROM read_parquet('{tmp_path}')").df() | |
| finally: | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| print("Loaded DataFrame from direct parquet download via DuckDB.") | |
| except Exception as direct_e: | |
| print(f"Direct parquet load failed: {direct_e}. Falling back to datasets loader...") | |
| # Force anonymous access in case an invalid cached token is present | |
| # Clear any token environment variables that could inject a bad Authorization header | |
| for env_key in ("HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HF_HUB_TOKEN"): | |
| if os.environ.pop(env_key, None) is not None: | |
| print(f"Cleared env var: {env_key}") | |
| # Prefer explicit train split when available | |
| try: | |
| dataset_obj = load_dataset(HF_DATASET_ID, split="train", token=None) | |
| except TypeError: | |
| dataset_obj = load_dataset(HF_DATASET_ID, split="train", use_auth_token=False) | |
| except Exception: | |
| # Fallback: load all splits and pick the first available | |
| try: | |
| dataset_obj = load_dataset(HF_DATASET_ID, token=None) | |
| except TypeError: | |
| dataset_obj = load_dataset(HF_DATASET_ID, use_auth_token=False) | |
| # Handle both Dataset and DatasetDict | |
| try: | |
| # If it's a Dataset (single split), this will work | |
| df = dataset_obj.to_pandas() | |
| except AttributeError: | |
| # Otherwise assume DatasetDict and take the first split | |
| first_split = list(dataset_obj.keys())[0] | |
| df = dataset_obj[first_split].to_pandas() | |
| # --- Normalize expected columns for the visualization --- | |
| # organization: prefer top-level organization_name, then paper_organization.name/fullname, else Unknown | |
| if 'organization_name' in df.columns: | |
| org_series = df['organization_name'] | |
| else: | |
| # try nested dicts commonly produced by HF datasets | |
| org_series = df.apply( | |
| lambda r: _first_non_null( | |
| _get_nested(r, 'paper_organization.name'), | |
| _get_nested(r, 'paper_organization.fullname'), | |
| _get_nested(r, 'organization.name'), | |
| _get_nested(r, 'organization.fullname') | |
| ), axis=1 | |
| ) | |
| df['organization'] = org_series.fillna('Unknown') | |
| # Extract organization avatar/logo | |
| if 'organization_name' in df.columns: | |
| # Try to get avatar from paper_organization or organization struct | |
| def _get_avatar(row): | |
| for path in ['paper_organization.avatar', 'organization.avatar']: | |
| av = _get_nested(row, path) | |
| if av and isinstance(av, str) and av.strip(): | |
| return av | |
| return None | |
| org_avatar_series = df.apply(_get_avatar, axis=1) | |
| else: | |
| org_avatar_series = pd.Series([None] * len(df)) | |
| df['organization_avatar'] = org_avatar_series | |
| # id for each paper row | |
| cand_cols = [ | |
| 'paper_id', 'paper_discussionId', 'key' | |
| ] | |
| id_val = None | |
| for c in cand_cols: | |
| if c in df.columns: | |
| id_val = df[c] | |
| break | |
| if id_val is None: | |
| # fallback to title + index | |
| if 'paper_title' in df.columns: | |
| df['id'] = df['paper_title'].astype(str) + '_' + df.reset_index().index.astype(str) | |
| elif 'title' in df.columns: | |
| df['id'] = df['title'].astype(str) + '_' + df.reset_index().index.astype(str) | |
| else: | |
| df['id'] = df.reset_index().index.astype(str) | |
| else: | |
| df['id'] = id_val.astype(str) | |
| # numeric metrics used for aggregation | |
| def _to_num(col_name): | |
| if col_name in df.columns: | |
| return pd.to_numeric(df[col_name], errors='coerce').fillna(0.0) | |
| return pd.Series([0.0] * len(df)) | |
| df['paper_upvotes'] = _to_num('paper_upvotes') | |
| df['numComments'] = _to_num('numComments') | |
| df['paper_githubStars'] = _to_num('paper_githubStars') | |
| # computed boolean filters | |
| def _has_code(row): | |
| # Check for GitHub repo | |
| try: | |
| gh = row['paper_githubRepo'] if 'paper_githubRepo' in row and pd.notna(row['paper_githubRepo']) else None | |
| if isinstance(gh, str) and len(gh.strip()) > 0: | |
| return True | |
| except Exception: | |
| pass | |
| # Check for project page | |
| try: | |
| pp = row.get('paper_projectPage') if isinstance(row, dict) else row.get('paper_projectPage', None) | |
| if isinstance(pp, str) and len(str(pp).strip()) > 0 and str(pp).strip().lower() != 'n/a': | |
| return True | |
| except Exception: | |
| pass | |
| return False | |
| def _has_media(row): | |
| for c in ['paper_mediaUrls', 'mediaUrls']: | |
| try: | |
| v = row[c] | |
| if isinstance(v, list) and len(v) > 0: | |
| return True | |
| # some providers store arrays as strings like "[... ]" | |
| if isinstance(v, str) and v.strip().startswith('[') and len(v.strip()) > 2: | |
| return True | |
| except Exception: | |
| continue | |
| return False | |
| df['has_code'] = df.apply(_has_code, axis=1) | |
| df['has_media'] = df.apply(_has_media, axis=1) | |
| df['has_organization'] = df['organization'].astype(str).str.strip().ne('Unknown') | |
| # Process publishedAt field for date filtering | |
| if 'publishedAt' in df.columns: | |
| df['publishedAt_dt'] = pd.to_datetime(df['publishedAt'], errors='coerce') | |
| else: | |
| df['publishedAt_dt'] = pd.NaT | |
| # Ensure topic hierarchy columns exist and are strings | |
| for col_name, default_val in [ | |
| ('primary_category', 'Unknown'), | |
| ('primary_subcategory', 'Unknown'), | |
| ('primary_topic', 'Unknown'), | |
| ]: | |
| if col_name not in df.columns: | |
| df[col_name] = default_val | |
| else: | |
| df[col_name] = df[col_name].fillna(default_val).astype(str).replace({'': default_val}) | |
| # Create a human-friendly paper label for treemap leaves: "<title> β <topic>" | |
| def _pick_title(row): | |
| t1 = row.get('paper_title') if isinstance(row, dict) else None | |
| try: | |
| t1 = row['paper_title'] if 'paper_title' in row and pd.notna(row['paper_title']) and str(row['paper_title']).strip() != '' else None | |
| except Exception: | |
| pass | |
| if t1 is not None: | |
| return str(t1) | |
| try: | |
| t2 = row['title'] if 'title' in row and pd.notna(row['title']) and str(row['title']).strip() != '' else None | |
| except Exception: | |
| t2 = None | |
| return str(t2) if t2 is not None else 'Untitled' | |
| def _pick_topic(row): | |
| # Prefer primary_topic, else first of taxonomy_topics | |
| try: | |
| pt = row['primary_topic'] if 'primary_topic' in row and pd.notna(row['primary_topic']) and str(row['primary_topic']).strip() != '' else None | |
| except Exception: | |
| pt = None | |
| if pt is not None: | |
| return str(pt) | |
| try: | |
| tt = row['taxonomy_topics'] if 'taxonomy_topics' in row else None | |
| if isinstance(tt, list) and len(tt) > 0: | |
| return str(tt[0]) | |
| # Sometimes arrays are serialized as strings like "[ ... ]" | |
| if isinstance(tt, str) and tt.strip().startswith('[') and len(tt.strip()) > 2: | |
| # naive parse for first quoted token | |
| inner = tt.strip().lstrip('[').rstrip(']') | |
| first = inner.split(',')[0].strip().strip('"\'') | |
| return first if first else 'No topic' | |
| except Exception: | |
| pass | |
| return 'No topic' | |
| titles = df.apply(_pick_title, axis=1) | |
| df['paper_label'] = titles.astype(str) | |
| # Build a Topic Chain for hover details | |
| df['topic_chain'] = ( | |
| df['primary_category'].astype(str) + ' > ' + | |
| df['primary_subcategory'].astype(str) + ' > ' + | |
| df['primary_topic'].astype(str) | |
| ) | |
| # Ensure link fields exist for hover details | |
| for link_col in ['paper_githubRepo', 'paper_projectPage']: | |
| if link_col not in df.columns: | |
| df[link_col] = 'N/A' | |
| else: | |
| df[link_col] = df[link_col].fillna('N/A').replace({'': 'N/A'}) | |
| msg = f"Successfully loaded dataset in {time.time() - start_time:.2f}s." | |
| print(msg) | |
| return df, True, msg | |
| except Exception as e: | |
| # If we encountered invalid credentials, try logging out programmatically and retry once anonymously | |
| if "Invalid credentials" in str(e) or "401 Client Error" in str(e): | |
| try: | |
| print("Encountered auth error; attempting to clear cached token and retry anonymously...") | |
| hf_logout() | |
| try: | |
| dataset_dict = load_dataset(HF_DATASET_ID, token=None) | |
| except TypeError: | |
| dataset_dict = load_dataset(HF_DATASET_ID, use_auth_token=False) | |
| df = dataset_dict[list(dataset_dict.keys())[0]].to_pandas() | |
| msg = f"Successfully loaded dataset after clearing token in {time.time() - start_time:.2f}s." | |
| print(msg) | |
| return df, True, msg | |
| except Exception as e2: | |
| err_msg = f"Failed to load dataset after retry. Error: {e2} (initial: {e})" | |
| print(err_msg) | |
| return pd.DataFrame(), False, err_msg | |
| err_msg = f"Failed to load dataset. Error: {e}" | |
| print(err_msg) | |
| return pd.DataFrame(), False, err_msg | |
| def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None, group_by='organization', date_range=None): | |
| """ | |
| Filter data and prepare it for a multi-level treemap. | |
| - Preserves individual datasets for the top K organizations. | |
| - Groups all other organizations into a single "Other" category. | |
| - date_range: tuple of (min_timestamp, max_timestamp) in seconds since epoch | |
| """ | |
| if df is None or df.empty: | |
| return pd.DataFrame() | |
| filtered_df = df.copy() | |
| # Apply date range filter | |
| if date_range is not None and 'publishedAt_dt' in filtered_df.columns: | |
| min_ts, max_ts = date_range | |
| min_date = pd.to_datetime(min_ts, unit='s') | |
| max_date = pd.to_datetime(max_ts, unit='s') | |
| # Remove timezone info for comparison if publishedAt_dt is tz-naive | |
| if filtered_df['publishedAt_dt'].dt.tz is None: | |
| min_date = min_date.tz_localize(None) | |
| max_date = max_date.tz_localize(None) | |
| filtered_df = filtered_df[ | |
| (filtered_df['publishedAt_dt'] >= min_date) & | |
| (filtered_df['publishedAt_dt'] <= max_date) | |
| ] | |
| col_map = { | |
| "Has Code": "has_code", | |
| "Has Media": "has_media", | |
| "Has Organization": "has_organization", | |
| } | |
| if tag_filter and tag_filter != "None" and tag_filter in col_map: | |
| if col_map[tag_filter] in filtered_df.columns: | |
| filtered_df = filtered_df[filtered_df[col_map[tag_filter]]] | |
| if filtered_df.empty: | |
| return pd.DataFrame() | |
| if count_by not in filtered_df.columns: | |
| filtered_df[count_by] = 0.0 | |
| filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors='coerce').fillna(0.0) | |
| if group_by == 'organization': | |
| all_org_totals = filtered_df.groupby("organization")[count_by].sum() | |
| top_org_names = all_org_totals.nlargest(top_k, keep='first').index.tolist() | |
| top_orgs_df = filtered_df[filtered_df['organization'].isin(top_org_names)].copy() | |
| other_total = all_org_totals[~all_org_totals.index.isin(top_org_names)].sum() | |
| final_df_for_plot = top_orgs_df | |
| if other_total > 0: | |
| other_row = pd.DataFrame([{ | |
| 'organization': 'Other', | |
| 'paper_label': 'Other', | |
| 'primary_category': 'Other', | |
| 'primary_subcategory': 'Other', | |
| 'primary_topic': 'Other', | |
| 'topic_chain': 'Other > Other > Other', | |
| 'paper_githubRepo': 'N/A', | |
| 'paper_projectPage': 'N/A', | |
| 'organization_avatar': None, | |
| count_by: other_total | |
| }]) | |
| final_df_for_plot = pd.concat([final_df_for_plot, other_row], ignore_index=True) | |
| if skip_cats and len(skip_cats) > 0: | |
| final_df_for_plot = final_df_for_plot[~final_df_for_plot['organization'].isin(skip_cats)] | |
| final_df_for_plot["root"] = "papers" | |
| return final_df_for_plot | |
| else: | |
| # Topic grouping: apply top-k to topic combinations and handle skip list | |
| topic_totals = filtered_df.groupby(['primary_category', 'primary_subcategory', 'primary_topic'])[count_by].sum() | |
| top_topics = topic_totals.nlargest(top_k, keep='first').index.tolist() | |
| # Filter to top topics | |
| top_topics_df = filtered_df[ | |
| filtered_df.apply( | |
| lambda r: (r['primary_category'], r['primary_subcategory'], r['primary_topic']) in top_topics, | |
| axis=1 | |
| ) | |
| ].copy() | |
| # Apply skip filter (skip by primary_topic name) | |
| if skip_cats and len(skip_cats) > 0: | |
| top_topics_df = top_topics_df[~top_topics_df['primary_topic'].isin(skip_cats)] | |
| top_topics_df["root"] = "papers" | |
| return top_topics_df | |
| def create_treemap(treemap_data, count_by, title=None, path=None, metric_label=None): | |
| """Generate the Plotly treemap figure from the prepared data.""" | |
| if treemap_data.empty or treemap_data[count_by].sum() <= 0: | |
| fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1]) | |
| fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25)) | |
| return fig | |
| if path is None: | |
| path = ["root", "organization", "paper_label"] | |
| # Add custom data columns as regular columns for Plotly to access | |
| # This ensures all nodes (including intermediate hierarchy nodes) have these fields | |
| # Ensure organization_avatar column exists (for search details, not hover) | |
| if 'organization_avatar' not in treemap_data.columns: | |
| treemap_data['organization_avatar'] = None | |
| fig = px.treemap( | |
| treemap_data, | |
| path=path, | |
| values=count_by, | |
| hover_data={ | |
| 'primary_category': True, | |
| 'primary_subcategory': True, | |
| 'primary_topic': True, | |
| 'paper_githubRepo': True, | |
| 'paper_projectPage': True, | |
| }, | |
| title=title, | |
| color_discrete_sequence=px.colors.qualitative.Plotly | |
| ) | |
| fig.update_layout(margin=dict(t=50, l=25, r=25, b=25)) | |
| display_metric = metric_label if metric_label else count_by | |
| # Clean hover without organization avatar (images shown in search details instead) | |
| fig.update_traces( | |
| textinfo="label+value", | |
| hovertemplate=( | |
| "<b>%{label}</b><br>" | |
| + "%{value:,} " + display_metric + | |
| "<br><br><b>Topic Hierarchy:</b><br>" | |
| + "%{customdata[0]} > %{customdata[1]} > %{customdata[2]}<br>" | |
| + "<br><b>Links:</b><br>" | |
| + "GitHub: %{customdata[3]}<br>" | |
| + "Project: %{customdata[4]}" | |
| + "<extra></extra>" | |
| ), | |
| ) | |
| return fig | |
| # --- Gradio UI Blocks --- | |
| with gr.Blocks( | |
| title="π PaperVerse Daily Explorer", | |
| fill_width=True, | |
| css=""" | |
| /* Hide the timestamp numbers on the range slider */ | |
| #date-range-slider-wrapper .head, | |
| #date-range-slider-wrapper div[data-testid="range-slider"] > span { | |
| display: none !important; | |
| } | |
| """ | |
| ) as demo: | |
| datasets_data_state = gr.State(pd.DataFrame()) | |
| loading_complete_state = gr.State(False) | |
| date_range_state = gr.State(None) # Store min/max timestamps | |
| with gr.Row(): | |
| gr.Markdown("# π PaperVerse Daily Explorer") | |
| with gr.Tabs(): | |
| with gr.Tab("π Treemap Visualization"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| count_by_dropdown = gr.Dropdown( | |
| label="Metric", | |
| choices=[ | |
| ("Upvotes", "paper_upvotes"), | |
| ("Comments", "numComments"), | |
| ], | |
| value="paper_upvotes", | |
| ) | |
| group_by_dropdown = gr.Dropdown( | |
| label="Group by", | |
| choices=[("Organization", "organization"), ("Topic", "topic")], | |
| value="organization", | |
| ) | |
| gr.Markdown("**Filters**") | |
| filter_code = gr.Checkbox(label="Has Code", value=False) | |
| filter_media = gr.Checkbox(label="Has Media", value=False) | |
| filter_org = gr.Checkbox(label="Has Organization", value=False) | |
| gr.Markdown("**Date Range**") | |
| date_range_slider = RangeSlider( | |
| minimum=0, | |
| maximum=100, | |
| value=(0, 100), | |
| label="Paper Release Date Range", | |
| interactive=True, | |
| elem_id="date-range-slider-wrapper" | |
| ) | |
| date_range_display = gr.Markdown("Loading date range...") | |
| top_k_dropdown = gr.Dropdown(label="Number of Top Organizations", choices=TOP_K_CHOICES, value=25) | |
| category_filter_dropdown = gr.Dropdown(label="Primary Category", choices=["All"], value="All") | |
| subcategory_filter_dropdown = gr.Dropdown(label="Primary Subcategory", choices=["All"], value="All") | |
| topic_filter_dropdown = gr.Dropdown(label="Primary Topic", choices=["All"], value="All") | |
| skip_cats_textbox = gr.Textbox(label="Organizations to Skip", value="unaffiliated, Other") | |
| generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False) | |
| with gr.Column(scale=3): | |
| plot_output = gr.Plot() | |
| status_message_md = gr.Markdown("Initializing...") | |
| data_info_md = gr.Markdown("") | |
| with gr.Tab("π Paper Search"): | |
| with gr.Column(): | |
| gr.Markdown("### οΏ½ Search Papers and Organizations") | |
| with gr.Row(): | |
| search_item = gr.Textbox( | |
| label="Search Organization or Paper", | |
| placeholder="Type organization name or paper title to see details...", | |
| scale=4 | |
| ) | |
| search_button = gr.Button("Show Details", scale=1, variant="secondary") | |
| selected_info_html = gr.HTML(value="<p style='color: gray;'>Enter an organization name or paper title above to see details</p>") | |
| def _update_button_interactivity(is_loaded_flag): | |
| return gr.update(interactive=is_loaded_flag) | |
| def _format_date_range(date_range_tuple, date_range_value): | |
| """Convert slider values to readable date range text""" | |
| if date_range_tuple is None: | |
| return "Date range unavailable" | |
| min_ts, max_ts = date_range_tuple | |
| selected_min, selected_max = date_range_value | |
| # Convert slider values to timestamps | |
| # The slider values are already timestamps | |
| min_date = pd.to_datetime(selected_min, unit='s') | |
| max_date = pd.to_datetime(selected_max, unit='s') | |
| return f"**Selected Range:** {min_date.strftime('%B %d, %Y')} to {max_date.strftime('%B %d, %Y')}" | |
| def _toggle_labels_by_grouping(group_by_value): | |
| # Update labels based on grouping mode | |
| if group_by_value == 'topic': | |
| top_k_label = "Number of Top Topics" | |
| skip_label = "Topics to Skip" | |
| skip_value = "" # Clear skip box for topics | |
| else: | |
| top_k_label = "Number of Top Organizations" | |
| skip_label = "Organizations to Skip" | |
| skip_value = "unaffiliated, Other" # Default orgs to skip | |
| return ( | |
| gr.update(label=top_k_label), | |
| gr.update(label=skip_label, value=skip_value) | |
| ) | |
| ## CHANGE: New combined function to load data and generate the initial plot on startup. | |
| def load_and_generate_initial_plot(progress=gr.Progress()): | |
| progress(0, desc=f"Loading dataset '{HF_DATASET_ID}'...") | |
| # --- Part 1: Data Loading --- | |
| try: | |
| current_df, load_success_flag, status_msg_from_load = load_datasets_data() | |
| if load_success_flag: | |
| progress(0.5, desc="Processing data...") | |
| date_display = "Pre-processed (date unavailable)" | |
| if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]): | |
| ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True) | |
| date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z') | |
| # Calculate date range from publishedAt_dt | |
| min_ts = 0 | |
| max_ts = 100 | |
| date_range_text = "Date range unavailable" | |
| date_range_tuple = None | |
| if 'publishedAt_dt' in current_df.columns: | |
| valid_dates = current_df['publishedAt_dt'].dropna() | |
| if len(valid_dates) > 0: | |
| min_date = valid_dates.min() | |
| max_date = valid_dates.max() | |
| min_ts = int(min_date.timestamp()) | |
| max_ts = int(max_date.timestamp()) | |
| date_range_tuple = (min_ts, max_ts) | |
| date_range_text = f"**Full Range:** {min_date.strftime('%B %d, %Y')} to {max_date.strftime('%B %d, %Y')}" | |
| data_info_text = (f"### Data Information\n- Source: `{HF_DATASET_ID}`\n" | |
| f"- Status: {status_msg_from_load}\n" | |
| f"- Total records loaded: {len(current_df):,}\n" | |
| f"- Data as of: {date_display}\n") | |
| else: | |
| data_info_text = f"### Data Load Failed\n- {status_msg_from_load}" | |
| min_ts = 0 | |
| max_ts = 100 | |
| date_range_text = "Date range unavailable" | |
| date_range_tuple = None | |
| except Exception as e: | |
| status_msg_from_load = f"An unexpected error occurred: {str(e)}" | |
| data_info_text = f"### Critical Error\n- {status_msg_from_load}" | |
| load_success_flag = False | |
| current_df = pd.DataFrame() # Ensure df is empty on failure | |
| min_ts = 0 | |
| max_ts = 100 | |
| date_range_text = "Date range unavailable" | |
| date_range_tuple = None | |
| print(f"Critical error in load_and_generate_initial_plot: {e}") | |
| # --- Part 2: Generate Initial Plot --- | |
| progress(0.6, desc="Generating initial plot...") | |
| # Defaults matching UI definitions | |
| default_metric = "paper_upvotes" | |
| default_tag = "None" | |
| default_k = 25 | |
| default_group_by = "organization" | |
| default_skip_cats = "unaffiliated, Other" | |
| # Use taxonomy from JSON instead of calculating from dataset | |
| cat_choices = TAXONOMY_DATA['categories'] | |
| subcat_choices = TAXONOMY_DATA['subcategories'] | |
| topic_choices = TAXONOMY_DATA['topics'] | |
| # Reuse the existing controller function for plotting (with date range set to None for initial load) | |
| initial_plot, initial_status = ui_generate_plot_controller( | |
| default_metric, False, False, False, default_k, default_group_by, "All", "All", "All", default_skip_cats, None, current_df, progress | |
| ) | |
| # Also update taxonomy dropdown choices | |
| return ( | |
| current_df, | |
| load_success_flag, | |
| data_info_text, | |
| initial_status, | |
| initial_plot, | |
| gr.update(choices=cat_choices, value="All"), | |
| gr.update(choices=subcat_choices, value="All"), | |
| gr.update(choices=topic_choices, value="All"), | |
| gr.update(minimum=min_ts, maximum=max_ts, value=(min_ts, max_ts)), | |
| date_range_text, | |
| date_range_tuple, | |
| ) | |
| def ui_generate_plot_controller(metric_choice, has_code, has_media, has_org, | |
| k_orgs, group_by_choice, | |
| category_choice, subcategory_choice, topic_choice, | |
| skip_cats_input, date_range, df_current_datasets, progress=gr.Progress()): | |
| if df_current_datasets is None or df_current_datasets.empty: | |
| return create_treemap(pd.DataFrame(), metric_choice), "Dataset data is not loaded. Cannot generate plot." | |
| progress(0.1, desc="Aggregating data...") | |
| cats_to_skip = [cat.strip() for cat in skip_cats_input.split(',') if cat.strip()] | |
| # Apply content filters (checkboxes) | |
| df_filtered = df_current_datasets.copy() | |
| if has_code: | |
| df_filtered = df_filtered[df_filtered['has_code']] | |
| if has_media: | |
| df_filtered = df_filtered[df_filtered['has_media']] | |
| if has_org: | |
| df_filtered = df_filtered[df_filtered['has_organization']] | |
| # Apply taxonomy filters | |
| if category_choice and category_choice != 'All': | |
| df_filtered = df_filtered[df_filtered['primary_category'] == category_choice] | |
| if subcategory_choice and subcategory_choice != 'All': | |
| df_filtered = df_filtered[df_filtered['primary_subcategory'] == subcategory_choice] | |
| if topic_choice and topic_choice != 'All': | |
| df_filtered = df_filtered[df_filtered['primary_topic'] == topic_choice] | |
| treemap_df = make_treemap_data(df_filtered, metric_choice, k_orgs, None, cats_to_skip, group_by_choice, date_range) | |
| progress(0.7, desc="Generating plot...") | |
| title_labels = { | |
| "paper_upvotes": "Upvotes", | |
| "numComments": "Comments", | |
| } | |
| if group_by_choice == "topic": | |
| chart_title = f"PaperVerse Daily - {title_labels.get(metric_choice, metric_choice)} by Topic" | |
| path = ["root", "primary_category", "primary_subcategory", "primary_topic", "paper_label"] | |
| else: | |
| chart_title = f"PaperVerse Daily - {title_labels.get(metric_choice, metric_choice)} by Organization" | |
| path = ["root", "organization", "paper_label"] | |
| plotly_fig = create_treemap( | |
| treemap_df, | |
| metric_choice, | |
| chart_title, | |
| path=path, | |
| metric_label=title_labels.get(metric_choice, metric_choice), | |
| ) | |
| if treemap_df.empty: | |
| plot_stats_md = "No data matches the selected filters. Please try different options." | |
| else: | |
| total_value_in_plot = treemap_df[metric_choice].sum() | |
| total_items_in_plot = treemap_df[treemap_df['paper_label'] != 'Other']['paper_label'].nunique() | |
| if group_by_choice == "topic": | |
| group_count = treemap_df[["primary_category", "primary_subcategory", "primary_topic"]].drop_duplicates().shape[0] | |
| group_line = f"**Topics Shown**: {group_count:,} unique triplets" | |
| else: | |
| group_line = f"**Organizations Shown**: {treemap_df['organization'].nunique():,}" | |
| plot_stats_md = ( | |
| f"## Plot Statistics\n- {group_line}\n" | |
| f"- **Individual Papers Shown**: {total_items_in_plot:,}\n" | |
| f"- **Total {title_labels.get(metric_choice, metric_choice)} in plot**: {int(total_value_in_plot):,}" | |
| ) | |
| return plotly_fig, plot_stats_md | |
| # --- Event Wiring --- | |
| ## CHANGE: Updated demo.load to call the new function and to add plot_output to the outputs list. | |
| demo.load( | |
| fn=load_and_generate_initial_plot, | |
| inputs=[], | |
| outputs=[ | |
| datasets_data_state, | |
| loading_complete_state, | |
| data_info_md, | |
| status_message_md, | |
| plot_output, | |
| category_filter_dropdown, | |
| subcategory_filter_dropdown, | |
| topic_filter_dropdown, | |
| date_range_slider, | |
| date_range_display, | |
| date_range_state, | |
| ] | |
| ) | |
| loading_complete_state.change( | |
| fn=_update_button_interactivity, | |
| inputs=loading_complete_state, | |
| outputs=generate_plot_button | |
| ) | |
| # Update labels based on grouping mode | |
| group_by_dropdown.change( | |
| fn=_toggle_labels_by_grouping, | |
| inputs=group_by_dropdown, | |
| outputs=[top_k_dropdown, skip_cats_textbox], | |
| ) | |
| # Update date range display when slider changes | |
| date_range_slider.change( | |
| fn=_format_date_range, | |
| inputs=[date_range_state, date_range_slider], | |
| outputs=date_range_display, | |
| show_progress="hidden" | |
| ) | |
| def handle_search_details(search_text, df_current): | |
| """Search for an organization or paper and show detailed information.""" | |
| if not search_text or not search_text.strip(): | |
| return "<p style='color: gray;'>Please enter a search term</p>" | |
| if df_current is None or df_current.empty: | |
| return "<p style='color: gray;'>No data available</p>" | |
| search_text = search_text.strip() | |
| try: | |
| # Try to find matching rows by organization or paper title (case-insensitive partial match) | |
| matching_rows = df_current[ | |
| df_current['organization'].str.contains(search_text, case=False, na=False) | | |
| df_current['paper_label'].str.contains(search_text, case=False, na=False) | | |
| (df_current['paper_title'].str.contains(search_text, case=False, na=False) if 'paper_title' in df_current.columns else False) | |
| ] | |
| if matching_rows.empty: | |
| return f"<p style='color: orange;'>No results found for: <b>{search_text}</b></p><p style='color: gray;'>Try searching for an organization name (e.g., 'Qwen', 'Meta') or paper title keyword</p>" | |
| # Build the info panel HTML showing all matching results | |
| num_results = len(matching_rows) | |
| html_parts = [ | |
| f"<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; background: #f9f9f9; max-height: 600px; overflow-y: auto;'>", | |
| f"<h3 style='margin: 0 0 15px 0; color: #333;'>π Found {num_results} result{'s' if num_results > 1 else ''} for: <span style='color: #0366d6;'>{search_text}</span></h3>" | |
| ] | |
| # Limit to first 20 results to avoid too much content | |
| display_rows = matching_rows.head(20) | |
| for idx, (_, row) in enumerate(display_rows.iterrows()): | |
| # Add separator between results | |
| if idx > 0: | |
| html_parts.append("<hr style='margin: 15px 0; border: none; border-top: 1px solid #ddd;'/>") | |
| html_parts.append("<div style='margin-bottom: 10px; overflow: auto;'>") | |
| # Get organization avatar from precomputed column | |
| org_avatar = row.get('organization_avatar') | |
| # Organization logo if available | |
| if org_avatar and isinstance(org_avatar, str) and org_avatar.strip() and org_avatar.strip().lower() not in ['none', 'null', 'n/a', '']: | |
| html_parts.append(f"<img src='{org_avatar}' style='max-width: 60px; max-height: 60px; border-radius: 50%; margin-bottom: 8px; float: left; margin-right: 12px; border: 2px solid #ddd;' onerror=\"this.style.display='none'\"/>") | |
| # Get paper thumbnail (direct field from schema) | |
| paper_thumbnail = row.get('thumbnail') | |
| # Paper thumbnail if available | |
| if paper_thumbnail and isinstance(paper_thumbnail, str) and paper_thumbnail.strip() and paper_thumbnail.strip().lower() not in ['none', 'null', 'n/a', '']: | |
| html_parts.append(f"<img src='{paper_thumbnail}' style='max-width: 120px; max-height: 120px; border-radius: 8px; margin-bottom: 8px; float: right; margin-left: 12px; border: 1px solid #ddd;' onerror=\"this.style.display='none'\"/>") | |
| # Organization name | |
| org_name = row.get('organization', 'Unknown') | |
| html_parts.append(f"<p style='margin: 0 0 5px 0; font-weight: bold; color: #333;'>π’ {org_name}</p>") | |
| # Paper title | |
| paper_title = row.get('paper_title', row.get('title', 'Untitled')) | |
| html_parts.append(f"<p style='margin: 0 0 5px 0; color: #555; font-size: 0.95em;'>π {paper_title}</p>") | |
| # Topic hierarchy | |
| category = row.get('primary_category', 'Unknown') | |
| subcategory = row.get('primary_subcategory', 'Unknown') | |
| topic = row.get('primary_topic', 'Unknown') | |
| html_parts.append(f"<p style='margin: 0 0 5px 0; font-size: 0.9em; color: #666;'><b>Topics:</b> {category} β {subcategory} β {topic}</p>") | |
| # Metrics | |
| upvotes = row.get('paper_upvotes', 0) | |
| comments = row.get('numComments', 0) | |
| html_parts.append(f"<p style='margin: 0 0 5px 0; font-size: 0.9em;'><b>Metrics:</b> β¬οΈ {upvotes:,} upvotes | π¬ {comments:,} comments</p>") | |
| # Links | |
| github = row.get('paper_githubRepo') | |
| project = row.get('paper_projectPage') | |
| links = [] | |
| if github and isinstance(github, str) and github.strip() and github.strip().lower() not in ['n/a', 'none']: | |
| links.append(f"<a href='{github}' target='_blank' style='color: #0366d6; margin-right: 15px;'>π GitHub</a>") | |
| if project and isinstance(project, str) and project.strip() and project.strip().lower() not in ['n/a', 'none']: | |
| links.append(f"<a href='{project}' target='_blank' style='color: #0366d6;'>π Project</a>") | |
| if links: | |
| html_parts.append(f"<p style='margin: 0; font-size: 0.9em;'>{' '.join(links)}</p>") | |
| html_parts.append("<div style='clear: both;'></div>") | |
| html_parts.append("</div>") | |
| if num_results > 20: | |
| html_parts.append(f"<p style='margin-top: 15px; color: #666; font-style: italic;'>Showing first 20 of {num_results} results. Refine your search for fewer results.</p>") | |
| html_parts.append("</div>") | |
| return "".join(html_parts) | |
| except Exception as e: | |
| return f"<p style='color: red;'>Error displaying details: {str(e)}</p>" | |
| generate_plot_button.click( | |
| fn=ui_generate_plot_controller, | |
| inputs=[ | |
| count_by_dropdown, | |
| filter_code, | |
| filter_media, | |
| filter_org, | |
| top_k_dropdown, | |
| group_by_dropdown, | |
| category_filter_dropdown, | |
| subcategory_filter_dropdown, | |
| topic_filter_dropdown, | |
| skip_cats_textbox, | |
| date_range_slider, | |
| datasets_data_state, | |
| ], | |
| outputs=[plot_output, status_message_md] | |
| ) | |
| # Handle search button for showing details | |
| search_button.click( | |
| fn=handle_search_details, | |
| inputs=[search_item, datasets_data_state], | |
| outputs=[selected_info_html] | |
| ) | |
| # Also trigger on Enter key in search box | |
| search_item.submit( | |
| fn=handle_search_details, | |
| inputs=[search_item, datasets_data_state], | |
| outputs=[selected_info_html] | |
| ) | |
| if __name__ == "__main__": | |
| print("Application starting...") | |
| demo.queue().launch() |