Adapters
Osher commited on
Commit
e839598
·
verified ·
1 Parent(s): 70a6fd7

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +37 -37
chat.py CHANGED
@@ -1,37 +1,37 @@
1
- import torch
2
- from model import TransformerModel
3
- from tokenizer import SimpleTokenizer
4
-
5
- # Load tokenizer
6
- tokenizer = SimpleTokenizer("vocab.pth")
7
-
8
- # Use same values from train.py
9
- vocab_size = len(tokenizer.char_to_idx)
10
- embed_size = 64
11
- num_heads = 2
12
- hidden_dim = 128
13
- num_layers = 2
14
- max_len = 32
15
-
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
-
18
- # Create the same model and load weights
19
- model = TransformerModel(vocab_size, embed_size, num_heads, hidden_dim, num_layers, max_len).to(device)
20
- model.load_state_dict(torch.load("model.pth", map_location=device))
21
- model.eval()
22
-
23
- # Chat loop
24
- while True:
25
- user_input = input("You: ")
26
- if user_input.lower() in ["quit", "exit"]:
27
- break
28
-
29
- input_ids = tokenizer.encode(user_input)
30
- input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
31
-
32
- with torch.no_grad():
33
- output = model(input_tensor)[0] # shape: [seq_len, vocab_size]
34
- prediction = torch.argmax(output, dim=-1).squeeze().tolist()
35
-
36
- response = tokenizer.decode(prediction)
37
- print("AI:", response)
 
1
+ import torch
2
+ from model import TransformerModel
3
+ from tokenizer import SimpleTokenizer
4
+
5
+ # Load tokenizer
6
+ tokenizer = SimpleTokenizer("vocab_path")
7
+
8
+ # Use same values from train.py
9
+ vocab_size = len(tokenizer.char_to_idx)
10
+ embed_size = 64
11
+ num_heads = 2
12
+ hidden_dim = 128
13
+ num_layers = 2
14
+ max_len = 32
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # Create the same model and load weights
19
+ model = TransformerModel(vocab_size, embed_size, num_heads, hidden_dim, num_layers, max_len).to(device)
20
+ model.load_state_dict(torch.load("model.pth", map_location=device))
21
+ model.eval()
22
+
23
+ # Chat loop
24
+ while True:
25
+ user_input = input("You: ")
26
+ if user_input.lower() in ["quit", "exit"]:
27
+ break
28
+
29
+ input_ids = tokenizer.encode(user_input)
30
+ input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
31
+
32
+ with torch.no_grad():
33
+ output = model(input_tensor)[0] # shape: [seq_len, vocab_size]
34
+ prediction = torch.argmax(output, dim=-1).squeeze().tolist()
35
+
36
+ response = tokenizer.decode(prediction)
37
+ print("AI:", response)