LipSyncAI / app.py
AlserFurma's picture
Update app.py
7802c36 verified
import gradio as gr
import os
from PIL import Image
import tempfile
from gradio_client import Client, handle_file
import torch
from transformers import VitsModel, AutoTokenizer, pipeline
import scipy.io.wavfile as wavfile
import traceback
import random
import time
import numpy as np
from pydub import AudioSegment
# =========================
# Параметры
# =========================
TALKING_HEAD_SPACE = "Skywork/skyreels-a1-talking-head"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# =========================
# Загрузка моделей
# =========================
try:
# TTS модель (казахский)
tts_model = VitsModel.from_pretrained("facebook/mms-tts-kaz").to(device)
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kaz")
# Настройка конфигурации для более приятного и выразительного голоса
tts_model.config.noise_scale = 0.5 # Меньше шума для чище голоса
tts_model.config.noise_scale_duration = 0.8 # Вариация в длительности
tts_model.config.speaking_rate = 0.9 # Чуть медленнее для выразительности
# Перевод ru -> kk
translator = pipeline(
"translation",
model="facebook/nllb-200-distilled-600M",
device=0 if device == "cuda" else -1
)
# Модель для генерации вопросов
qa_model = pipeline(
"text2text-generation",
model="google/flan-t5-small",
device=0 if device == "cuda" else -1
)
print("✅ Все модели успешно загружены!")
except Exception as e:
raise RuntimeError(f"❌ Ошибка загрузки моделей: {str(e)}")
# =========================
# Вспомогательные функции
# =========================
def generate_quiz(text: str):
""" Генерирует один вопрос и два варианта ответа на основе текста.
Алгоритмы:
1. Базовый: случайное предложение и первые слова.
2. Пропуск ключевого слова.
3. Вопрос о числе/дате.
"""
try:
sentences = [s.strip() for s in text.replace("!", ".").replace("?", ".").split(".") if s.strip()]
if len(sentences) < 1:
raise ValueError("Текст слишком короткий")
algo = random.choice([1, 2, 3])
# ------------------------
if algo == 1: # Базовый алгоритм
question_sentence = random.choice(sentences)
words = question_sentence.split()
if len(words) <= 3:
correct_answer = question_sentence
question = "Что сказано в этом предложении?"
else:
question = "Что сказано в тексте?"
correct_answer = " ".join(words[:6]) + ("..." if len(words) > 6 else "")
wrong_sentence = random.choice([s for s in sentences if s != question_sentence] or ["Другая информация"])
wrong_words = wrong_sentence.split()
wrong_answer = " ".join(wrong_words[:6]) + ("..." if len(wrong_words) > 6 else "")
# ------------------------
elif algo == 2: # Пропуск ключевого слова
question_sentence = random.choice(sentences)
words = question_sentence.split()
if len(words) > 2:
key_word = random.choice(words)
question = question_sentence.replace(key_word, "_____")
correct_answer = key_word
wrong_answer = random.choice([w for w in words if w != key_word] or ["другое"])
else:
# fallback
return generate_quiz(text)
# ------------------------
elif algo == 3: # Вопрос о числе или дате
import re
question_sentence = random.choice(sentences)
numbers = re.findall(r'\d+', question_sentence)
if numbers:
number = random.choice(numbers)
question = question_sentence.replace(number, "_____")
correct_answer = number
wrong_answer = str(int(number)+random.randint(1,5))
else:
# fallback к базовому
return generate_quiz(text)
options = [correct_answer, wrong_answer]
random.shuffle(options)
return question, options, correct_answer
except Exception as e:
raise ValueError(f"Ошибка генерации вопроса: {str(e)}")
def synthesize_audio(text_ru: str):
"""Переводит русскую строку на казахский, синтезирует аудио и возвращает путь к файлу .wav"""
translation = translator(text_ru, src_lang="rus_Cyrl", tgt_lang="kaz_Cyrl")
text_kk = translation[0]["translation_text"]
inputs = tts_tokenizer(text_kk, return_tensors="pt").to(device)
with torch.no_grad():
output = tts_model(**inputs)
waveform = output.waveform.squeeze().cpu().numpy()
waveform /= np.max(np.abs(waveform)) + 1e-8 # Нормализация для лучшего качества
audio = (waveform * 32767).astype('int16')
sampling_rate = getattr(tts_model.config, 'sampling_rate', 22050)
tmpf = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
wavfile.write(tmpf.name, sampling_rate, audio)
tmpf.close()
return tmpf.name
def concatenate_audio_files(audio_files):
"""Объединяет несколько аудио файлов в один с паузами между ними"""
combined = AudioSegment.empty()
pause = AudioSegment.silent(duration=1000) # 1 секунда паузы
for i, audio_file in enumerate(audio_files):
audio = AudioSegment.from_wav(audio_file)
combined += audio
if i < len(audio_files) - 1: # Не добавляем паузу после последнего файла
combined += pause
output_file = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
combined.export(output_file.name, format='wav')
output_file.close()
return output_file.name
def make_talking_head(image_path: str, audio_path: str, max_retries=3):
"""Вызывает SkyReels/Talking Head space и возвращает путь или URL видео."""
for attempt in range(max_retries):
try:
client = Client(TALKING_HEAD_SPACE)
result = client.predict(
image_path=handle_file(image_path),
audio_path=handle_file(audio_path),
guidance_scale=3.0,
steps=10,
api_name="/process_image_audio"
)
print(f"Result type: {type(result)}")
print(f"Result content: {result}")
if isinstance(result, tuple):
video_path = result[0]
if isinstance(video_path, dict) and "video" in video_path:
return video_path["video"]
elif isinstance(video_path, str):
return video_path
else:
for item in result:
if isinstance(item, str) and (item.endswith('.mp4') or item.endswith('.webm') or os.path.exists(str(item))):
return item
raise ValueError(f"Не удалось найти видео в результате: {result}")
elif isinstance(result, dict) and "video" in result:
return result["video"]
elif isinstance(result, str):
return result
else:
raise ValueError(f"Unexpected talking head result: {type(result)}, value: {result}")
except Exception as e:
if attempt < max_retries - 1:
print(f"Попытка {attempt + 1} не удалась: {e}. Повторяю через 2 секунды...")
time.sleep(2)
else:
raise Exception(f"Ошибка после {max_retries} попыток: {str(e)}")
# =========================
# Основные обработчики для Gradio
# =========================
def start_lesson(image: Image.Image, text: str, state):
"""Создает одно видео: текст лекции + вопрос с вариантами ответа"""
if image is None or not text.strip() or len(text) > 500:
return None, "Пожалуйста, загрузите фото и введите текст лекции (до 500 символов)", gr.update(visible=False), gr.update(visible=False), state
try:
# Сохраняем изображение
tmpimg = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
if image.mode != 'RGB':
image = image.convert('RGB')
image.save(tmpimg.name)
tmpimg.close()
image_path = tmpimg.name
# Генерируем вопрос
question, options, correct = generate_quiz(text)
# Создаем три аудио файла
audio_files = []
# 1. Текст лекции
audio1 = synthesize_audio(text)
audio_files.append(audio1)
# 2. Вопрос
question_text = f"А теперь вопрос: {question}"
audio2 = synthesize_audio(question_text)
audio_files.append(audio2)
# 3. Варианты ответа
options_text = f"Первый вариант: {options[0]}. Второй вариант: {options[1]}"
audio3 = synthesize_audio(options_text)
audio_files.append(audio3)
# Объединяем все аудио в одно
combined_audio = concatenate_audio_files(audio_files)
# Создаем одно видео с полным содержанием
video_path = make_talking_head(image_path, combined_audio)
# Сохраняем состояние
state_data = {
'image_path': image_path,
'correct': correct,
'options': options,
'question': question
}
# Удаляем временные аудио файлы
for audio_file in audio_files:
try:
os.remove(audio_file)
except:
pass
try:
os.remove(combined_audio)
except:
pass
question_display = f"**Вопрос:** {question}"
return (
video_path,
question_display,
gr.update(value=options[0], visible=True),
gr.update(value=options[1], visible=True),
state_data
)
except Exception as e:
traceback.print_exc()
return None, f"❌ Ошибка: {e}", gr.update(visible=False), gr.update(visible=False), state
def answer_selected(selected_option: str, state):
"""Генерирует реакцию лектора и показывает в том же окне"""
if not state:
return None, "❌ Ошибка: отсутствует состояние урока"
try:
correct = state.get('correct')
image_path = state.get('image_path')
if selected_option == correct:
reaction_ru = "Правильно! Отлично справились!"
display_message = "✅ **Дұрыс! Жарайсың!**"
else:
reaction_ru = f"К сожалению неправильно. Правильный ответ был: {correct}"
display_message = f"❌ **Қате!** Дұрыс жауап: **{correct}**"
# Создаем аудио с реакцией
audio_path = synthesize_audio(reaction_ru)
# Создаем видео с реакцией
reaction_video = make_talking_head(image_path, audio_path)
try:
os.remove(audio_path)
except:
pass
return reaction_video, display_message
except Exception as e:
traceback.print_exc()
return None, f"❌ Ошибка: {e}"
# =========================
# Gradio UI
# =========================
title = "🎓 Интерактивті Бейне Мұғалім TiлГен"
description = (
"**Қалай жұмыс істейді:**\n"
"1. Мұғалімнің суретін жүктеп, дәріс мәтінін енгізіңіз (орыс, 500 таңбаға дейін)\n"
"2. 'Сабақты бастау' түймесін басыңыз-мұғалім мәтінді оқып, сұрақ қояды\n"
"3. Дұрыс жауапты таңдаңыз-мұғалім сіздің жауабыңызға жауап береді"
)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# {title}\n{description}")
with gr.Row():
with gr.Column(scale=1):
inp_image = gr.Image(type='pil', label='📸 Мұғалімнің суреті')
inp_text = gr.Textbox(
lines=5,
label='📝 Дәріс мәтіні (орыс.)',
placeholder='Дәріс мәтінін енгізіңіз...',
info="Ең көбі 500 таңба"
)
btn_start = gr.Button("🚀 Сабақты бастау", variant="primary", size="lg")
with gr.Column(scale=1):
out_video = gr.Video(label='🎬 Мұғалімнің видеосы')
out_question = gr.Markdown("")
with gr.Row():
btn_opt1 = gr.Button("Вариант 1", visible=False, size="lg", variant="secondary")
btn_opt2 = gr.Button("Вариант 2", visible=False, size="lg", variant="secondary")
out_result = gr.Markdown("")
lesson_state = gr.State({})
# Запуск урока
btn_start.click(
fn=start_lesson,
inputs=[inp_image, inp_text, lesson_state],
outputs=[out_video, out_question, btn_opt1, btn_opt2, lesson_state]
)
# Обработка ответов
def handle_answer_1(state):
option = state.get('options', [''])[0] if state else ''
return answer_selected(option, state)
def handle_answer_2(state):
option = state.get('options', [''])[1] if state and len(state.get('options', [])) > 1 else ''
return answer_selected(option, state)
btn_opt1.click(
fn=handle_answer_1,
inputs=[lesson_state],
outputs=[out_video, out_result]
)
btn_opt2.click(
fn=handle_answer_2,
inputs=[lesson_state],
outputs=[out_video, out_result]
)
if __name__ == '__main__':
demo.launch(server_name="0.0.0.0", server_port=7860)