Osher
commited on
Upload 14 files
Browse files- .gitattributes +1 -0
- chat.py +37 -0
- data.txt +3 -0
- feedforward.py +14 -0
- model.pth +3 -0
- model.py +42 -0
- multi_head_attention.py +42 -0
- positional_encoding.py +19 -0
- tiny_llama_model.pth +3 -0
- tokenizer.py +29 -0
- train.log +1 -0
- train.py +111 -0
- transformer_block.py +26 -0
- transformer_model.py +37 -0
- vocab.pth +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data.txt filter=lfs diff=lfs merge=lfs -text
|
chat.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from model import TransformerModel
|
| 3 |
+
from tokenizer import SimpleTokenizer
|
| 4 |
+
|
| 5 |
+
# Load tokenizer
|
| 6 |
+
tokenizer = SimpleTokenizer("vocab.pth")
|
| 7 |
+
|
| 8 |
+
# Use same values from train.py
|
| 9 |
+
vocab_size = len(tokenizer.char_to_idx)
|
| 10 |
+
embed_size = 64
|
| 11 |
+
num_heads = 2
|
| 12 |
+
hidden_dim = 128
|
| 13 |
+
num_layers = 2
|
| 14 |
+
max_len = 32
|
| 15 |
+
|
| 16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 17 |
+
|
| 18 |
+
# Create the same model and load weights
|
| 19 |
+
model = TransformerModel(vocab_size, embed_size, num_heads, hidden_dim, num_layers, max_len).to(device)
|
| 20 |
+
model.load_state_dict(torch.load("model.pth", map_location=device))
|
| 21 |
+
model.eval()
|
| 22 |
+
|
| 23 |
+
# Chat loop
|
| 24 |
+
while True:
|
| 25 |
+
user_input = input("You: ")
|
| 26 |
+
if user_input.lower() in ["quit", "exit"]:
|
| 27 |
+
break
|
| 28 |
+
|
| 29 |
+
input_ids = tokenizer.encode(user_input)
|
| 30 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
|
| 31 |
+
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
output = model(input_tensor)[0] # shape: [seq_len, vocab_size]
|
| 34 |
+
prediction = torch.argmax(output, dim=-1).squeeze().tolist()
|
| 35 |
+
|
| 36 |
+
response = tokenizer.decode(prediction)
|
| 37 |
+
print("AI:", response)
|
data.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d9feec5f917db5f960b189c720479a443afcd4a6c51f5f98cb370f747f4a7b6b
|
| 3 |
+
size 401944371
|
feedforward.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class FeedForward(nn.Module):
|
| 6 |
+
def __init__(self, d_model, ff_dim=2048):
|
| 7 |
+
super(FeedForward, self).__init__()
|
| 8 |
+
self.linear1 = nn.Linear(d_model, ff_dim)
|
| 9 |
+
self.linear2 = nn.Linear(ff_dim, d_model)
|
| 10 |
+
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
x = F.relu(self.linear1(x))
|
| 13 |
+
x = self.linear2(x)
|
| 14 |
+
return x
|
model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:398b4803f6b42d46f2da3ec4d07dfcf0349da443e5555321d7c85e1fcb364489
|
| 3 |
+
size 15764984
|
model.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class TransformerBlock(nn.Module):
|
| 5 |
+
def __init__(self, embed_size, heads, ff_hidden_dim, dropout):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=heads, batch_first=True)
|
| 8 |
+
self.norm1 = nn.LayerNorm(embed_size)
|
| 9 |
+
self.norm2 = nn.LayerNorm(embed_size)
|
| 10 |
+
self.ff = nn.Sequential(
|
| 11 |
+
nn.Linear(embed_size, ff_hidden_dim),
|
| 12 |
+
nn.ReLU(),
|
| 13 |
+
nn.Dropout(dropout),
|
| 14 |
+
nn.Linear(ff_hidden_dim, embed_size)
|
| 15 |
+
)
|
| 16 |
+
self.dropout = nn.Dropout(dropout)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
attn_output, _ = self.attention(x, x, x)
|
| 20 |
+
x = self.norm1(x + self.dropout(attn_output))
|
| 21 |
+
ff_output = self.ff(x)
|
| 22 |
+
x = self.norm2(x + self.dropout(ff_output))
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
class TransformerModel(nn.Module):
|
| 26 |
+
def __init__(self, vocab_size, embed_size=512, num_heads=8, hidden_dim=2048, num_layers=6, max_len=512, dropout=0.1):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.embedding = nn.Embedding(vocab_size, embed_size)
|
| 29 |
+
self.pos_embedding = nn.Parameter(torch.zeros(1, max_len, embed_size))
|
| 30 |
+
self.transformer_blocks = nn.Sequential(
|
| 31 |
+
*[TransformerBlock(embed_size, num_heads, hidden_dim, dropout) for _ in range(num_layers)]
|
| 32 |
+
)
|
| 33 |
+
self.norm = nn.LayerNorm(embed_size)
|
| 34 |
+
self.output = nn.Linear(embed_size, vocab_size)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
seq_len = x.size(1)
|
| 38 |
+
positions = self.pos_embedding[:, :seq_len, :]
|
| 39 |
+
x = self.embedding(x) + positions
|
| 40 |
+
x = self.transformer_blocks(x)
|
| 41 |
+
x = self.norm(x)
|
| 42 |
+
return self.output(x)
|
multi_head_attention.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
class MultiHeadAttention(nn.Module):
|
| 6 |
+
def __init__(self, d_model, n_heads):
|
| 7 |
+
super(MultiHeadAttention, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.d_model = d_model
|
| 10 |
+
self.n_heads = n_heads
|
| 11 |
+
|
| 12 |
+
assert d_model % self.n_heads == 0
|
| 13 |
+
self.head_dim = d_model // n_heads
|
| 14 |
+
|
| 15 |
+
self.query = nn.Linear(d_model, d_model)
|
| 16 |
+
self.key = nn.Linear(d_model, d_model)
|
| 17 |
+
self.value = nn.Linear(d_model, d_model)
|
| 18 |
+
|
| 19 |
+
self.fc_out = nn.Linear(d_model, d_model)
|
| 20 |
+
|
| 21 |
+
def forward(self, query, key, value, mask=None):
|
| 22 |
+
N = query.shape[0]
|
| 23 |
+
Q = self.query(query)
|
| 24 |
+
K = self.key(key)
|
| 25 |
+
V = self.value(value)
|
| 26 |
+
|
| 27 |
+
Q = Q.view(N, -1, self.n_heads, self.head_dim).transpose(1, 2)
|
| 28 |
+
K = K.view(N, -1, self.n_heads, self.head_dim).transpose(1, 2)
|
| 29 |
+
V = V.view(N, -1, self.n_heads, self.head_dim).transpose(1, 2)
|
| 30 |
+
|
| 31 |
+
energy = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 32 |
+
|
| 33 |
+
if mask is not None:
|
| 34 |
+
energy = energy.masked_fill(mask == 0, float('-1e20'))
|
| 35 |
+
|
| 36 |
+
attention = torch.softmax(energy, dim=-1)
|
| 37 |
+
out = torch.matmul(attention, V)
|
| 38 |
+
|
| 39 |
+
out = out.transpose(1, 2).contiguous().view(N, -1, self.n_heads * self.head_dim)
|
| 40 |
+
out = self.fc_out(out)
|
| 41 |
+
|
| 42 |
+
return out
|
positional_encoding.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class PositionalEncoding(nn.Module):
|
| 6 |
+
def __init__(self, d_model, max_len=5000):
|
| 7 |
+
super(PositionalEncoding, self).__init__()
|
| 8 |
+
|
| 9 |
+
pe = torch.zeros(max_len, d_model)
|
| 10 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 11 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
|
| 12 |
+
|
| 13 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 14 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 15 |
+
|
| 16 |
+
self.register_buffer('pe', pe.unsqueeze(0))
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return x + self.pe[:, :x.size(1)]
|
tiny_llama_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5761ad26bb663d61e9af274845b0122825705cecfcd9ac9aeb1140b100fca102
|
| 3 |
+
size 28985127
|
tokenizer.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
class SimpleTokenizer:
|
| 4 |
+
def __init__(self, vocab_path):
|
| 5 |
+
self.char_to_idx = torch.load(vocab_pth)
|
| 6 |
+
|
| 7 |
+
# Add <unk> if not in vocab
|
| 8 |
+
if '<unk>' not in self.char_to_idx:
|
| 9 |
+
self.char_to_idx['<unk>'] = max(self.char_to_idx.values()) + 1
|
| 10 |
+
|
| 11 |
+
self.idx_to_char = {i: c for c, i in self.char_to_idx.items()}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def encode(self, text):
|
| 15 |
+
return [self.char_to_idx.get(c, self.char_to_idx.get('<unk>', 0)) for c in text]
|
| 16 |
+
|
| 17 |
+
def decode(self, indices):
|
| 18 |
+
return ''.join([self.idx_to_char.get(i, '') for i in indices])
|
| 19 |
+
|
| 20 |
+
# Example usage
|
| 21 |
+
vocab_path = 'vocab.pth' # Replace with the actual path to your vocab file
|
| 22 |
+
tokenizer = SimpleTokenizer(vocab_path)
|
| 23 |
+
|
| 24 |
+
text = "Hello, world!"
|
| 25 |
+
tokens = tokenizer.encode(text) # Use the encode method here
|
| 26 |
+
print(tokens)
|
| 27 |
+
|
| 28 |
+
decoded_text = tokenizer.decode(tokens)
|
| 29 |
+
print(decoded_text)
|
train.log
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
-sh: nohup: command not found
|
train.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import psutil
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
# Load your data
|
| 9 |
+
def load_data(file_path):
|
| 10 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 11 |
+
return f.read()
|
| 12 |
+
|
| 13 |
+
# Tokenizer
|
| 14 |
+
class SimpleTokenizer:
|
| 15 |
+
def __init__(self, vocab_path):
|
| 16 |
+
self.char_to_idx = torch.load(vocab_path)
|
| 17 |
+
self.idx_to_char = {i: c for c, i in self.char_to_idx.items()}
|
| 18 |
+
|
| 19 |
+
def encode(self, text):
|
| 20 |
+
return [self.char_to_idx.get(c, self.char_to_idx.get('<unk>', 0)) for c in text]
|
| 21 |
+
|
| 22 |
+
def decode(self, indices):
|
| 23 |
+
return ''.join([self.idx_to_char.get(i, '') for i in indices])
|
| 24 |
+
|
| 25 |
+
# Model
|
| 26 |
+
class TransformerModel(nn.Module):
|
| 27 |
+
def __init__(self, vocab_size, emb_size=256, num_heads=4, num_layers=4, ff_hid_dim=1024):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.embedding = nn.Embedding(vocab_size, emb_size)
|
| 30 |
+
self.pos_embedding = nn.Parameter(torch.zeros(1, 512, emb_size))
|
| 31 |
+
self.transformer_blocks = nn.ModuleList([
|
| 32 |
+
nn.TransformerEncoderLayer(d_model=emb_size, nhead=num_heads, dim_feedforward=ff_hid_dim)
|
| 33 |
+
for _ in range(num_layers)
|
| 34 |
+
])
|
| 35 |
+
self.output = nn.Linear(emb_size, vocab_size)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
x = self.embedding(x) + self.pos_embedding[:, :x.size(1), :]
|
| 39 |
+
for block in self.transformer_blocks:
|
| 40 |
+
x = block(x)
|
| 41 |
+
return self.output(x)
|
| 42 |
+
|
| 43 |
+
# Batching
|
| 44 |
+
def get_batches(data, batch_size, seq_length):
|
| 45 |
+
inputs, targets = [], []
|
| 46 |
+
for i in range(0, len(data) - seq_length - 1, seq_length):
|
| 47 |
+
x = data[i:i + seq_length]
|
| 48 |
+
y = data[i + 1:i + 1 + seq_length]
|
| 49 |
+
if len(x) == seq_length and len(y) == seq_length:
|
| 50 |
+
inputs.append(x)
|
| 51 |
+
targets.append(y)
|
| 52 |
+
if len(inputs) == batch_size:
|
| 53 |
+
yield (
|
| 54 |
+
torch.tensor(inputs, dtype=torch.long),
|
| 55 |
+
torch.tensor(targets, dtype=torch.long)
|
| 56 |
+
)
|
| 57 |
+
inputs, targets = [], []
|
| 58 |
+
|
| 59 |
+
# Memory
|
| 60 |
+
def show_memory():
|
| 61 |
+
process = psutil.Process()
|
| 62 |
+
mem_info = process.memory_info()
|
| 63 |
+
return f"{mem_info.rss / 1024**2:.2f} MB"
|
| 64 |
+
|
| 65 |
+
# Training
|
| 66 |
+
def train():
|
| 67 |
+
vocab_size = 30000
|
| 68 |
+
batch_size = 64
|
| 69 |
+
seq_length = 64
|
| 70 |
+
num_epochs = 3
|
| 71 |
+
lr = 0.001
|
| 72 |
+
vocab_path = 'vocab.pth'
|
| 73 |
+
data_path = 'data.txt'
|
| 74 |
+
|
| 75 |
+
text = load_data(data_path)
|
| 76 |
+
tokenizer = SimpleTokenizer(vocab_path)
|
| 77 |
+
tokens = tokenizer.encode(text)
|
| 78 |
+
model = TransformerModel(vocab_size)
|
| 79 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 80 |
+
criterion = nn.CrossEntropyLoss()
|
| 81 |
+
|
| 82 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 83 |
+
model.to(device)
|
| 84 |
+
|
| 85 |
+
for epoch in range(num_epochs):
|
| 86 |
+
batches = list(get_batches(tokens, batch_size, seq_length))
|
| 87 |
+
total = len(batches)
|
| 88 |
+
total_loss = 0
|
| 89 |
+
print(f"\n🧠 Epoch {epoch+1}/{num_epochs} — {total} batches")
|
| 90 |
+
|
| 91 |
+
with tqdm(total=total, desc="Training", bar_format="{l_bar}{bar} [ time left: {remaining} ]") as pbar:
|
| 92 |
+
for step, (x, y) in enumerate(batches):
|
| 93 |
+
x, y = x.to(device), y.to(device)
|
| 94 |
+
optimizer.zero_grad()
|
| 95 |
+
output = model(x)
|
| 96 |
+
loss = criterion(output.view(-1, vocab_size), y.view(-1))
|
| 97 |
+
loss.backward()
|
| 98 |
+
optimizer.step()
|
| 99 |
+
|
| 100 |
+
total_loss += loss.item()
|
| 101 |
+
avg_loss = total_loss / (step + 1)
|
| 102 |
+
|
| 103 |
+
if step % 10 == 0:
|
| 104 |
+
pbar.set_description(f"Loss: {loss.item():.4f} | RAM: {show_memory()}")
|
| 105 |
+
pbar.update(1)
|
| 106 |
+
|
| 107 |
+
torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")
|
| 108 |
+
print(f"💾 Model saved: model_epoch_{epoch+1}.pth")
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
train()
|
transformer_block.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from multi_head_attention import MultiHeadAttention # Add this import
|
| 4 |
+
from feedforward import FeedForward
|
| 5 |
+
|
| 6 |
+
class TransformerBlock(nn.Module):
|
| 7 |
+
def __init__(self, d_model, n_heads, ff_dim):
|
| 8 |
+
super(TransformerBlock, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.attention = MultiHeadAttention(d_model, n_heads)
|
| 11 |
+
self.ffn = FeedForward(d_model, ff_dim)
|
| 12 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 13 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 14 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 15 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 16 |
+
|
| 17 |
+
def forward(self, x, mask=None):
|
| 18 |
+
# Multi-head attention
|
| 19 |
+
attn_out = self.attention(x, x, x, mask)
|
| 20 |
+
x = self.norm1(x + self.dropout1(attn_out))
|
| 21 |
+
|
| 22 |
+
# Feedforward network
|
| 23 |
+
ff_out = self.ffn(x)
|
| 24 |
+
x = self.norm2(x + self.dropout2(ff_out))
|
| 25 |
+
|
| 26 |
+
return x
|
transformer_model.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class TransformerBlock(nn.Module):
|
| 5 |
+
def __init__(self, d_model, n_heads, ff_dim):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.attention = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
|
| 8 |
+
self.ff = nn.Sequential(
|
| 9 |
+
nn.Linear(d_model, ff_dim),
|
| 10 |
+
nn.ReLU(),
|
| 11 |
+
nn.Linear(ff_dim, d_model),
|
| 12 |
+
)
|
| 13 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 14 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
attn_output, _ = self.attention(x, x, x)
|
| 18 |
+
x = self.norm1(x + attn_output)
|
| 19 |
+
x = self.norm2(x + self.ff(x))
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
class TransformerModel(nn.Module):
|
| 23 |
+
def __init__(self, vocab_size, d_model, n_heads, n_layers, max_len):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 26 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model))
|
| 27 |
+
self.transformer_blocks = nn.ModuleList([
|
| 28 |
+
TransformerBlock(d_model, n_heads, ff_dim=4*d_model)
|
| 29 |
+
for _ in range(n_layers)
|
| 30 |
+
])
|
| 31 |
+
self.output = nn.Linear(d_model, vocab_size)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
x = self.embedding(x) + self.pos_embedding[:, :x.size(1), :]
|
| 35 |
+
for block in self.transformer_blocks:
|
| 36 |
+
x = block(x)
|
| 37 |
+
return self.output(x)
|
vocab.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:17ec8f094e0274e43a931d387fbfb59caa5df051cbbe95d1f0b30584b1082d6a
|
| 3 |
+
size 696472
|