samtuckervegan commited on
Commit
f61bc02
·
verified ·
1 Parent(s): f23126c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -65
handler.py CHANGED
@@ -1,77 +1,44 @@
1
- from typing import Dict, Any, List
2
  import torch
3
- from transformers import pipeline, AutoTokenizer, LlamaForCausalLM
4
-
5
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
6
 
7
  class EndpointHandler:
8
  def __init__(self, path: str = ""):
9
- self.tokenizer = AutoTokenizer.from_pretrained(path, revision="main")
10
- if self.tokenizer.pad_token is None:
11
- self.tokenizer.pad_token = self.tokenizer.eos_token
12
-
13
- self.model = LlamaForCausalLM.from_pretrained(path, revision="main", torch_dtype=dtype)
14
-
15
- device = 0 if torch.cuda.is_available() else -1
16
  self.generator = pipeline(
17
  "text-generation",
18
  model=self.model,
19
  tokenizer=self.tokenizer,
20
- device=device
 
 
21
  )
22
 
23
- self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.eos_token)
24
-
25
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
26
- # If using HF Inference Endpoint, wrap everything under "inputs"
27
- data = data.get("inputs", data)
28
-
29
- messages: List[Dict[str, str]] = data.get("messages", [])
30
- if not messages:
31
- return {"error": "Missing 'messages' array."}
32
-
33
- prompt = self.format_chat_prompt(messages)
34
-
35
- generation_args = data.get("parameters", {})
36
- max_tokens = generation_args.setdefault("max_new_tokens", 300)
37
- generation_args.setdefault("do_sample", True)
38
- generation_args.setdefault("temperature", 0.4)
39
- generation_args.setdefault("top_p", 0.9)
40
- generation_args.setdefault("repetition_penalty", 1.2)
41
- generation_args.setdefault("no_repeat_ngram_size", 6)
42
- generation_args.setdefault("early_stopping", True)
43
- generation_args.setdefault("return_full_text", False)
44
- generation_args.setdefault("eos_token_id", self.eos_token_id)
45
- generation_args.setdefault("pad_token_id", self.tokenizer.pad_token_id)
46
-
47
- try:
48
- result = self.generator(prompt, **generation_args)
49
- output = result[0]["generated_text"].strip()
50
- token_count = len(self.tokenizer.encode(output))
51
-
52
- finish_reason = "stop"
53
- if self.tokenizer.eos_token not in output and token_count >= max_tokens:
54
- finish_reason = "length"
55
-
56
- return {
57
- "choices": [{
58
- "message": {
59
- "role": "assistant",
60
- "content": output
61
- },
62
- "finish_reason": finish_reason
63
- }]
64
- }
65
-
66
- except Exception as e:
67
- import traceback
68
- return {"error": str(e), "traceback": traceback.format_exc()}
69
 
70
- def format_chat_prompt(self, messages: List[Dict[str, str]]) -> str:
71
- prompt = ""
72
- for msg in messages:
73
- role = msg.get("role", "").strip().lower()
74
- content = msg.get("content", "").strip()
75
- if role in ["system", "user", "assistant", "ipython"]:
76
- prompt += f"{content}\n"
77
- return prompt.strip()
 
 
1
+ from typing import Dict, Any
2
  import torch
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path: str = ""):
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
8
+ self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16).to("cuda" if torch.cuda.is_available() else "cpu")
9
+
 
 
 
 
10
  self.generator = pipeline(
11
  "text-generation",
12
  model=self.model,
13
  tokenizer=self.tokenizer,
14
+ device=0 if torch.cuda.is_available() else -1,
15
+ return_full_text=False,
16
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
17
  )
18
 
 
 
19
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
20
+ prompt = data.get("inputs", "")
21
+ if not prompt:
22
+ return {"error": "Missing 'inputs' field."}
23
+
24
+ if not prompt.startswith("<|begin_of_text|>"):
25
+ prompt = f"<|begin_of_text|>{prompt}"
26
+
27
+ params = data.get("parameters", {})
28
+ outputs = self.generator(
29
+ prompt,
30
+ max_new_tokens=params.get("max_new_tokens", 100),
31
+ do_sample=params.get("do_sample", True),
32
+ temperature=params.get("temperature", 0.7),
33
+ top_p=params.get("top_p", 0.9)
34
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ return {
37
+ "choices": [{
38
+ "message": {
39
+ "role": "assistant",
40
+ "content": outputs[0]["generated_text"].strip()
41
+ },
42
+ "finish_reason": "stop"
43
+ }]
44
+ }