| import gradio as gr | |
| import torch | |
| import re | |
| model = None | |
| tokenizer = None | |
| def init(): | |
| from transformers import MT5ForConditionalGeneration, T5TokenizerFast | |
| import os | |
| global model, tokenizer | |
| hf_token = os.environ.get("HF_TOKEN") | |
| model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token) | |
| model.eval() | |
| tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base") | |
| tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']}) | |
| def correct(text): | |
| text = re.sub(r'\u200d', '<ZWJ>', text) | |
| inputs = tokenizer( | |
| text, | |
| return_tensors='pt', | |
| padding='do_not_pad', | |
| max_length=1024 | |
| ) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_length=1024, | |
| num_beams=1, | |
| do_sample=False, | |
| ) | |
| prediction = outputs[0] | |
| special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>') | |
| all_special_ids = set(tokenizer.all_special_ids) | |
| pred_tokens = prediction.cpu() | |
| tokens_list = pred_tokens.tolist() | |
| filtered_tokens = [ | |
| token for token in tokens_list | |
| if token == special_token_id_to_keep or token not in all_special_ids | |
| ] | |
| prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip() | |
| return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded) | |
| init() | |
| demo = gr.Interface(fn=correct, inputs="text", outputs="text") | |
| demo.launch() | |