pipV1 / pip_artist.py
Itsjustamit's picture
files for v1
cd35cc5 verified
"""
Pip's Artist - Image generation with load balancing.
Distributes image generation across multiple providers.
"""
import asyncio
from typing import Optional, Literal
from dataclasses import dataclass
import random
from services.openai_client import OpenAIClient
from services.gemini_client import GeminiClient
from services.modal_flux import ModalFluxClient
@dataclass
class GeneratedImage:
"""Result from image generation."""
image_data: str # URL or base64
provider: str
is_url: bool = True
error: Optional[str] = None
ImageProvider = Literal["openai", "gemini", "flux", "sdxl_lightning"]
class PipArtist:
"""
Load-balanced image generation for Pip.
Distributes requests across providers to avoid rate limits and utilize all credits.
"""
def __init__(self):
self.openai = OpenAIClient()
self.gemini = GeminiClient()
self.modal_flux = ModalFluxClient()
# Provider rotation index
self._current_index = 0
# Available providers in rotation order
# Flux first (via HuggingFace router - most reliable)
# Gemini has rate limits on free tier, OpenAI requires paid account
self.providers: list[ImageProvider] = ["flux", "gemini", "openai"]
# Provider health tracking
self._provider_failures: dict[str, int] = {p: 0 for p in self.providers}
self._max_failures = 3 # Temporarily skip after this many consecutive failures
# Check if OpenAI is available (has credits)
self._openai_available = True # Will be set to False on first 429 error
# Gemini rate limit tracking
self._gemini_available = True # Will be set to False on 429 error
def _get_next_provider(self) -> ImageProvider:
"""
Get next provider using round-robin with health awareness.
"""
attempts = 0
while attempts < len(self.providers):
provider = self.providers[self._current_index]
self._current_index = (self._current_index + 1) % len(self.providers)
# Skip OpenAI if it has quota issues
if provider == "openai" and not self._openai_available:
attempts += 1
continue
# Skip Gemini if it has rate limit issues
if provider == "gemini" and not self._gemini_available:
attempts += 1
continue
# Skip if provider has too many recent failures
if self._provider_failures[provider] < self._max_failures:
return provider
attempts += 1
# Reset failures if all providers are failing (except permanent quota issues)
self._provider_failures = {p: 0 for p in self.providers}
return self.providers[0] # Default to Flux
def _mark_success(self, provider: str):
"""Mark provider as successful, reset failure count."""
self._provider_failures[provider] = 0
def _mark_failure(self, provider: str):
"""Mark provider as failed."""
self._provider_failures[provider] += 1
async def generate(
self,
prompt: str,
style: str = "vivid",
preferred_provider: Optional[ImageProvider] = None
) -> GeneratedImage:
"""
Generate an image using load-balanced providers.
Args:
prompt: The image generation prompt
style: Style hint ("vivid", "natural", "artistic", "dreamy")
preferred_provider: Force a specific provider if needed
Returns:
GeneratedImage with either URL or base64 data
"""
provider = preferred_provider or self._get_next_provider()
# Skip disabled providers
if provider == "openai" and not self._openai_available:
provider = self._get_next_provider()
if provider == "gemini" and not self._gemini_available:
provider = self._get_next_provider()
try:
result = await self._generate_with_provider(prompt, provider, style)
if result:
self._mark_success(provider)
return result
except Exception as e:
error_str = str(e).lower()
print(f"Provider {provider} failed: {e}")
# Detect quota/rate limit errors and disable providers
if provider == "openai" and ("insufficient_quota" in error_str or "429" in error_str):
print("OpenAI quota exceeded - disabling for this session.")
self._openai_available = False
elif provider == "gemini" and ("429" in error_str or "quota" in error_str or "rate" in error_str):
print("Gemini rate limited - disabling for this session.")
self._gemini_available = False
self._mark_failure(provider)
# Try fallback providers
for fallback in self.providers:
if fallback != provider:
# Skip disabled providers
if fallback == "openai" and not self._openai_available:
continue
if fallback == "gemini" and not self._gemini_available:
continue
try:
result = await self._generate_with_provider(prompt, fallback, style)
if result:
self._mark_success(fallback)
return result
except Exception as e:
error_str = str(e).lower()
print(f"Fallback {fallback} failed: {e}")
if fallback == "openai" and ("insufficient_quota" in error_str or "429" in error_str):
self._openai_available = False
elif fallback == "gemini" and ("429" in error_str or "quota" in error_str or "rate" in error_str):
self._gemini_available = False
self._mark_failure(fallback)
# All providers failed
return GeneratedImage(
image_data="",
provider="none",
is_url=False,
error="All image generation providers failed"
)
async def _generate_with_provider(
self,
prompt: str,
provider: ImageProvider,
style: str
) -> Optional[GeneratedImage]:
"""
Generate image with a specific provider.
"""
if provider == "openai":
# Map style to OpenAI style parameter
openai_style = "vivid" if style in ["vivid", "bright", "energetic"] else "natural"
result = await self.openai.generate_image(prompt, openai_style)
if result:
return GeneratedImage(
image_data=result,
provider="openai",
is_url=True
)
elif provider == "gemini":
result = await self.gemini.generate_image(prompt)
if result:
return GeneratedImage(
image_data=result,
provider="gemini",
is_url=False # Gemini returns base64
)
elif provider == "flux":
result = await self.modal_flux.generate_artistic(prompt)
if result:
return GeneratedImage(
image_data=result,
provider="flux",
is_url=False # Returns base64
)
elif provider == "sdxl_lightning":
result = await self.modal_flux.generate_fast(prompt)
if result:
return GeneratedImage(
image_data=result,
provider="sdxl_lightning",
is_url=False
)
return None
async def generate_fast(self, prompt: str) -> GeneratedImage:
"""
Generate image optimizing for speed over quality.
Uses SDXL-Lightning when available.
"""
# Try fast providers first
fast_providers = ["sdxl_lightning", "flux", "openai"]
for provider in fast_providers:
try:
result = await self._generate_with_provider(prompt, provider, "natural")
if result:
return result
except Exception as e:
print(f"Fast generation with {provider} failed: {e}")
# Fallback to regular generation
return await self.generate(prompt)
async def generate_artistic(self, prompt: str) -> GeneratedImage:
"""
Generate image optimizing for artistic quality.
Prefers Flux for dreamlike results.
"""
return await self.generate(prompt, style="artistic", preferred_provider="flux")
async def generate_for_mood(
self,
prompt: str,
mood: str,
action: str
) -> GeneratedImage:
"""
Generate image appropriate for the emotional context.
Args:
prompt: Enhanced image prompt
mood: Detected mood/emotion
action: Pip's action (reflect, celebrate, comfort, etc.)
"""
# Map moods/actions to best provider and style
mood_provider_map = {
"dreamy": "flux",
"surreal": "flux",
"artistic": "flux",
"calm": "gemini",
"peaceful": "gemini",
"energetic": "openai",
"photorealistic": "openai",
"warm": "gemini",
}
action_style_map = {
"reflect": "natural",
"celebrate": "vivid",
"comfort": "natural",
"calm": "natural",
"energize": "vivid",
"curiosity": "artistic",
"intervene": "artistic", # Mysterious, wonder-provoking
}
preferred_provider = mood_provider_map.get(mood)
style = action_style_map.get(action, "natural")
return await self.generate(prompt, style, preferred_provider)
def get_provider_stats(self) -> dict:
"""Get current provider health stats."""
return {
"current_index": self._current_index,
"failures": self._provider_failures.copy(),
"providers": self.providers.copy()
}
# Convenience function for quick image generation
async def generate_mood_image(
prompt: str,
mood: str = "neutral",
action: str = "reflect"
) -> GeneratedImage:
"""
Quick function to generate a mood-appropriate image.
"""
artist = PipArtist()
return await artist.generate_for_mood(prompt, mood, action)