Spaces:
Runtime error
Runtime error
Update gradio_app.py
Browse files- gradio_app.py +48 -7
gradio_app.py
CHANGED
|
@@ -41,6 +41,14 @@ TRAINED_MODEL_REPO = os.getenv("TRAINED_MODEL_REPO", "")
|
|
| 41 |
# 优先读取官方变量名,其次兼容 HF_TOKEN
|
| 42 |
HF_AUTH_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# -----------------------------------------------------------------------------
|
| 46 |
# Utilities
|
|
@@ -165,7 +173,11 @@ def _download_models() -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
| 165 |
token=HF_AUTH_TOKEN,
|
| 166 |
)
|
| 167 |
# Symlink to ./trained_model so downstream code can load from there
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
return sd15_path, hairmapper_dir, ffhq_dir
|
| 171 |
|
|
@@ -201,13 +213,42 @@ SD15_PATH, _, _ = _download_models()
|
|
| 201 |
# -----------------------------------------------------------------------------
|
| 202 |
# Global model loading (CPU) so GPU task only does inference
|
| 203 |
# -----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
def _resolve_trained_model_dir() -> str:
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
|
| 213 |
# Lazy globals
|
|
|
|
| 41 |
# 优先读取官方变量名,其次兼容 HF_TOKEN
|
| 42 |
HF_AUTH_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
|
| 43 |
|
| 44 |
+
# 需要的权重文件清单
|
| 45 |
+
REQUIRED_WEIGHT_FILENAMES = [
|
| 46 |
+
"pytorch_model.bin",
|
| 47 |
+
"motion_module-4140000.pth",
|
| 48 |
+
"pytorch_model_1.bin",
|
| 49 |
+
"pytorch_model_2.bin",
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
|
| 53 |
# -----------------------------------------------------------------------------
|
| 54 |
# Utilities
|
|
|
|
| 173 |
token=HF_AUTH_TOKEN,
|
| 174 |
)
|
| 175 |
# Symlink to ./trained_model so downstream code can load from there
|
| 176 |
+
tm_linked = _ensure_symlink(tm_snap, os.path.abspath("trained_model"))
|
| 177 |
+
# If the repo contains a nested pretrain/ folder, also expose it at ./pretrain
|
| 178 |
+
nested_pretrain = os.path.join(tm_linked, "pretrain")
|
| 179 |
+
if os.path.isdir(nested_pretrain):
|
| 180 |
+
_ensure_symlink(nested_pretrain, os.path.abspath("pretrain"))
|
| 181 |
|
| 182 |
return sd15_path, hairmapper_dir, ffhq_dir
|
| 183 |
|
|
|
|
| 213 |
# -----------------------------------------------------------------------------
|
| 214 |
# Global model loading (CPU) so GPU task only does inference
|
| 215 |
# -----------------------------------------------------------------------------
|
| 216 |
+
def _has_all_weights(dir_path: str) -> bool:
|
| 217 |
+
return all(os.path.isfile(os.path.join(dir_path, name)) for name in REQUIRED_WEIGHT_FILENAMES)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
def _resolve_trained_model_dir() -> str:
|
| 221 |
+
pretrain_dir = os.path.abspath("pretrain") if os.path.isdir("pretrain") else None
|
| 222 |
+
trained_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
|
| 223 |
+
trained_dir_nested = os.path.join(trained_dir, "pretrain") if trained_dir else None
|
| 224 |
+
|
| 225 |
+
# 优先使用 pretrain(你已说明文件在此),并校验文件齐全
|
| 226 |
+
if pretrain_dir and _has_all_weights(pretrain_dir):
|
| 227 |
+
return pretrain_dir
|
| 228 |
+
|
| 229 |
+
# 其次尝试 trained_model,并校验文件齐全
|
| 230 |
+
if trained_dir and _has_all_weights(trained_dir):
|
| 231 |
+
return trained_dir
|
| 232 |
+
# 再尝试 trained_model/pretrain 子目录
|
| 233 |
+
if trained_dir_nested and os.path.isdir(trained_dir_nested) and _has_all_weights(trained_dir_nested):
|
| 234 |
+
return trained_dir_nested
|
| 235 |
+
|
| 236 |
+
# 构造更友好的报错信息
|
| 237 |
+
def _missing_list(dir_path: str) -> str:
|
| 238 |
+
if not dir_path:
|
| 239 |
+
return "目录不存在"
|
| 240 |
+
missing = [n for n in REQUIRED_WEIGHT_FILENAMES if not os.path.isfile(os.path.join(dir_path, n))]
|
| 241 |
+
if not missing:
|
| 242 |
+
return "文件齐全"
|
| 243 |
+
return "缺少: " + ", ".join(missing)
|
| 244 |
+
|
| 245 |
+
msg = (
|
| 246 |
+
"Missing trained model weights. Provide TRAINED_MODEL_REPO or include ./pretrain.\n"
|
| 247 |
+
f"pretrain 状态: {_missing_list(pretrain_dir)}\n"
|
| 248 |
+
f"trained_model 状态: {_missing_list(trained_dir)}\n"
|
| 249 |
+
f"trained_model/pretrain 状态: {_missing_list(trained_dir_nested)}"
|
| 250 |
+
)
|
| 251 |
+
raise RuntimeError(msg)
|
| 252 |
|
| 253 |
|
| 254 |
# Lazy globals
|