XiangpengYang commited on
Commit
42e8438
·
1 Parent(s): c2d6fd3

no unmerging

Browse files
Files changed (2) hide show
  1. app.py +32 -17
  2. videox_fun/ui/ui.py +1 -2
app.py CHANGED
@@ -162,6 +162,12 @@ class VideoCoF_Controller(Wan_Controller):
162
  if self.lora_model_path != lora_model_dropdown:
163
  self.update_lora_model(lora_model_dropdown)
164
 
 
 
 
 
 
 
165
  # Scheduler setup
166
  scheduler_config = self.pipeline.scheduler.config
167
  if sampler_dropdown in ["Flow_Unipc", "Flow_DPM++"]:
@@ -171,18 +177,34 @@ class VideoCoF_Controller(Wan_Controller):
171
  # LoRA merging
172
  # 1. Merge VideoCoF LoRA
173
  if self.lora_model_path != "none":
174
- print(f"Merge VideoCoF Lora: {self.lora_model_path}")
175
- self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device)
 
 
 
 
 
 
 
 
176
 
177
  # 2. Merge Acceleration LoRA (FusionX) if enabled
178
  acc_lora_path = os.path.join(self.personalized_model_dir, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors")
179
  if enable_acceleration:
180
  if os.path.exists(acc_lora_path):
181
- print(f"Merge Acceleration LoRA: {acc_lora_path}")
182
- # FusionX LoRA generally uses multiplier 1.0
183
- self.pipeline = merge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
 
 
184
  else:
185
  print(f"Warning: Acceleration LoRA not found at {acc_lora_path}")
 
 
 
 
 
 
186
 
187
  # Seed
188
  if int(seed_textbox) != -1 and seed_textbox != "":
@@ -247,24 +269,17 @@ class VideoCoF_Controller(Wan_Controller):
247
  except Exception as e:
248
  print(f"Error: {e}")
249
  # Unmerge in case of error (LIFO order)
250
- if enable_acceleration and os.path.exists(acc_lora_path):
251
  print("Unmerging Acceleration LoRA (due to error)")
252
  self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
 
253
 
254
- if self.lora_model_path != "none":
255
  print("Unmerging VideoCoF LoRA (due to error)")
256
- self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device)
 
257
  return gr.update(), gr.update(), f"Error: {str(e)}"
258
 
259
- # Unmerge LoRAs (LIFO order)
260
- if enable_acceleration and os.path.exists(acc_lora_path):
261
- print("Unmerging Acceleration LoRA")
262
- self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
263
-
264
- if self.lora_model_path != "none":
265
- print("Unmerging VideoCoF LoRA")
266
- self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device)
267
-
268
  # Save output
269
  save_sample_path = self.save_outputs(
270
  False, length_slider, final_video, fps=fps
 
162
  if self.lora_model_path != lora_model_dropdown:
163
  self.update_lora_model(lora_model_dropdown)
164
 
165
+ # Track whether LoRAs are already merged to avoid repeat merges/unmerges.
166
+ if not hasattr(self, "_active_lora_path"):
167
+ self._active_lora_path = None
168
+ if not hasattr(self, "_acc_lora_active"):
169
+ self._acc_lora_active = False
170
+
171
  # Scheduler setup
172
  scheduler_config = self.pipeline.scheduler.config
173
  if sampler_dropdown in ["Flow_Unipc", "Flow_DPM++"]:
 
177
  # LoRA merging
178
  # 1. Merge VideoCoF LoRA
179
  if self.lora_model_path != "none":
180
+ # If a different LoRA was previously merged, unmerge it first.
181
+ if self._active_lora_path and self._active_lora_path != self.lora_model_path:
182
+ print(f"Unmerging previous VideoCoF LoRA: {self._active_lora_path}")
183
+ self.pipeline = unmerge_lora(self.pipeline, self._active_lora_path, multiplier=lora_alpha_slider, device=self.device)
184
+ self._active_lora_path = None
185
+
186
+ if self._active_lora_path != self.lora_model_path:
187
+ print(f"Merge VideoCoF LoRA: {self.lora_model_path}")
188
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device)
189
+ self._active_lora_path = self.lora_model_path
190
 
191
  # 2. Merge Acceleration LoRA (FusionX) if enabled
192
  acc_lora_path = os.path.join(self.personalized_model_dir, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors")
193
  if enable_acceleration:
194
  if os.path.exists(acc_lora_path):
195
+ if not self._acc_lora_active:
196
+ print(f"Merge Acceleration LoRA: {acc_lora_path}")
197
+ # FusionX LoRA generally uses multiplier 1.0
198
+ self.pipeline = merge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
199
+ self._acc_lora_active = True
200
  else:
201
  print(f"Warning: Acceleration LoRA not found at {acc_lora_path}")
202
+ else:
203
+ # If it was previously merged but now disabled, unmerge once.
204
+ if self._acc_lora_active and os.path.exists(acc_lora_path):
205
+ print("Unmerging Acceleration LoRA (disabled)")
206
+ self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
207
+ self._acc_lora_active = False
208
 
209
  # Seed
210
  if int(seed_textbox) != -1 and seed_textbox != "":
 
269
  except Exception as e:
270
  print(f"Error: {e}")
271
  # Unmerge in case of error (LIFO order)
272
+ if self._acc_lora_active and os.path.exists(acc_lora_path):
273
  print("Unmerging Acceleration LoRA (due to error)")
274
  self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
275
+ self._acc_lora_active = False
276
 
277
+ if self._active_lora_path:
278
  print("Unmerging VideoCoF LoRA (due to error)")
279
+ self.pipeline = unmerge_lora(self.pipeline, self._active_lora_path, multiplier=lora_alpha_slider, device=self.device)
280
+ self._active_lora_path = None
281
  return gr.update(), gr.update(), f"Error: {str(e)}"
282
 
 
 
 
 
 
 
 
 
 
283
  # Save output
284
  save_sample_path = self.save_outputs(
285
  False, length_slider, final_video, fps=fps
videox_fun/ui/ui.py CHANGED
@@ -291,8 +291,7 @@ def create_generation_method(source_method_options, prompt_textbox, support_end_
291
  gr.Examples(
292
  examples=video_examples,
293
  inputs=[validation_video, prompt_textbox] if len(video_examples[0]) > 1 else validation_video,
294
- label="Video Examples",
295
- examples_per_page=6,
296
  )
297
 
298
  # Removed Mask Accordion entirely per request or hidden. User said "mask这个不需要"
 
291
  gr.Examples(
292
  examples=video_examples,
293
  inputs=[validation_video, prompt_textbox] if len(video_examples[0]) > 1 else validation_video,
294
+ label="Video Examples"
 
295
  )
296
 
297
  # Removed Mask Accordion entirely per request or hidden. User said "mask这个不需要"