File size: 10,984 Bytes
cd35cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
"""
Pip's Artist - Image generation with load balancing.
Distributes image generation across multiple providers.
"""

import asyncio
from typing import Optional, Literal
from dataclasses import dataclass
import random

from services.openai_client import OpenAIClient
from services.gemini_client import GeminiClient
from services.modal_flux import ModalFluxClient


@dataclass
class GeneratedImage:
    """Result from image generation."""
    image_data: str  # URL or base64
    provider: str
    is_url: bool = True
    error: Optional[str] = None


ImageProvider = Literal["openai", "gemini", "flux", "sdxl_lightning"]


class PipArtist:
    """
    Load-balanced image generation for Pip.
    Distributes requests across providers to avoid rate limits and utilize all credits.
    """
    
    def __init__(self):
        self.openai = OpenAIClient()
        self.gemini = GeminiClient()
        self.modal_flux = ModalFluxClient()
        
        # Provider rotation index
        self._current_index = 0
        
        # Available providers in rotation order
        # Flux first (via HuggingFace router - most reliable)
        # Gemini has rate limits on free tier, OpenAI requires paid account
        self.providers: list[ImageProvider] = ["flux", "gemini", "openai"]
        
        # Provider health tracking
        self._provider_failures: dict[str, int] = {p: 0 for p in self.providers}
        self._max_failures = 3  # Temporarily skip after this many consecutive failures
        
        # Check if OpenAI is available (has credits)
        self._openai_available = True  # Will be set to False on first 429 error
        
        # Gemini rate limit tracking
        self._gemini_available = True  # Will be set to False on 429 error
    
    def _get_next_provider(self) -> ImageProvider:
        """
        Get next provider using round-robin with health awareness.
        """
        attempts = 0
        while attempts < len(self.providers):
            provider = self.providers[self._current_index]
            self._current_index = (self._current_index + 1) % len(self.providers)
            
            # Skip OpenAI if it has quota issues
            if provider == "openai" and not self._openai_available:
                attempts += 1
                continue
            
            # Skip Gemini if it has rate limit issues
            if provider == "gemini" and not self._gemini_available:
                attempts += 1
                continue
            
            # Skip if provider has too many recent failures
            if self._provider_failures[provider] < self._max_failures:
                return provider
            
            attempts += 1
        
        # Reset failures if all providers are failing (except permanent quota issues)
        self._provider_failures = {p: 0 for p in self.providers}
        return self.providers[0]  # Default to Flux
    
    def _mark_success(self, provider: str):
        """Mark provider as successful, reset failure count."""
        self._provider_failures[provider] = 0
    
    def _mark_failure(self, provider: str):
        """Mark provider as failed."""
        self._provider_failures[provider] += 1
    
    async def generate(
        self, 
        prompt: str,
        style: str = "vivid",
        preferred_provider: Optional[ImageProvider] = None
    ) -> GeneratedImage:
        """
        Generate an image using load-balanced providers.
        
        Args:
            prompt: The image generation prompt
            style: Style hint ("vivid", "natural", "artistic", "dreamy")
            preferred_provider: Force a specific provider if needed
        
        Returns:
            GeneratedImage with either URL or base64 data
        """
        provider = preferred_provider or self._get_next_provider()
        
        # Skip disabled providers
        if provider == "openai" and not self._openai_available:
            provider = self._get_next_provider()
        if provider == "gemini" and not self._gemini_available:
            provider = self._get_next_provider()
        
        try:
            result = await self._generate_with_provider(prompt, provider, style)
            if result:
                self._mark_success(provider)
                return result
        except Exception as e:
            error_str = str(e).lower()
            print(f"Provider {provider} failed: {e}")
            
            # Detect quota/rate limit errors and disable providers
            if provider == "openai" and ("insufficient_quota" in error_str or "429" in error_str):
                print("OpenAI quota exceeded - disabling for this session.")
                self._openai_available = False
            elif provider == "gemini" and ("429" in error_str or "quota" in error_str or "rate" in error_str):
                print("Gemini rate limited - disabling for this session.")
                self._gemini_available = False
            
            self._mark_failure(provider)
        
        # Try fallback providers
        for fallback in self.providers:
            if fallback != provider:
                # Skip disabled providers
                if fallback == "openai" and not self._openai_available:
                    continue
                if fallback == "gemini" and not self._gemini_available:
                    continue
                    
                try:
                    result = await self._generate_with_provider(prompt, fallback, style)
                    if result:
                        self._mark_success(fallback)
                        return result
                except Exception as e:
                    error_str = str(e).lower()
                    print(f"Fallback {fallback} failed: {e}")
                    
                    if fallback == "openai" and ("insufficient_quota" in error_str or "429" in error_str):
                        self._openai_available = False
                    elif fallback == "gemini" and ("429" in error_str or "quota" in error_str or "rate" in error_str):
                        self._gemini_available = False
                    
                    self._mark_failure(fallback)
        
        # All providers failed
        return GeneratedImage(
            image_data="",
            provider="none",
            is_url=False,
            error="All image generation providers failed"
        )
    
    async def _generate_with_provider(
        self, 
        prompt: str, 
        provider: ImageProvider,
        style: str
    ) -> Optional[GeneratedImage]:
        """
        Generate image with a specific provider.
        """
        if provider == "openai":
            # Map style to OpenAI style parameter
            openai_style = "vivid" if style in ["vivid", "bright", "energetic"] else "natural"
            result = await self.openai.generate_image(prompt, openai_style)
            if result:
                return GeneratedImage(
                    image_data=result,
                    provider="openai",
                    is_url=True
                )
        
        elif provider == "gemini":
            result = await self.gemini.generate_image(prompt)
            if result:
                return GeneratedImage(
                    image_data=result,
                    provider="gemini",
                    is_url=False  # Gemini returns base64
                )
        
        elif provider == "flux":
            result = await self.modal_flux.generate_artistic(prompt)
            if result:
                return GeneratedImage(
                    image_data=result,
                    provider="flux",
                    is_url=False  # Returns base64
                )
        
        elif provider == "sdxl_lightning":
            result = await self.modal_flux.generate_fast(prompt)
            if result:
                return GeneratedImage(
                    image_data=result,
                    provider="sdxl_lightning",
                    is_url=False
                )
        
        return None
    
    async def generate_fast(self, prompt: str) -> GeneratedImage:
        """
        Generate image optimizing for speed over quality.
        Uses SDXL-Lightning when available.
        """
        # Try fast providers first
        fast_providers = ["sdxl_lightning", "flux", "openai"]
        
        for provider in fast_providers:
            try:
                result = await self._generate_with_provider(prompt, provider, "natural")
                if result:
                    return result
            except Exception as e:
                print(f"Fast generation with {provider} failed: {e}")
        
        # Fallback to regular generation
        return await self.generate(prompt)
    
    async def generate_artistic(self, prompt: str) -> GeneratedImage:
        """
        Generate image optimizing for artistic quality.
        Prefers Flux for dreamlike results.
        """
        return await self.generate(prompt, style="artistic", preferred_provider="flux")
    
    async def generate_for_mood(
        self, 
        prompt: str, 
        mood: str,
        action: str
    ) -> GeneratedImage:
        """
        Generate image appropriate for the emotional context.
        
        Args:
            prompt: Enhanced image prompt
            mood: Detected mood/emotion
            action: Pip's action (reflect, celebrate, comfort, etc.)
        """
        # Map moods/actions to best provider and style
        mood_provider_map = {
            "dreamy": "flux",
            "surreal": "flux",
            "artistic": "flux",
            "calm": "gemini",
            "peaceful": "gemini",
            "energetic": "openai",
            "photorealistic": "openai",
            "warm": "gemini",
        }
        
        action_style_map = {
            "reflect": "natural",
            "celebrate": "vivid",
            "comfort": "natural",
            "calm": "natural",
            "energize": "vivid",
            "curiosity": "artistic",
            "intervene": "artistic",  # Mysterious, wonder-provoking
        }
        
        preferred_provider = mood_provider_map.get(mood)
        style = action_style_map.get(action, "natural")
        
        return await self.generate(prompt, style, preferred_provider)
    
    def get_provider_stats(self) -> dict:
        """Get current provider health stats."""
        return {
            "current_index": self._current_index,
            "failures": self._provider_failures.copy(),
            "providers": self.providers.copy()
        }


# Convenience function for quick image generation
async def generate_mood_image(
    prompt: str,
    mood: str = "neutral",
    action: str = "reflect"
) -> GeneratedImage:
    """
    Quick function to generate a mood-appropriate image.
    """
    artist = PipArtist()
    return await artist.generate_for_mood(prompt, mood, action)