File size: 1,991 Bytes
f61bc02 f23126c f61bc02 f23126c f61bc02 6d20ad5 f23126c f61bc02 f23126c f61bc02 6d20ad5 d942f82 6d20ad5 d942f82 6d20ad5 d942f82 6d20ad5 d942f82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from typing import Dict, Any
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
class EndpointHandler:
def __init__(self, path: str = ""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.tokenizer.add_bos_token = True
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
).to("cuda" if torch.cuda.is_available() else "cpu")
self.generator = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1,
return_full_text=False,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
prompt = data.get("inputs", "")
if not prompt:
return {"error": "Missing 'inputs' field."}
defaults = {
"max_new_tokens": 100,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"eos_token_id": self.tokenizer.eos_token_id # ✅ Stop at <|eot_id|>
}
generation_args = {**defaults, **data.get("parameters", {})}
try:
outputs = self.generator(prompt, **generation_args)
output_text = outputs[0]["generated_text"].strip()
finish_reason = "stop"
if len(self.tokenizer.encode(output_text)) >= generation_args["max_new_tokens"]:
finish_reason = "length"
return {
"choices": [{
"message": {
"role": "assistant",
"content": output_text
},
"finish_reason": finish_reason
}]
}
except Exception as e:
return {"error": str(e)}
|