""" Pip's Artist - Image generation with load balancing. Distributes image generation across multiple providers. """ import asyncio from typing import Optional, Literal from dataclasses import dataclass import random from services.openai_client import OpenAIClient from services.gemini_client import GeminiClient from services.modal_flux import ModalFluxClient @dataclass class GeneratedImage: """Result from image generation.""" image_data: str # URL or base64 provider: str is_url: bool = True error: Optional[str] = None ImageProvider = Literal["openai", "gemini", "flux", "sdxl_lightning"] class PipArtist: """ Load-balanced image generation for Pip. Distributes requests across providers to avoid rate limits and utilize all credits. """ def __init__(self): self.openai = OpenAIClient() self.gemini = GeminiClient() self.modal_flux = ModalFluxClient() # Provider rotation index self._current_index = 0 # Available providers in rotation order # Flux first (via HuggingFace router - most reliable) # Gemini has rate limits on free tier, OpenAI requires paid account self.providers: list[ImageProvider] = ["flux", "gemini", "openai"] # Provider health tracking self._provider_failures: dict[str, int] = {p: 0 for p in self.providers} self._max_failures = 3 # Temporarily skip after this many consecutive failures # Check if OpenAI is available (has credits) self._openai_available = True # Will be set to False on first 429 error # Gemini rate limit tracking self._gemini_available = True # Will be set to False on 429 error def _get_next_provider(self) -> ImageProvider: """ Get next provider using round-robin with health awareness. """ attempts = 0 while attempts < len(self.providers): provider = self.providers[self._current_index] self._current_index = (self._current_index + 1) % len(self.providers) # Skip OpenAI if it has quota issues if provider == "openai" and not self._openai_available: attempts += 1 continue # Skip Gemini if it has rate limit issues if provider == "gemini" and not self._gemini_available: attempts += 1 continue # Skip if provider has too many recent failures if self._provider_failures[provider] < self._max_failures: return provider attempts += 1 # Reset failures if all providers are failing (except permanent quota issues) self._provider_failures = {p: 0 for p in self.providers} return self.providers[0] # Default to Flux def _mark_success(self, provider: str): """Mark provider as successful, reset failure count.""" self._provider_failures[provider] = 0 def _mark_failure(self, provider: str): """Mark provider as failed.""" self._provider_failures[provider] += 1 async def generate( self, prompt: str, style: str = "vivid", preferred_provider: Optional[ImageProvider] = None ) -> GeneratedImage: """ Generate an image using load-balanced providers. Args: prompt: The image generation prompt style: Style hint ("vivid", "natural", "artistic", "dreamy") preferred_provider: Force a specific provider if needed Returns: GeneratedImage with either URL or base64 data """ provider = preferred_provider or self._get_next_provider() # Skip disabled providers if provider == "openai" and not self._openai_available: provider = self._get_next_provider() if provider == "gemini" and not self._gemini_available: provider = self._get_next_provider() try: result = await self._generate_with_provider(prompt, provider, style) if result: self._mark_success(provider) return result except Exception as e: error_str = str(e).lower() print(f"Provider {provider} failed: {e}") # Detect quota/rate limit errors and disable providers if provider == "openai" and ("insufficient_quota" in error_str or "429" in error_str): print("OpenAI quota exceeded - disabling for this session.") self._openai_available = False elif provider == "gemini" and ("429" in error_str or "quota" in error_str or "rate" in error_str): print("Gemini rate limited - disabling for this session.") self._gemini_available = False self._mark_failure(provider) # Try fallback providers for fallback in self.providers: if fallback != provider: # Skip disabled providers if fallback == "openai" and not self._openai_available: continue if fallback == "gemini" and not self._gemini_available: continue try: result = await self._generate_with_provider(prompt, fallback, style) if result: self._mark_success(fallback) return result except Exception as e: error_str = str(e).lower() print(f"Fallback {fallback} failed: {e}") if fallback == "openai" and ("insufficient_quota" in error_str or "429" in error_str): self._openai_available = False elif fallback == "gemini" and ("429" in error_str or "quota" in error_str or "rate" in error_str): self._gemini_available = False self._mark_failure(fallback) # All providers failed return GeneratedImage( image_data="", provider="none", is_url=False, error="All image generation providers failed" ) async def _generate_with_provider( self, prompt: str, provider: ImageProvider, style: str ) -> Optional[GeneratedImage]: """ Generate image with a specific provider. """ if provider == "openai": # Map style to OpenAI style parameter openai_style = "vivid" if style in ["vivid", "bright", "energetic"] else "natural" result = await self.openai.generate_image(prompt, openai_style) if result: return GeneratedImage( image_data=result, provider="openai", is_url=True ) elif provider == "gemini": result = await self.gemini.generate_image(prompt) if result: return GeneratedImage( image_data=result, provider="gemini", is_url=False # Gemini returns base64 ) elif provider == "flux": result = await self.modal_flux.generate_artistic(prompt) if result: return GeneratedImage( image_data=result, provider="flux", is_url=False # Returns base64 ) elif provider == "sdxl_lightning": result = await self.modal_flux.generate_fast(prompt) if result: return GeneratedImage( image_data=result, provider="sdxl_lightning", is_url=False ) return None async def generate_fast(self, prompt: str) -> GeneratedImage: """ Generate image optimizing for speed over quality. Uses SDXL-Lightning when available. """ # Try fast providers first fast_providers = ["sdxl_lightning", "flux", "openai"] for provider in fast_providers: try: result = await self._generate_with_provider(prompt, provider, "natural") if result: return result except Exception as e: print(f"Fast generation with {provider} failed: {e}") # Fallback to regular generation return await self.generate(prompt) async def generate_artistic(self, prompt: str) -> GeneratedImage: """ Generate image optimizing for artistic quality. Prefers Flux for dreamlike results. """ return await self.generate(prompt, style="artistic", preferred_provider="flux") async def generate_for_mood( self, prompt: str, mood: str, action: str ) -> GeneratedImage: """ Generate image appropriate for the emotional context. Args: prompt: Enhanced image prompt mood: Detected mood/emotion action: Pip's action (reflect, celebrate, comfort, etc.) """ # Map moods/actions to best provider and style mood_provider_map = { "dreamy": "flux", "surreal": "flux", "artistic": "flux", "calm": "gemini", "peaceful": "gemini", "energetic": "openai", "photorealistic": "openai", "warm": "gemini", } action_style_map = { "reflect": "natural", "celebrate": "vivid", "comfort": "natural", "calm": "natural", "energize": "vivid", "curiosity": "artistic", "intervene": "artistic", # Mysterious, wonder-provoking } preferred_provider = mood_provider_map.get(mood) style = action_style_map.get(action, "natural") return await self.generate(prompt, style, preferred_provider) def get_provider_stats(self) -> dict: """Get current provider health stats.""" return { "current_index": self._current_index, "failures": self._provider_failures.copy(), "providers": self.providers.copy() } # Convenience function for quick image generation async def generate_mood_image( prompt: str, mood: str = "neutral", action: str = "reflect" ) -> GeneratedImage: """ Quick function to generate a mood-appropriate image. """ artist = PipArtist() return await artist.generate_for_mood(prompt, mood, action)