ouclxy commited on
Commit
04f9c50
·
verified ·
1 Parent(s): 627b1b8

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +48 -7
gradio_app.py CHANGED
@@ -41,6 +41,14 @@ TRAINED_MODEL_REPO = os.getenv("TRAINED_MODEL_REPO", "")
41
  # 优先读取官方变量名,其次兼容 HF_TOKEN
42
  HF_AUTH_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
43
 
 
 
 
 
 
 
 
 
44
 
45
  # -----------------------------------------------------------------------------
46
  # Utilities
@@ -165,7 +173,11 @@ def _download_models() -> Tuple[Optional[str], Optional[str], Optional[str]]:
165
  token=HF_AUTH_TOKEN,
166
  )
167
  # Symlink to ./trained_model so downstream code can load from there
168
- _ = _ensure_symlink(tm_snap, os.path.abspath("trained_model"))
 
 
 
 
169
 
170
  return sd15_path, hairmapper_dir, ffhq_dir
171
 
@@ -201,13 +213,42 @@ SD15_PATH, _, _ = _download_models()
201
  # -----------------------------------------------------------------------------
202
  # Global model loading (CPU) so GPU task only does inference
203
  # -----------------------------------------------------------------------------
 
 
 
 
204
  def _resolve_trained_model_dir() -> str:
205
- tm_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
206
- if tm_dir is None and os.path.isdir("pretrain"):
207
- tm_dir = os.path.abspath("pretrain")
208
- if tm_dir is None:
209
- raise RuntimeError("Missing trained model weights. Provide TRAINED_MODEL_REPO or include ./pretrain.")
210
- return tm_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
  # Lazy globals
 
41
  # 优先读取官方变量名,其次兼容 HF_TOKEN
42
  HF_AUTH_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
43
 
44
+ # 需要的权重文件清单
45
+ REQUIRED_WEIGHT_FILENAMES = [
46
+ "pytorch_model.bin",
47
+ "motion_module-4140000.pth",
48
+ "pytorch_model_1.bin",
49
+ "pytorch_model_2.bin",
50
+ ]
51
+
52
 
53
  # -----------------------------------------------------------------------------
54
  # Utilities
 
173
  token=HF_AUTH_TOKEN,
174
  )
175
  # Symlink to ./trained_model so downstream code can load from there
176
+ tm_linked = _ensure_symlink(tm_snap, os.path.abspath("trained_model"))
177
+ # If the repo contains a nested pretrain/ folder, also expose it at ./pretrain
178
+ nested_pretrain = os.path.join(tm_linked, "pretrain")
179
+ if os.path.isdir(nested_pretrain):
180
+ _ensure_symlink(nested_pretrain, os.path.abspath("pretrain"))
181
 
182
  return sd15_path, hairmapper_dir, ffhq_dir
183
 
 
213
  # -----------------------------------------------------------------------------
214
  # Global model loading (CPU) so GPU task only does inference
215
  # -----------------------------------------------------------------------------
216
+ def _has_all_weights(dir_path: str) -> bool:
217
+ return all(os.path.isfile(os.path.join(dir_path, name)) for name in REQUIRED_WEIGHT_FILENAMES)
218
+
219
+
220
  def _resolve_trained_model_dir() -> str:
221
+ pretrain_dir = os.path.abspath("pretrain") if os.path.isdir("pretrain") else None
222
+ trained_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
223
+ trained_dir_nested = os.path.join(trained_dir, "pretrain") if trained_dir else None
224
+
225
+ # 优先使用 pretrain(你已说明文件在此),并校验文件齐全
226
+ if pretrain_dir and _has_all_weights(pretrain_dir):
227
+ return pretrain_dir
228
+
229
+ # 其次尝试 trained_model,并校验文件齐全
230
+ if trained_dir and _has_all_weights(trained_dir):
231
+ return trained_dir
232
+ # 再尝试 trained_model/pretrain 子目录
233
+ if trained_dir_nested and os.path.isdir(trained_dir_nested) and _has_all_weights(trained_dir_nested):
234
+ return trained_dir_nested
235
+
236
+ # 构造更友好的报错信息
237
+ def _missing_list(dir_path: str) -> str:
238
+ if not dir_path:
239
+ return "目录不存在"
240
+ missing = [n for n in REQUIRED_WEIGHT_FILENAMES if not os.path.isfile(os.path.join(dir_path, n))]
241
+ if not missing:
242
+ return "文件齐全"
243
+ return "缺少: " + ", ".join(missing)
244
+
245
+ msg = (
246
+ "Missing trained model weights. Provide TRAINED_MODEL_REPO or include ./pretrain.\n"
247
+ f"pretrain 状态: {_missing_list(pretrain_dir)}\n"
248
+ f"trained_model 状态: {_missing_list(trained_dir)}\n"
249
+ f"trained_model/pretrain 状态: {_missing_list(trained_dir_nested)}"
250
+ )
251
+ raise RuntimeError(msg)
252
 
253
 
254
  # Lazy globals