kfirbria commited on
Commit
fd176a7
·
verified ·
1 Parent(s): 7d3c3ee

Update fibo_vlm_prompt_to_json.py

Browse files
Files changed (1) hide show
  1. fibo_vlm_prompt_to_json.py +43 -5
fibo_vlm_prompt_to_json.py CHANGED
@@ -4,6 +4,7 @@ import textwrap
4
  from typing import Any, Dict, Iterable, List, Optional
5
 
6
  import torch
 
7
  from boltons.iterutils import remap
8
  from PIL import Image
9
  from transformers import AutoModelForCausalLM, AutoProcessor, Qwen3VLForConditionalGeneration
@@ -11,6 +12,13 @@ from transformers import AutoModelForCausalLM, AutoProcessor, Qwen3VLForConditio
11
  from diffusers.modular_pipelines import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState
12
 
13
 
 
 
 
 
 
 
 
14
  def parse_aesthetic_score(record: dict) -> str:
15
  ae = record["aesthetic_score"]
16
  if ae < 5.5:
@@ -57,7 +65,24 @@ def prepare_clean_caption(record: dict) -> str:
57
  if "aesthetic_score" in record:
58
  scores["aesthetic_score"] = parse_aesthetic_score(record)
59
 
60
- clean_caption_dict = remap(record, visit=keep)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # Set aesthetics scores
63
  if "aesthetics" not in clean_caption_dict:
@@ -67,7 +92,7 @@ def prepare_clean_caption(record: dict) -> str:
67
  clean_caption_dict["aesthetics"].update(scores)
68
 
69
  # Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps)
70
- clean_caption_str = json.dumps(clean_caption_dict)
71
  return clean_caption_str
72
  except Exception as ex:
73
  print("Error: ", ex)
@@ -221,6 +246,7 @@ def generate_json_prompt(
221
  prompt: Optional[str] = None,
222
  structured_prompt: Optional[str] = None,
223
  ):
 
224
  if image is None and structured_prompt is None:
225
  # only got prompt
226
  task = "generate"
@@ -233,6 +259,7 @@ def generate_json_prompt(
233
  # got image and prompt
234
  task = "refine"
235
  editing_instructions = prompt
 
236
  elif image is not None and structured_prompt is None and prompt is None:
237
  # only got image
238
  task = "inspire"
@@ -244,6 +271,7 @@ def generate_json_prompt(
244
  task,
245
  image=image,
246
  prompt=prompt,
 
247
  structured_prompt=structured_prompt,
248
  editing_instructions=editing_instructions,
249
  )
@@ -277,12 +305,22 @@ def build_messages(
277
  if refine_image is None:
278
  base_prompt = (structured_prompt or "").strip()
279
  edits = (editing_instructions or "").strip()
280
- formatted = textwrap.dedent(f"""<refine> Input: {base_prompt} Editing instructions: {edits}""").strip()
 
 
 
 
 
 
281
  user_content.append({"type": "text", "text": formatted})
282
  else:
283
  user_content.append({"type": "image", "image": refine_image})
284
  edits = (editing_instructions or "").strip()
285
- formatted = textwrap.dedent(f"""<refine> Editing instructions: {edits}""").strip()
 
 
 
 
286
  user_content.append({"type": "text", "text": formatted})
287
 
288
  messages: List[Dict[str, Any]] = []
@@ -293,7 +331,7 @@ def build_messages(
293
  class BriaFiboVLMPromptToJson(ModularPipelineBlocks):
294
  model_name = "BriaFibo"
295
 
296
- def __init__(self, model_id = "briaai/vlm-processor-new"):
297
  super().__init__()
298
  self.engine = TransformersEngine(model_id)
299
  self.engine.model.to("cuda")
 
4
  from typing import Any, Dict, Iterable, List, Optional
5
 
6
  import torch
7
+ import ujson
8
  from boltons.iterutils import remap
9
  from PIL import Image
10
  from transformers import AutoModelForCausalLM, AutoProcessor, Qwen3VLForConditionalGeneration
 
12
  from diffusers.modular_pipelines import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState
13
 
14
 
15
+ def clean_json(caption):
16
+ caption["pickascore"] = 1.0
17
+ caption["aesthetic_score"] = 10.0
18
+ caption = prepare_clean_caption(caption)
19
+ return caption
20
+
21
+
22
  def parse_aesthetic_score(record: dict) -> str:
23
  ae = record["aesthetic_score"]
24
  if ae < 5.5:
 
65
  if "aesthetic_score" in record:
66
  scores["aesthetic_score"] = parse_aesthetic_score(record)
67
 
68
+ # Create structured caption dict of original values
69
+ fields = [
70
+ "short_description",
71
+ "objects",
72
+ "background_setting",
73
+ "lighting",
74
+ "aesthetics",
75
+ "photographic_characteristics",
76
+ "style_medium",
77
+ "text_render",
78
+ "context",
79
+ "artistic_style",
80
+ ]
81
+
82
+ original_caption_dict = {f: record[f] for f in fields if f in record}
83
+
84
+ # filter empty values recursivly (i.e. None, "", {}, [], float("nan"))
85
+ clean_caption_dict = remap(original_caption_dict, visit=keep)
86
 
87
  # Set aesthetics scores
88
  if "aesthetics" not in clean_caption_dict:
 
92
  clean_caption_dict["aesthetics"].update(scores)
93
 
94
  # Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps)
95
+ clean_caption_str = ujson.dumps(clean_caption_dict, escape_forward_slashes=False)
96
  return clean_caption_str
97
  except Exception as ex:
98
  print("Error: ", ex)
 
246
  prompt: Optional[str] = None,
247
  structured_prompt: Optional[str] = None,
248
  ):
249
+ refine_image = None
250
  if image is None and structured_prompt is None:
251
  # only got prompt
252
  task = "generate"
 
259
  # got image and prompt
260
  task = "refine"
261
  editing_instructions = prompt
262
+ refine_image = image
263
  elif image is not None and structured_prompt is None and prompt is None:
264
  # only got image
265
  task = "inspire"
 
271
  task,
272
  image=image,
273
  prompt=prompt,
274
+ refine_image=refine_image,
275
  structured_prompt=structured_prompt,
276
  editing_instructions=editing_instructions,
277
  )
 
305
  if refine_image is None:
306
  base_prompt = (structured_prompt or "").strip()
307
  edits = (editing_instructions or "").strip()
308
+ formatted = textwrap.dedent(
309
+ f"""<refine>
310
+ Input:
311
+ {base_prompt}
312
+ Editing instructions:
313
+ {edits}"""
314
+ ).strip()
315
  user_content.append({"type": "text", "text": formatted})
316
  else:
317
  user_content.append({"type": "image", "image": refine_image})
318
  edits = (editing_instructions or "").strip()
319
+ formatted = textwrap.dedent(
320
+ f"""<refine>
321
+ Editing instructions:
322
+ {edits}"""
323
+ ).strip()
324
  user_content.append({"type": "text", "text": formatted})
325
 
326
  messages: List[Dict[str, Any]] = []
 
331
  class BriaFiboVLMPromptToJson(ModularPipelineBlocks):
332
  model_name = "BriaFibo"
333
 
334
+ def __init__(self, model_id="briaai/vlm-processor-new"):
335
  super().__init__()
336
  self.engine = TransformersEngine(model_id)
337
  self.engine.model.to("cuda")