|
|
import os |
|
|
import random |
|
|
import torch |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import spaces |
|
|
from diffusers import DiffusionPipeline |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
import gradio_client.utils |
|
|
original_json_schema = gradio_client.utils._json_schema_to_python_type |
|
|
|
|
|
from PIL import ImageOps, ExifTags |
|
|
|
|
|
def preprocess_image(image): |
|
|
|
|
|
try: |
|
|
image = ImageOps.exif_transpose(image) |
|
|
except Exception as e: |
|
|
print(f"EXIF 변환 오류: {e}") |
|
|
|
|
|
|
|
|
if max(image.width, image.height) > 1024: |
|
|
image.thumbnail((1024, 1024), Image.LANCZOS) |
|
|
|
|
|
|
|
|
if image.mode != "RGB": |
|
|
image = image.convert("RGB") |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
def patched_json_schema(schema, defs=None): |
|
|
|
|
|
if isinstance(schema, bool): |
|
|
return "bool" |
|
|
|
|
|
|
|
|
try: |
|
|
if "additionalProperties" in schema and isinstance(schema["additionalProperties"], bool): |
|
|
schema["additionalProperties"] = {"type": "any"} |
|
|
except (TypeError, KeyError): |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
return original_json_schema(schema, defs) |
|
|
except Exception: |
|
|
return "any" |
|
|
|
|
|
gradio_client.utils._json_schema_to_python_type = patched_json_schema |
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
repo_id = "black-forest-labs/FLUX.1-dev" |
|
|
adapter_id = "openfree/flux-chatgpt-ghibli-lora" |
|
|
|
|
|
def load_model_with_retry(max_retries=5): |
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
print(f"Loading model attempt {attempt+1}/{max_retries}...") |
|
|
pipeline = DiffusionPipeline.from_pretrained( |
|
|
repo_id, |
|
|
torch_dtype=torch.bfloat16, |
|
|
use_safetensors=True, |
|
|
resume_download=True |
|
|
) |
|
|
print("Base model loaded successfully, now loading LoRA weights...") |
|
|
pipeline.load_lora_weights(adapter_id) |
|
|
pipeline = pipeline.to(device) |
|
|
print("Pipeline is ready!") |
|
|
return pipeline |
|
|
except Exception as e: |
|
|
if attempt < max_retries - 1: |
|
|
wait_time = 10 * (attempt + 1) |
|
|
print(f"Error loading model: {e}. Retrying in {wait_time} seconds...") |
|
|
import time |
|
|
time.sleep(wait_time) |
|
|
else: |
|
|
raise Exception(f"Failed to load model after {max_retries} attempts: {e}") |
|
|
|
|
|
pipeline = load_model_with_retry() |
|
|
|
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
MAX_IMAGE_SIZE = 1024 |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def inference( |
|
|
prompt: str, |
|
|
seed: int, |
|
|
randomize_seed: bool, |
|
|
width: int, |
|
|
height: int, |
|
|
guidance_scale: float, |
|
|
num_inference_steps: int, |
|
|
lora_scale: float, |
|
|
): |
|
|
|
|
|
if randomize_seed: |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
|
|
print(f"Running inference with prompt: {prompt}") |
|
|
|
|
|
try: |
|
|
image = pipeline( |
|
|
prompt=prompt, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
width=width, |
|
|
height=height, |
|
|
generator=generator, |
|
|
joint_attention_kwargs={"scale": lora_scale}, |
|
|
).images[0] |
|
|
return image, seed |
|
|
except Exception as e: |
|
|
print(f"Error during inference: {e}") |
|
|
|
|
|
error_img = Image.new('RGB', (width, height), color='red') |
|
|
return error_img, seed |
|
|
|
|
|
|
|
|
import subprocess |
|
|
try: |
|
|
subprocess.run( |
|
|
'pip install flash-attn --no-build-isolation', |
|
|
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, |
|
|
shell=True |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not install flash-attn: {e}") |
|
|
|
|
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
def load_caption_model(model_name): |
|
|
try: |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, trust_remote_code=True |
|
|
).eval() |
|
|
processor = AutoProcessor.from_pretrained( |
|
|
model_name, trust_remote_code=True |
|
|
) |
|
|
return model, processor |
|
|
except Exception as e: |
|
|
print(f"Error loading caption model {model_name}: {e}") |
|
|
return None, None |
|
|
|
|
|
|
|
|
print("Loading captioning models...") |
|
|
default_caption_model = 'microsoft/Florence-2-large' |
|
|
models = {} |
|
|
processors = {} |
|
|
|
|
|
|
|
|
default_model, default_processor = load_caption_model(default_caption_model) |
|
|
if default_model is not None and default_processor is not None: |
|
|
models[default_caption_model] = default_model |
|
|
processors[default_caption_model] = default_processor |
|
|
print(f"Successfully loaded default caption model: {default_caption_model}") |
|
|
else: |
|
|
|
|
|
fallback_model = 'gokaygokay/Florence-2-Flux' |
|
|
fallback_model_obj, fallback_processor = load_caption_model(fallback_model) |
|
|
if fallback_model_obj is not None and fallback_processor is not None: |
|
|
models[fallback_model] = fallback_model_obj |
|
|
processors[fallback_model] = fallback_processor |
|
|
default_caption_model = fallback_model |
|
|
print(f"Loaded fallback caption model: {fallback_model}") |
|
|
else: |
|
|
print("WARNING: Failed to load any caption model!") |
|
|
|
|
|
@spaces.GPU |
|
|
def caption_image(image, model_name=default_caption_model): |
|
|
""" |
|
|
Runs the selected Florence-2 model to generate a detailed caption. |
|
|
""" |
|
|
from PIL import Image as PILImage |
|
|
import numpy as np |
|
|
|
|
|
print(f"Starting caption generation with model: {model_name}") |
|
|
|
|
|
|
|
|
if isinstance(image, PILImage.Image): |
|
|
pil_image = image |
|
|
else: |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
pil_image = PILImage.fromarray(image) |
|
|
else: |
|
|
print(f"Unexpected image type: {type(image)}") |
|
|
return "Error: Unsupported image type" |
|
|
|
|
|
|
|
|
if pil_image.mode != "RGB": |
|
|
pil_image = pil_image.convert("RGB") |
|
|
|
|
|
|
|
|
if model_name not in models or model_name not in processors: |
|
|
available_models = list(models.keys()) |
|
|
if available_models: |
|
|
model_name = available_models[0] |
|
|
print(f"Requested model not available, using: {model_name}") |
|
|
else: |
|
|
return "Error: No caption models available" |
|
|
|
|
|
model = models[model_name] |
|
|
processor = processors[model_name] |
|
|
|
|
|
task_prompt = "<DESCRIPTION>" |
|
|
user_prompt = task_prompt + "Describe this image in great detail." |
|
|
|
|
|
try: |
|
|
inputs = processor(text=user_prompt, images=pil_image, return_tensors="pt") |
|
|
|
|
|
generated_ids = model.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
pixel_values=inputs["pixel_values"], |
|
|
max_new_tokens=1024, |
|
|
num_beams=3, |
|
|
repetition_penalty=1.10, |
|
|
) |
|
|
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
|
parsed_answer = processor.post_process_generation( |
|
|
generated_text, task=task_prompt, image_size=(pil_image.width, pil_image.height) |
|
|
) |
|
|
|
|
|
|
|
|
caption = parsed_answer.get("<DESCRIPTION>", "") |
|
|
print(f"Generated caption: {caption}") |
|
|
return caption |
|
|
except Exception as e: |
|
|
print(f"Error during captioning: {e}") |
|
|
return f"Error generating caption: {str(e)}" |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def process_uploaded_image( |
|
|
image, |
|
|
seed, |
|
|
randomize_seed, |
|
|
width, |
|
|
height, |
|
|
guidance_scale, |
|
|
num_inference_steps, |
|
|
lora_scale |
|
|
): |
|
|
if image is None: |
|
|
print("No image provided") |
|
|
return None, None, "No image provided", "No image provided" |
|
|
|
|
|
print("Starting image processing workflow") |
|
|
|
|
|
|
|
|
try: |
|
|
caption = caption_image(image) |
|
|
if caption.startswith("Error:"): |
|
|
print(f"Captioning failed: {caption}") |
|
|
|
|
|
caption = "A beautiful scene" |
|
|
except Exception as e: |
|
|
print(f"Exception during captioning: {e}") |
|
|
caption = "A beautiful scene" |
|
|
|
|
|
|
|
|
ghibli_prompt = f"{caption}, ghibli style" |
|
|
print(f"Final prompt for Ghibli generation: {ghibli_prompt}") |
|
|
|
|
|
|
|
|
try: |
|
|
generated_image, used_seed = inference( |
|
|
prompt=ghibli_prompt, |
|
|
seed=seed, |
|
|
randomize_seed=randomize_seed, |
|
|
width=width, |
|
|
height=height, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
lora_scale=lora_scale |
|
|
) |
|
|
|
|
|
print(f"Image generation complete with seed: {used_seed}") |
|
|
return generated_image, used_seed, caption, ghibli_prompt |
|
|
except Exception as e: |
|
|
print(f"Error generating image: {e}") |
|
|
error_img = Image.new('RGB', (width, height), color='red') |
|
|
return error_img, seed, caption, ghibli_prompt |
|
|
|
|
|
|
|
|
ghibli_theme = gr.themes.Soft( |
|
|
primary_hue="indigo", |
|
|
secondary_hue="blue", |
|
|
neutral_hue="slate", |
|
|
font=[gr.themes.GoogleFont("Nunito"), "ui-sans-serif", "sans-serif"], |
|
|
radius_size=gr.themes.sizes.radius_sm, |
|
|
).set( |
|
|
body_background_fill="#f0f9ff", |
|
|
body_background_fill_dark="#0f172a", |
|
|
button_primary_background_fill="#6366f1", |
|
|
button_primary_background_fill_hover="#4f46e5", |
|
|
button_primary_text_color="#ffffff", |
|
|
block_title_text_weight="600", |
|
|
block_border_width="1px", |
|
|
block_shadow="0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1)", |
|
|
) |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { |
|
|
max-width: 1200px !important; |
|
|
} |
|
|
.main-header { |
|
|
text-align: center; |
|
|
margin-bottom: 1rem; |
|
|
font-weight: 800; |
|
|
font-size: 2.5rem; |
|
|
background: linear-gradient(90deg, #4338ca, #3b82f6); |
|
|
-webkit-background-clip: text; |
|
|
-webkit-text-fill-color: transparent; |
|
|
padding: 0.5rem; |
|
|
} |
|
|
.tagline { |
|
|
text-align: center; |
|
|
font-size: 1.2rem; |
|
|
margin-bottom: 2rem; |
|
|
color: #4b5563; |
|
|
} |
|
|
.image-preview { |
|
|
border-radius: 12px; |
|
|
overflow: hidden; |
|
|
box-shadow: 0 10px 15px -3px rgb(0 0 0 / 0.1), 0 4px 6px -4px rgb(0 0 0 / 0.1); |
|
|
} |
|
|
.panel-box { |
|
|
border-radius: 12px; |
|
|
background-color: rgba(255, 255, 255, 0.8); |
|
|
padding: 1rem; |
|
|
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1); |
|
|
} |
|
|
.control-panel { |
|
|
padding: 1rem; |
|
|
border-radius: 12px; |
|
|
background-color: rgba(255, 255, 255, 0.9); |
|
|
margin-bottom: 1rem; |
|
|
border: 1px solid #e2e8f0; |
|
|
} |
|
|
.section-header { |
|
|
font-weight: 600; |
|
|
font-size: 1.1rem; |
|
|
margin-bottom: 0.5rem; |
|
|
color: #4338ca; |
|
|
} |
|
|
.transform-button { |
|
|
font-weight: 600 !important; |
|
|
margin-top: 1rem !important; |
|
|
} |
|
|
.footer { |
|
|
text-align: center; |
|
|
color: #6b7280; |
|
|
margin-top: 2rem; |
|
|
font-size: 0.9rem; |
|
|
} |
|
|
.output-panel { |
|
|
background: linear-gradient(135deg, #f0f9ff, #e0f2fe); |
|
|
border-radius: 12px; |
|
|
padding: 1rem; |
|
|
border: 1px solid #bfdbfe; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(analytics_enabled=False, theme=ghibli_theme, css=custom_css) as demo: |
|
|
gr.HTML( |
|
|
""" |
|
|
<div class="main-header">Open Ghibli Studio</div> |
|
|
<div class="tagline">Transform your photos into magical Ghibli-inspired artwork</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML( |
|
|
""" |
|
|
<style> |
|
|
body { |
|
|
background-image: url('https://i.imgur.com/LxPQPR1.jpg'); |
|
|
background-size: cover; |
|
|
background-position: center; |
|
|
background-attachment: fixed; |
|
|
background-repeat: no-repeat; |
|
|
background-color: #f0f9ff; |
|
|
} |
|
|
@media (max-width: 768px) { |
|
|
body { |
|
|
background-size: contain; |
|
|
} |
|
|
} |
|
|
</style> |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes="panel-box"): |
|
|
gr.HTML('<div class="section-header">Upload Image</div>') |
|
|
upload_img = gr.Image( |
|
|
label="Drop your image here", |
|
|
type="pil", |
|
|
elem_classes="image-preview", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
with gr.Group(elem_classes="control-panel"): |
|
|
gr.HTML('<div class="section-header">Generation Controls</div>') |
|
|
with gr.Row(): |
|
|
img2img_seed = gr.Slider( |
|
|
label="Seed", |
|
|
minimum=0, |
|
|
maximum=MAX_SEED, |
|
|
step=1, |
|
|
value=42, |
|
|
info="Set a specific seed for reproducible results" |
|
|
) |
|
|
img2img_randomize_seed = gr.Checkbox( |
|
|
label="Randomize Seed", |
|
|
value=True, |
|
|
info="Enable to get different results each time" |
|
|
) |
|
|
|
|
|
with gr.Group(): |
|
|
gr.HTML('<div class="section-header">Image Size</div>') |
|
|
with gr.Row(): |
|
|
img2img_width = gr.Slider( |
|
|
label="Width", |
|
|
minimum=256, |
|
|
maximum=MAX_IMAGE_SIZE, |
|
|
step=32, |
|
|
value=1024, |
|
|
info="Image width in pixels" |
|
|
) |
|
|
img2img_height = gr.Slider( |
|
|
label="Height", |
|
|
minimum=256, |
|
|
maximum=MAX_IMAGE_SIZE, |
|
|
step=32, |
|
|
value=1024, |
|
|
info="Image height in pixels" |
|
|
) |
|
|
|
|
|
with gr.Group(): |
|
|
gr.HTML('<div class="section-header">Generation Parameters</div>') |
|
|
with gr.Row(): |
|
|
img2img_guidance_scale = gr.Slider( |
|
|
label="Guidance Scale", |
|
|
minimum=0.0, |
|
|
maximum=10.0, |
|
|
step=0.1, |
|
|
value=3.5, |
|
|
info="Higher values follow the prompt more closely" |
|
|
) |
|
|
img2img_steps = gr.Slider( |
|
|
label="Steps", |
|
|
minimum=1, |
|
|
maximum=50, |
|
|
step=1, |
|
|
value=30, |
|
|
info="More steps = more detailed but slower generation" |
|
|
) |
|
|
|
|
|
img2img_lora_scale = gr.Slider( |
|
|
label="Ghibli Style Strength", |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
step=0.05, |
|
|
value=1.0, |
|
|
info="Controls the intensity of the Ghibli style" |
|
|
) |
|
|
|
|
|
transform_button = gr.Button("Transform to Ghibli Style", variant="primary", elem_classes="transform-button") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes="output-panel"): |
|
|
gr.HTML('<div class="section-header">Ghibli Magic Result</div>') |
|
|
ghibli_output_image = gr.Image( |
|
|
label="Generated Ghibli Style Image", |
|
|
elem_classes="image-preview", |
|
|
height=400 |
|
|
) |
|
|
ghibli_output_seed = gr.Number(label="Seed Used", interactive=False) |
|
|
|
|
|
|
|
|
with gr.Accordion("Image Details", open=False): |
|
|
extracted_caption = gr.Textbox( |
|
|
label="Detected Image Content", |
|
|
placeholder="The AI will analyze your image and describe it here...", |
|
|
info="AI-generated description of your uploaded image" |
|
|
) |
|
|
ghibli_prompt = gr.Textbox( |
|
|
label="Generation Prompt", |
|
|
placeholder="The prompt used to create your Ghibli image will appear here...", |
|
|
info="Final prompt used for the Ghibli transformation" |
|
|
) |
|
|
|
|
|
gr.HTML( |
|
|
""" |
|
|
<div class="footer"> |
|
|
<p>Open Ghibli Studio uses AI to transform your images into Ghibli-inspired artwork.</p> |
|
|
<p>Powered by FLUX.1 and Florence-2 models.</p> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
upload_img.upload( |
|
|
process_uploaded_image, |
|
|
inputs=[ |
|
|
upload_img, |
|
|
img2img_seed, |
|
|
img2img_randomize_seed, |
|
|
img2img_width, |
|
|
img2img_height, |
|
|
img2img_guidance_scale, |
|
|
img2img_steps, |
|
|
img2img_lora_scale, |
|
|
], |
|
|
outputs=[ |
|
|
ghibli_output_image, |
|
|
ghibli_output_seed, |
|
|
extracted_caption, |
|
|
ghibli_prompt, |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
transform_button.click( |
|
|
process_uploaded_image, |
|
|
inputs=[ |
|
|
upload_img, |
|
|
img2img_seed, |
|
|
img2img_randomize_seed, |
|
|
img2img_width, |
|
|
img2img_height, |
|
|
img2img_guidance_scale, |
|
|
img2img_steps, |
|
|
img2img_lora_scale, |
|
|
], |
|
|
outputs=[ |
|
|
ghibli_output_image, |
|
|
ghibli_output_seed, |
|
|
extracted_caption, |
|
|
ghibli_prompt, |
|
|
] |
|
|
) |
|
|
|
|
|
demo.launch(debug=True) |