File size: 3,762 Bytes
1fbc47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Run inference with the multitask model."""
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import List, cast

from src.data.tokenization import TokenizerConfig
from src.inference import EmotionPrediction, TopicPrediction, create_inference_pipeline


def _load_texts(positional: List[str], file_path: Path | None) -> List[str]:
    texts = [text for text in positional if text]
    if file_path is not None:
        if not file_path.exists():
            raise FileNotFoundError(file_path)
        with file_path.open("r", encoding="utf-8") as handle:
            texts.extend([line.strip() for line in handle if line.strip()])
    if not texts:
        raise ValueError("No input texts provided. Pass text arguments or use --file.")
    return texts


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run LexiMind multitask inference.")
    parser.add_argument("text", nargs="*", help="Input text(s) to analyse.")
    parser.add_argument("--file", type=Path, help="Path to a file containing one text per line.")
    parser.add_argument(
        "--checkpoint",
        type=Path,
        default=Path("checkpoints/best.pt"),
        help="Path to the model checkpoint produced during training.",
    )
    parser.add_argument(
        "--labels",
        type=Path,
        default=Path("artifacts/labels.json"),
        help="JSON file containing emotion/topic label vocabularies.",
    )
    parser.add_argument(
        "--tokenizer",
        type=Path,
        default=None,
        help="Optional path to a tokenizer directory exported during training.",
    )
    parser.add_argument(
        "--model-config",
        type=Path,
        default=Path("configs/model/base.yaml"),
        help="Model architecture config used to rebuild the transformer stack.",
    )
    parser.add_argument("--device", default="cpu", help="Device to run inference on (cpu or cuda).")
    parser.add_argument(
        "--summary-max-length",
        type=int,
        default=None,
        help="Optional maximum length for generated summaries.",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    texts = _load_texts(args.text, args.file)

    tokenizer_config = None
    if args.tokenizer is not None:
        tokenizer_config = TokenizerConfig(pretrained_model_name=str(args.tokenizer))
    else:
        local_dir = Path("artifacts/hf_tokenizer")
        if local_dir.exists():
            tokenizer_config = TokenizerConfig(pretrained_model_name=str(local_dir))

    pipeline, _ = create_inference_pipeline(
        checkpoint_path=args.checkpoint,
        labels_path=args.labels,
        tokenizer_config=tokenizer_config,
        model_config_path=args.model_config,
        device=args.device,
        summary_max_length=args.summary_max_length,
    )

    results = pipeline.batch_predict(texts)
    summaries = cast(List[str], results["summaries"])
    emotion_preds = cast(List[EmotionPrediction], results["emotion"])
    topic_preds = cast(List[TopicPrediction], results["topic"])

    packaged = []
    for idx, text in enumerate(texts):
        emotion = emotion_preds[idx]
        topic = topic_preds[idx]
        packaged.append(
            {
                "text": text,
                "summary": summaries[idx],
                "emotion": {
                    "labels": emotion.labels,
                    "scores": emotion.scores,
                },
                "topic": {
                    "label": topic.label,
                    "confidence": topic.confidence,
                },
            }
        )

    print(json.dumps(packaged, indent=2, ensure_ascii=False))


if __name__ == "__main__":
    main()