Spaces:
Paused
Paused
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import transformers | |
| from pathlib import Path | |
| from transformers import pipeline | |
| from transformers.utils import logging | |
| # Log | |
| #logging.set_verbosity_debug() | |
| logger = logging.get_logger("transformers") | |
| # Pipelines | |
| device = 0 if torch.cuda.is_available() else "cpu" | |
| ## Automatic Speech Recognition | |
| ## https://huggingface.co/docs/transformers/task_summary#automatic-speech-recognition | |
| ## Require ffmpeg to be installed | |
| asr_model = "openai/whisper-tiny" | |
| asr = pipeline( | |
| "automatic-speech-recognition", | |
| model=asr_model, | |
| # torch_dtype=torch.float16, | |
| device=device | |
| ) | |
| ## Token Classification / Name Entity Recognition | |
| ## https://huggingface.co/docs/transformers/task_summary#token-classification | |
| tc_model = "dslim/distilbert-NER" | |
| tc = pipeline( | |
| "token-classification", # ner | |
| model=tc_model, | |
| device=device | |
| ) | |
| # --- | |
| # Transformers | |
| # https://www.gradio.app/main/docs/gradio/audio#behavior | |
| # As output component: expects audio data in any of these formats: | |
| # - a str or pathlib.Path filepath | |
| # - or URL to an audio file, | |
| # - or a bytes object (recommended for streaming), | |
| # - or a tuple of (sample rate in Hz, audio data as numpy array) | |
| def transcribe(audio: str | Path | bytes | tuple[int, np.ndarray] | None): | |
| if audio is None: | |
| return "..." | |
| # TODO Manage str/Path | |
| logger.debug("Transcribe") | |
| text = "" | |
| # https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__ | |
| # Whisper input format for tuple differ from output provided by gradio audio component | |
| if asr_model.startswith("openai/whisper"): | |
| inputs = {"sampling_rate": audio[0], "raw": audio[1]} if type(audio) is tuple else audio | |
| transcript = asr(inputs) | |
| text = transcript['text'] | |
| logger.debug("Tokenize:[" + text + "]") | |
| entities = tc(text) | |
| #logger.debug("Classify:[" + entities + "]") | |
| # TODO Add Text Classification for sentiment analysis | |
| return {"text": text, "entities": entities} | |
| # --- | |
| # Gradio | |
| ## Interfaces | |
| # https://www.gradio.app/main/docs/gradio/audio | |
| input_audio = gr.Audio( | |
| sources=["upload", "microphone"], | |
| show_share_button=False | |
| ) | |
| ## App | |
| gradio_app = gr.Interface( | |
| transcribe, | |
| inputs=[ | |
| input_audio | |
| ], | |
| outputs=[ | |
| gr.HighlightedText() | |
| ], | |
| title="ASRNERSBX", | |
| description=( | |
| "Transcribe, Tokenize, Classify" | |
| ), | |
| flagging_mode="never" | |
| ) | |
| ## Start! | |
| gradio_app.launch() |