Spaces:
Running
Running
| """ | |
| Pip's Artist - Image generation with load balancing. | |
| Distributes image generation across multiple providers. | |
| """ | |
| import asyncio | |
| from typing import Optional, Literal | |
| from dataclasses import dataclass | |
| import random | |
| from services.openai_client import OpenAIClient | |
| from services.gemini_client import GeminiClient | |
| from services.modal_flux import ModalFluxClient | |
| class GeneratedImage: | |
| """Result from image generation.""" | |
| image_data: str # URL or base64 | |
| provider: str | |
| is_url: bool = True | |
| error: Optional[str] = None | |
| ImageProvider = Literal["openai", "gemini", "flux", "sdxl_lightning"] | |
| class PipArtist: | |
| """ | |
| Load-balanced image generation for Pip. | |
| Distributes requests across providers to avoid rate limits and utilize all credits. | |
| """ | |
| def __init__(self): | |
| self.openai = OpenAIClient() | |
| self.gemini = GeminiClient() | |
| self.modal_flux = ModalFluxClient() | |
| # Provider rotation index | |
| self._current_index = 0 | |
| # Available providers in rotation order | |
| # Flux first (via HuggingFace router - most reliable) | |
| # Gemini has rate limits on free tier, OpenAI requires paid account | |
| self.providers: list[ImageProvider] = ["flux", "gemini", "openai"] | |
| # Provider health tracking | |
| self._provider_failures: dict[str, int] = {p: 0 for p in self.providers} | |
| self._max_failures = 3 # Temporarily skip after this many consecutive failures | |
| # Check if OpenAI is available (has credits) | |
| self._openai_available = True # Will be set to False on first 429 error | |
| # Gemini rate limit tracking | |
| self._gemini_available = True # Will be set to False on 429 error | |
| def _get_next_provider(self) -> ImageProvider: | |
| """ | |
| Get next provider using round-robin with health awareness. | |
| """ | |
| attempts = 0 | |
| while attempts < len(self.providers): | |
| provider = self.providers[self._current_index] | |
| self._current_index = (self._current_index + 1) % len(self.providers) | |
| # Skip OpenAI if it has quota issues | |
| if provider == "openai" and not self._openai_available: | |
| attempts += 1 | |
| continue | |
| # Skip Gemini if it has rate limit issues | |
| if provider == "gemini" and not self._gemini_available: | |
| attempts += 1 | |
| continue | |
| # Skip if provider has too many recent failures | |
| if self._provider_failures[provider] < self._max_failures: | |
| return provider | |
| attempts += 1 | |
| # Reset failures if all providers are failing (except permanent quota issues) | |
| self._provider_failures = {p: 0 for p in self.providers} | |
| return self.providers[0] # Default to Flux | |
| def _mark_success(self, provider: str): | |
| """Mark provider as successful, reset failure count.""" | |
| self._provider_failures[provider] = 0 | |
| def _mark_failure(self, provider: str): | |
| """Mark provider as failed.""" | |
| self._provider_failures[provider] += 1 | |
| async def generate( | |
| self, | |
| prompt: str, | |
| style: str = "vivid", | |
| preferred_provider: Optional[ImageProvider] = None | |
| ) -> GeneratedImage: | |
| """ | |
| Generate an image using load-balanced providers. | |
| Args: | |
| prompt: The image generation prompt | |
| style: Style hint ("vivid", "natural", "artistic", "dreamy") | |
| preferred_provider: Force a specific provider if needed | |
| Returns: | |
| GeneratedImage with either URL or base64 data | |
| """ | |
| provider = preferred_provider or self._get_next_provider() | |
| # Skip disabled providers | |
| if provider == "openai" and not self._openai_available: | |
| provider = self._get_next_provider() | |
| if provider == "gemini" and not self._gemini_available: | |
| provider = self._get_next_provider() | |
| try: | |
| result = await self._generate_with_provider(prompt, provider, style) | |
| if result: | |
| self._mark_success(provider) | |
| return result | |
| except Exception as e: | |
| error_str = str(e).lower() | |
| print(f"Provider {provider} failed: {e}") | |
| # Detect quota/rate limit errors and disable providers | |
| if provider == "openai" and ("insufficient_quota" in error_str or "429" in error_str): | |
| print("OpenAI quota exceeded - disabling for this session.") | |
| self._openai_available = False | |
| elif provider == "gemini" and ("429" in error_str or "quota" in error_str or "rate" in error_str): | |
| print("Gemini rate limited - disabling for this session.") | |
| self._gemini_available = False | |
| self._mark_failure(provider) | |
| # Try fallback providers | |
| for fallback in self.providers: | |
| if fallback != provider: | |
| # Skip disabled providers | |
| if fallback == "openai" and not self._openai_available: | |
| continue | |
| if fallback == "gemini" and not self._gemini_available: | |
| continue | |
| try: | |
| result = await self._generate_with_provider(prompt, fallback, style) | |
| if result: | |
| self._mark_success(fallback) | |
| return result | |
| except Exception as e: | |
| error_str = str(e).lower() | |
| print(f"Fallback {fallback} failed: {e}") | |
| if fallback == "openai" and ("insufficient_quota" in error_str or "429" in error_str): | |
| self._openai_available = False | |
| elif fallback == "gemini" and ("429" in error_str or "quota" in error_str or "rate" in error_str): | |
| self._gemini_available = False | |
| self._mark_failure(fallback) | |
| # All providers failed | |
| return GeneratedImage( | |
| image_data="", | |
| provider="none", | |
| is_url=False, | |
| error="All image generation providers failed" | |
| ) | |
| async def _generate_with_provider( | |
| self, | |
| prompt: str, | |
| provider: ImageProvider, | |
| style: str | |
| ) -> Optional[GeneratedImage]: | |
| """ | |
| Generate image with a specific provider. | |
| """ | |
| if provider == "openai": | |
| # Map style to OpenAI style parameter | |
| openai_style = "vivid" if style in ["vivid", "bright", "energetic"] else "natural" | |
| result = await self.openai.generate_image(prompt, openai_style) | |
| if result: | |
| return GeneratedImage( | |
| image_data=result, | |
| provider="openai", | |
| is_url=True | |
| ) | |
| elif provider == "gemini": | |
| result = await self.gemini.generate_image(prompt) | |
| if result: | |
| return GeneratedImage( | |
| image_data=result, | |
| provider="gemini", | |
| is_url=False # Gemini returns base64 | |
| ) | |
| elif provider == "flux": | |
| result = await self.modal_flux.generate_artistic(prompt) | |
| if result: | |
| return GeneratedImage( | |
| image_data=result, | |
| provider="flux", | |
| is_url=False # Returns base64 | |
| ) | |
| elif provider == "sdxl_lightning": | |
| result = await self.modal_flux.generate_fast(prompt) | |
| if result: | |
| return GeneratedImage( | |
| image_data=result, | |
| provider="sdxl_lightning", | |
| is_url=False | |
| ) | |
| return None | |
| async def generate_fast(self, prompt: str) -> GeneratedImage: | |
| """ | |
| Generate image optimizing for speed over quality. | |
| Uses SDXL-Lightning when available. | |
| """ | |
| # Try fast providers first | |
| fast_providers = ["sdxl_lightning", "flux", "openai"] | |
| for provider in fast_providers: | |
| try: | |
| result = await self._generate_with_provider(prompt, provider, "natural") | |
| if result: | |
| return result | |
| except Exception as e: | |
| print(f"Fast generation with {provider} failed: {e}") | |
| # Fallback to regular generation | |
| return await self.generate(prompt) | |
| async def generate_artistic(self, prompt: str) -> GeneratedImage: | |
| """ | |
| Generate image optimizing for artistic quality. | |
| Prefers Flux for dreamlike results. | |
| """ | |
| return await self.generate(prompt, style="artistic", preferred_provider="flux") | |
| async def generate_for_mood( | |
| self, | |
| prompt: str, | |
| mood: str, | |
| action: str | |
| ) -> GeneratedImage: | |
| """ | |
| Generate image appropriate for the emotional context. | |
| Args: | |
| prompt: Enhanced image prompt | |
| mood: Detected mood/emotion | |
| action: Pip's action (reflect, celebrate, comfort, etc.) | |
| """ | |
| # Map moods/actions to best provider and style | |
| mood_provider_map = { | |
| "dreamy": "flux", | |
| "surreal": "flux", | |
| "artistic": "flux", | |
| "calm": "gemini", | |
| "peaceful": "gemini", | |
| "energetic": "openai", | |
| "photorealistic": "openai", | |
| "warm": "gemini", | |
| } | |
| action_style_map = { | |
| "reflect": "natural", | |
| "celebrate": "vivid", | |
| "comfort": "natural", | |
| "calm": "natural", | |
| "energize": "vivid", | |
| "curiosity": "artistic", | |
| "intervene": "artistic", # Mysterious, wonder-provoking | |
| } | |
| preferred_provider = mood_provider_map.get(mood) | |
| style = action_style_map.get(action, "natural") | |
| return await self.generate(prompt, style, preferred_provider) | |
| def get_provider_stats(self) -> dict: | |
| """Get current provider health stats.""" | |
| return { | |
| "current_index": self._current_index, | |
| "failures": self._provider_failures.copy(), | |
| "providers": self.providers.copy() | |
| } | |
| # Convenience function for quick image generation | |
| async def generate_mood_image( | |
| prompt: str, | |
| mood: str = "neutral", | |
| action: str = "reflect" | |
| ) -> GeneratedImage: | |
| """ | |
| Quick function to generate a mood-appropriate image. | |
| """ | |
| artist = PipArtist() | |
| return await artist.generate_for_mood(prompt, mood, action) | |