"""Rebuild and export the trained multitask model for downstream use.""" from __future__ import annotations import argparse from pathlib import Path import torch from src.data.tokenization import Tokenizer, TokenizerConfig from src.models.factory import build_multitask_model, load_model_config from src.utils.config import load_yaml from src.utils.labels import load_label_metadata def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Export LexiMind model weights") parser.add_argument("--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint.") parser.add_argument("--output", default="outputs/model.pt", help="Output path for the exported state dict.") parser.add_argument("--labels", default="artifacts/labels.json", help="Label metadata JSON produced after training.") parser.add_argument("--model-config", default="configs/model/base.yaml", help="Model architecture configuration.") parser.add_argument("--data-config", default="configs/data/datasets.yaml", help="Data configuration (for tokenizer settings).") return parser.parse_args() def main() -> None: """Export multitask model weights from a training checkpoint to a standalone state dict.""" args = parse_args() checkpoint = Path(args.checkpoint) if not checkpoint.exists(): raise FileNotFoundError(checkpoint) labels = load_label_metadata(args.labels) data_cfg = load_yaml(args.data_config).data tokenizer_section = data_cfg.get("tokenizer", {}) tokenizer_config = TokenizerConfig( pretrained_model_name=tokenizer_section.get("pretrained_model_name", "facebook/bart-base"), max_length=int(tokenizer_section.get("max_length", 512)), lower=bool(tokenizer_section.get("lower", False)), ) tokenizer = Tokenizer(tokenizer_config) model = build_multitask_model( tokenizer, num_emotions=labels.emotion_size, num_topics=labels.topic_size, config=load_model_config(args.model_config), ) raw_state = torch.load(checkpoint, map_location="cpu") if isinstance(raw_state, dict): if "model_state_dict" in raw_state and isinstance(raw_state["model_state_dict"], dict): state_dict = raw_state["model_state_dict"] elif "state_dict" in raw_state and isinstance(raw_state["state_dict"], dict): state_dict = raw_state["state_dict"] else: state_dict = raw_state else: raise TypeError(f"Unsupported checkpoint format: expected dict, got {type(raw_state)!r}") model.load_state_dict(state_dict) output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), output_path) print(f"Model exported to {output_path}") if __name__ == "__main__": main()