Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from pytorch_lightning import Callback | |
| import os | |
| import torch | |
| from lightning_fabric.utilities.cloud_io import get_filesystem | |
| from pytorch_lightning.cli import LightningArgumentParser | |
| from pytorch_lightning import LightningModule, Trainer | |
| from lightning_utilities.core.imports import RequirementCache | |
| from omegaconf import OmegaConf | |
| _JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache( | |
| "jsonargparse[signatures]>=4.17.0") | |
| if _JSONARGPARSE_SIGNATURES_AVAILABLE: | |
| import docstring_parser | |
| from jsonargparse import ( | |
| ActionConfigFile, | |
| ArgumentParser, | |
| class_from_function, | |
| Namespace, | |
| register_unresolvable_import_paths, | |
| set_config_read_mode, | |
| ) | |
| # Required until fix https://github.com/pytorch/pytorch/issues/74483 | |
| register_unresolvable_import_paths(torch) | |
| set_config_read_mode(fsspec_enabled=True) | |
| else: | |
| locals()["ArgumentParser"] = object | |
| locals()["Namespace"] = object | |
| class SaveConfigCallback(Callback): | |
| """Saves a LightningCLI config to the log_dir when training starts. | |
| Args: | |
| parser: The parser object used to parse the configuration. | |
| config: The parsed configuration that will be saved. | |
| config_filename: Filename for the config file. | |
| overwrite: Whether to overwrite an existing config file. | |
| multifile: When input is multiple config files, saved config preserves this structure. | |
| Raises: | |
| RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run | |
| """ | |
| def __init__( | |
| self, | |
| parser: LightningArgumentParser, | |
| config: Namespace, | |
| log_dir: str, | |
| config_filename: str = "config.yaml", | |
| overwrite: bool = False, | |
| multifile: bool = False, | |
| ) -> None: | |
| self.parser = parser | |
| self.config = config | |
| self.config_filename = config_filename | |
| self.overwrite = overwrite | |
| self.multifile = multifile | |
| self.already_saved = False | |
| self.log_dir = log_dir | |
| def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: | |
| if self.already_saved: | |
| return | |
| log_dir = self.log_dir | |
| assert log_dir is not None | |
| config_path = os.path.join(log_dir, self.config_filename) | |
| fs = get_filesystem(log_dir) | |
| if not self.overwrite: | |
| # check if the file exists on rank 0 | |
| file_exists = fs.isfile( | |
| config_path) if trainer.is_global_zero else False | |
| # broadcast whether to fail to all ranks | |
| file_exists = trainer.strategy.broadcast(file_exists) | |
| if file_exists: | |
| raise RuntimeError( | |
| f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" | |
| " results of a previous run. You can delete the previous config file," | |
| " set `LightningCLI(save_config_callback=None)` to disable config saving," | |
| ' or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.' | |
| ) | |
| # save the file on rank 0 | |
| if trainer.is_global_zero: | |
| # save only on rank zero to avoid race conditions. | |
| # the `log_dir` needs to be created as we rely on the logger to do it usually | |
| # but it hasn't logged anything at this point | |
| fs.makedirs(log_dir, exist_ok=True) | |
| self.parser.save( | |
| self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile | |
| ) | |
| self.already_saved = True | |
| trainer.logger.log_hyperparams(OmegaConf.load(config_path)) | |
| # broadcast so that all ranks are in sync on future calls to .setup() | |
| self.already_saved = trainer.strategy.broadcast(self.already_saved) | |