from typing import Dict, Any import torch from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM class EndpointHandler: def __init__(self, path: str = ""): self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer.add_bos_token = True self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16 ).to("cuda" if torch.cuda.is_available() else "cpu") self.generator = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, device=0 if torch.cuda.is_available() else -1, return_full_text=False, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16 ) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: prompt = data.get("inputs", "") if not prompt: return {"error": "Missing 'inputs' field."} defaults = { "max_new_tokens": 100, "do_sample": True, "temperature": 0.7, "top_p": 0.9, "eos_token_id": self.tokenizer.eos_token_id # ✅ Stop at <|eot_id|> } generation_args = {**defaults, **data.get("parameters", {})} try: outputs = self.generator(prompt, **generation_args) output_text = outputs[0]["generated_text"].strip() finish_reason = "stop" if len(self.tokenizer.encode(output_text)) >= generation_args["max_new_tokens"]: finish_reason = "length" return { "choices": [{ "message": { "role": "assistant", "content": output_text }, "finish_reason": finish_reason }] } except Exception as e: return {"error": str(e)}