Upload 9 files
Browse files- src/convert_to_hf.py +184 -0
- src/data_pre_to_raw.py +37 -0
- src/data_processing.py +86 -0
- src/hf_inference.py +101 -0
- src/inference.py +285 -0
- src/model.py +182 -0
- src/tokenization.py +138 -0
- src/train.py +404 -0
- src/utils.py +15 -0
src/convert_to_hf.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import yaml
|
| 5 |
+
import json
|
| 6 |
+
import torch
|
| 7 |
+
import shutil
|
| 8 |
+
import tiktoken
|
| 9 |
+
from model import create_model # your model creation function from model.py
|
| 10 |
+
|
| 11 |
+
def load_config(config_path):
|
| 12 |
+
"""Load the training configuration from a YAML file."""
|
| 13 |
+
with open(config_path, "r") as f:
|
| 14 |
+
return yaml.safe_load(f)
|
| 15 |
+
|
| 16 |
+
def load_tokenizer_encoding(tokenizer_dir):
|
| 17 |
+
"""Reads the encoding name from encoding_config.txt in your tokenizer directory."""
|
| 18 |
+
encoding_config_path = os.path.join(tokenizer_dir, "encoding_config.txt")
|
| 19 |
+
if not os.path.exists(encoding_config_path):
|
| 20 |
+
raise FileNotFoundError(f"Encoding config not found at {encoding_config_path}")
|
| 21 |
+
|
| 22 |
+
with open(encoding_config_path, "r") as f:
|
| 23 |
+
content = f.read().strip()
|
| 24 |
+
# Expect a line like: "encoding_name: cl100k_base"
|
| 25 |
+
if ":" not in content:
|
| 26 |
+
raise ValueError(f"Invalid encoding config format: {content}")
|
| 27 |
+
|
| 28 |
+
_, encoding_name = content.split(":", 1)
|
| 29 |
+
return encoding_name.strip()
|
| 30 |
+
|
| 31 |
+
def get_tokenizer(encoding_name):
|
| 32 |
+
"""Initialize tiktoken encoding."""
|
| 33 |
+
tokenizer = tiktoken.get_encoding(encoding_name)
|
| 34 |
+
return tokenizer
|
| 35 |
+
|
| 36 |
+
def load_state_dict(checkpoint_dir):
|
| 37 |
+
"""
|
| 38 |
+
Loads the model state dict from a DeepSpeed checkpoint.
|
| 39 |
+
First tries to load a consolidated checkpoint, then attempts to convert from ZeRO format.
|
| 40 |
+
"""
|
| 41 |
+
# First try loading from converted_model directory
|
| 42 |
+
converted_path = os.path.join(checkpoint_dir, "converted_model", "pytorch_model.bin")
|
| 43 |
+
if os.path.exists(converted_path):
|
| 44 |
+
print(f"Loading converted checkpoint from {converted_path}")
|
| 45 |
+
state_dict = torch.load(converted_path, map_location="cpu")
|
| 46 |
+
|
| 47 |
+
# Remove "_orig_mod." prefix from keys if present
|
| 48 |
+
if all(k.startswith("_orig_mod.") for k in state_dict.keys()):
|
| 49 |
+
print("Removing '_orig_mod.' prefix from state dict keys")
|
| 50 |
+
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
| 51 |
+
|
| 52 |
+
return state_dict
|
| 53 |
+
|
| 54 |
+
# Try loading consolidated checkpoint from main directory
|
| 55 |
+
consolidated_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
|
| 56 |
+
if os.path.exists(consolidated_path):
|
| 57 |
+
print(f"Loading consolidated checkpoint from {consolidated_path}")
|
| 58 |
+
state_dict = torch.load(consolidated_path, map_location="cpu")
|
| 59 |
+
|
| 60 |
+
# Remove "_orig_mod." prefix from keys if present
|
| 61 |
+
if all(k.startswith("_orig_mod.") for k in state_dict.keys()):
|
| 62 |
+
print("Removing '_orig_mod.' prefix from state dict keys")
|
| 63 |
+
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
| 64 |
+
|
| 65 |
+
return state_dict
|
| 66 |
+
|
| 67 |
+
# If no consolidated checkpoint exists, try converting from ZeRO format
|
| 68 |
+
print("No consolidated checkpoint found. Converting from ZeRO format...")
|
| 69 |
+
|
| 70 |
+
# Import the zero_to_fp32 module from the checkpoint directory
|
| 71 |
+
import sys
|
| 72 |
+
sys.path.append(checkpoint_dir)
|
| 73 |
+
from zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
# Convert ZeRO checkpoint to consolidated checkpoint
|
| 77 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, exclude_frozen_parameters=False)
|
| 78 |
+
|
| 79 |
+
if state_dict is None:
|
| 80 |
+
raise ValueError("Failed to convert ZeRO checkpoint")
|
| 81 |
+
|
| 82 |
+
# Remove "_orig_mod." prefix from keys if present
|
| 83 |
+
if all(k.startswith("_orig_mod.") for k in state_dict.keys()):
|
| 84 |
+
print("Removing '_orig_mod.' prefix from state dict keys")
|
| 85 |
+
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
| 86 |
+
|
| 87 |
+
print("Successfully converted ZeRO checkpoint to consolidated format")
|
| 88 |
+
return state_dict
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Error converting ZeRO checkpoint: {str(e)}")
|
| 92 |
+
raise
|
| 93 |
+
|
| 94 |
+
def convert_to_hf(checkpoint_dir, tokenizer_dir, config_path, output_dir):
|
| 95 |
+
# Load configurations
|
| 96 |
+
config = load_config(config_path)
|
| 97 |
+
|
| 98 |
+
# Set up tokenizer
|
| 99 |
+
encoding_name = load_tokenizer_encoding(tokenizer_dir)
|
| 100 |
+
tokenizer = get_tokenizer(encoding_name)
|
| 101 |
+
vocab_size = tokenizer.n_vocab
|
| 102 |
+
print(f"Using tokenizer encoding: {encoding_name} (vocab size: {vocab_size})")
|
| 103 |
+
|
| 104 |
+
# Update config with correct vocab size
|
| 105 |
+
config["model"]["vocab_size"] = vocab_size
|
| 106 |
+
|
| 107 |
+
# Create model and load weights
|
| 108 |
+
model = create_model(config)
|
| 109 |
+
state_dict = load_state_dict(checkpoint_dir)
|
| 110 |
+
model.load_state_dict(state_dict)
|
| 111 |
+
model.eval()
|
| 112 |
+
|
| 113 |
+
# Create output directory
|
| 114 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 115 |
+
|
| 116 |
+
# 1. Save model weights
|
| 117 |
+
model_path = os.path.join(output_dir, "pytorch_model.bin")
|
| 118 |
+
torch.save(model.state_dict(), model_path)
|
| 119 |
+
print(f"Saved model weights to {model_path}")
|
| 120 |
+
|
| 121 |
+
# 2. Save model config
|
| 122 |
+
model_config = {
|
| 123 |
+
"architectures": ["CustomLanguageModel"],
|
| 124 |
+
"model_type": "custom-gpt",
|
| 125 |
+
"vocab_size": vocab_size,
|
| 126 |
+
"n_positions": config["model"]["n_positions"],
|
| 127 |
+
"n_embd": config["model"]["n_embd"],
|
| 128 |
+
"n_layer": config["model"]["n_layer"],
|
| 129 |
+
"n_head": config["model"]["n_head"],
|
| 130 |
+
"bos_token_id": None,
|
| 131 |
+
"eos_token_id": tokenizer.eot_token,
|
| 132 |
+
"tie_word_embeddings": True,
|
| 133 |
+
"gradient_checkpointing": config["model"].get("gradient_checkpointing", False)
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
config_path = os.path.join(output_dir, "config.json")
|
| 137 |
+
with open(config_path, "w") as f:
|
| 138 |
+
json.dump(model_config, f, indent=2)
|
| 139 |
+
print(f"Saved model config to {config_path}")
|
| 140 |
+
|
| 141 |
+
# 3. Save tokenizer config
|
| 142 |
+
tokenizer_config = {
|
| 143 |
+
"model_type": "tiktoken",
|
| 144 |
+
"encoding_name": encoding_name,
|
| 145 |
+
"vocab_size": vocab_size,
|
| 146 |
+
"max_length": config["dataset"]["max_length"],
|
| 147 |
+
"padding_side": "right",
|
| 148 |
+
"truncation_side": "right",
|
| 149 |
+
"bos_token": "<|endoftext|>",
|
| 150 |
+
"eos_token": "<|endoftext|>",
|
| 151 |
+
"unk_token": "<|endoftext|>",
|
| 152 |
+
"pad_token": "<|endoftext|>"
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
tokenizer_config_path = os.path.join(output_dir, "tokenizer_config.json")
|
| 156 |
+
with open(tokenizer_config_path, "w") as f:
|
| 157 |
+
json.dump(tokenizer_config, f, indent=2)
|
| 158 |
+
print(f"Saved tokenizer config to {tokenizer_config_path}")
|
| 159 |
+
|
| 160 |
+
# 4. Copy tokenizer files
|
| 161 |
+
src_encoding_config = os.path.join(tokenizer_dir, "encoding_config.txt")
|
| 162 |
+
if os.path.exists(src_encoding_config):
|
| 163 |
+
dst_encoding_config = os.path.join(output_dir, "encoding_config.txt")
|
| 164 |
+
shutil.copy2(src_encoding_config, dst_encoding_config)
|
| 165 |
+
print(f"Copied encoding config to {dst_encoding_config}")
|
| 166 |
+
|
| 167 |
+
print(f"\nConversion complete! HuggingFace model saved to: {output_dir}")
|
| 168 |
+
|
| 169 |
+
def main():
|
| 170 |
+
parser = argparse.ArgumentParser(description="Convert DeepSpeed checkpoint to HuggingFace format")
|
| 171 |
+
parser.add_argument("--checkpoint_dir", required=True,
|
| 172 |
+
help="Path to the checkpoint directory")
|
| 173 |
+
parser.add_argument("--tokenizer_dir", required=True,
|
| 174 |
+
help="Path to the tokenizer directory")
|
| 175 |
+
parser.add_argument("--config", default="config/config.yaml",
|
| 176 |
+
help="Path to the training config.yaml file")
|
| 177 |
+
parser.add_argument("--output_dir", required=True,
|
| 178 |
+
help="Output directory for HuggingFace model")
|
| 179 |
+
|
| 180 |
+
args = parser.parse_args()
|
| 181 |
+
convert_to_hf(args.checkpoint_dir, args.tokenizer_dir, args.config, args.output_dir)
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
main()
|
src/data_pre_to_raw.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
def convert_to_raw_text():
|
| 6 |
+
# Setup paths
|
| 7 |
+
processed_dir = Path("data/processed")
|
| 8 |
+
raw_dir = Path("data/raw")
|
| 9 |
+
|
| 10 |
+
# Create raw directory if it doesn't exist
|
| 11 |
+
raw_dir.mkdir(parents=True, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
# Output file for combined raw text
|
| 14 |
+
output_file = raw_dir / "combined_raw_text.txt"
|
| 15 |
+
|
| 16 |
+
# Process all txt files in processed directory
|
| 17 |
+
processed_files = list(processed_dir.glob("*.txt"))
|
| 18 |
+
|
| 19 |
+
print(f"Found {len(processed_files)} files to process")
|
| 20 |
+
|
| 21 |
+
with open(output_file, 'w', encoding='utf-8') as outfile:
|
| 22 |
+
for proc_file in tqdm(processed_files, desc="Converting files"):
|
| 23 |
+
with open(proc_file, 'r', encoding='utf-8') as infile:
|
| 24 |
+
for line in infile:
|
| 25 |
+
# Skip metadata lines (starting with #)
|
| 26 |
+
if not line.startswith('#'):
|
| 27 |
+
# Only write non-empty lines
|
| 28 |
+
line = line.strip()
|
| 29 |
+
if line:
|
| 30 |
+
outfile.write(line + '\n')
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
try:
|
| 34 |
+
convert_to_raw_text()
|
| 35 |
+
print("Successfully converted processed data to raw text")
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"Error during conversion: {e}")
|
src/data_processing.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import os
|
| 4 |
+
from utils import load_config, setup_logging
|
| 5 |
+
import psutil # For monitoring memory usage
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def download_and_process_data(config):
|
| 9 |
+
"""Downloads, preprocesses, and saves the dataset."""
|
| 10 |
+
setup_logging()
|
| 11 |
+
|
| 12 |
+
dataset_name = config["dataset"]["name"]
|
| 13 |
+
streaming = config["dataset"]["streaming"]
|
| 14 |
+
text_column = config["dataset"]["text_column"]
|
| 15 |
+
target_size_gb = config["dataset"]["target_size_gb"]
|
| 16 |
+
max_length = config["dataset"]["max_length"]
|
| 17 |
+
subset = config["dataset"]["subset"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Download dataset (streaming is essential for large datasets)
|
| 22 |
+
try:
|
| 23 |
+
dataset = load_dataset(dataset_name, subset,streaming=streaming)
|
| 24 |
+
if not streaming:
|
| 25 |
+
raise ValueError("Streaming must be True for large datasets like fineweb")
|
| 26 |
+
except Exception as e:
|
| 27 |
+
raise Exception(f"Failed to download dataset: {e}. Check dataset name and internet connection, and HF login.")
|
| 28 |
+
|
| 29 |
+
# Filter data - removing the subset filter since it's specific to CC-MAIN
|
| 30 |
+
dataset = dataset["train"] # Taking only train split
|
| 31 |
+
|
| 32 |
+
# Add basic quality filters
|
| 33 |
+
def quality_filter(example):
|
| 34 |
+
return (
|
| 35 |
+
example['text'] is not None and
|
| 36 |
+
len(example['text'].strip()) > 0 and
|
| 37 |
+
example['language'] == 'en' and # Filter for English content
|
| 38 |
+
example['language_score'] >= 0.8 # High confidence in language detection
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
dataset = dataset.filter(quality_filter)
|
| 42 |
+
|
| 43 |
+
# Create output directory if it doesn't exist
|
| 44 |
+
output_dir = os.path.join("data", "processed")
|
| 45 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 46 |
+
|
| 47 |
+
# Process and save in chunks, monitoring data size
|
| 48 |
+
def process_and_save_chunk(chunk, chunk_num, total_bytes):
|
| 49 |
+
output_file = os.path.join(output_dir, f"processed_data_{chunk_num}.txt")
|
| 50 |
+
|
| 51 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 52 |
+
for example in tqdm(chunk, desc=f"Processing chunk {chunk_num}"):
|
| 53 |
+
text = example[text_column].strip()
|
| 54 |
+
if text:
|
| 55 |
+
# Add metadata as a comment before each text
|
| 56 |
+
metadata = f"# ID: {example['id']} | URL: {example['url']} | Date: {example['date']}\n"
|
| 57 |
+
f.write(metadata)
|
| 58 |
+
f.write(text + "\n\n") # Add extra newline for separation
|
| 59 |
+
total_bytes += len(text.encode("utf-8")) + len(metadata.encode("utf-8"))
|
| 60 |
+
return total_bytes
|
| 61 |
+
|
| 62 |
+
chunk_num = 0
|
| 63 |
+
chunk = []
|
| 64 |
+
total_bytes_processed = 0
|
| 65 |
+
target_bytes = target_size_gb * (1024**3) # Convert GB to bytes
|
| 66 |
+
|
| 67 |
+
for example in tqdm(dataset, desc="Processing and saving data"):
|
| 68 |
+
chunk.append(example)
|
| 69 |
+
if len(chunk) >= 10000: # Adjust chunk size as needed
|
| 70 |
+
total_bytes_processed = process_and_save_chunk(chunk, chunk_num, total_bytes_processed)
|
| 71 |
+
chunk = []
|
| 72 |
+
chunk_num += 1
|
| 73 |
+
print(f"Processed: {total_bytes_processed / (1024**3):.2f} GB")
|
| 74 |
+
|
| 75 |
+
if total_bytes_processed >= target_bytes:
|
| 76 |
+
print("Target data size reached.")
|
| 77 |
+
break # Stop processing
|
| 78 |
+
|
| 79 |
+
if chunk:
|
| 80 |
+
process_and_save_chunk(chunk, chunk_num,total_bytes_processed) #for remaining data
|
| 81 |
+
|
| 82 |
+
print(f"Data download and processing complete. Total processed size: {total_bytes_processed / (1024**3):.2f} GB")
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
config = load_config()
|
| 86 |
+
download_and_process_data(config)
|
src/hf_inference.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 3 |
+
from utils import load_config
|
| 4 |
+
from tokenization import get_tokenizer
|
| 5 |
+
|
| 6 |
+
class CustomConfig(PretrainedConfig):
|
| 7 |
+
"""Configuration class for the custom language model."""
|
| 8 |
+
model_type = "custom_llm"
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
vocab_size: int = 50000,
|
| 13 |
+
n_embd: int = 640,
|
| 14 |
+
n_head: int = 10,
|
| 15 |
+
n_layer: int = 12,
|
| 16 |
+
n_positions: int = 512,
|
| 17 |
+
tie_word_embeddings: bool = True,
|
| 18 |
+
**kwargs
|
| 19 |
+
):
|
| 20 |
+
self.vocab_size = vocab_size
|
| 21 |
+
self.n_embd = n_embd
|
| 22 |
+
self.n_head = n_head
|
| 23 |
+
self.n_layer = n_layer
|
| 24 |
+
self.n_positions = n_positions
|
| 25 |
+
self.tie_word_embeddings = tie_word_embeddings
|
| 26 |
+
super().__init__(**kwargs)
|
| 27 |
+
|
| 28 |
+
def generate_text(
|
| 29 |
+
prompt: str,
|
| 30 |
+
model_path: str = "outputs/hf_model",
|
| 31 |
+
max_length: int = 200,
|
| 32 |
+
temperature: float = 0.8,
|
| 33 |
+
top_k: int = 50,
|
| 34 |
+
top_p: float = 0.9,
|
| 35 |
+
repetition_penalty: float = 1.2,
|
| 36 |
+
no_repeat_ngram_size: int = 3
|
| 37 |
+
):
|
| 38 |
+
"""Generate text using the model."""
|
| 39 |
+
# Load config and tokenizer
|
| 40 |
+
config = load_config()
|
| 41 |
+
tokenizer = get_tokenizer(config)
|
| 42 |
+
|
| 43 |
+
# Load model
|
| 44 |
+
from inference import CustomModelForCausalLM # Import here to avoid circular imports
|
| 45 |
+
model = CustomModelForCausalLM.from_pretrained(model_path)
|
| 46 |
+
|
| 47 |
+
# Move model to GPU if available
|
| 48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
model = model.to(device)
|
| 50 |
+
model.eval()
|
| 51 |
+
|
| 52 |
+
# Encode prompt
|
| 53 |
+
encoded = tokenizer.batch_encode(
|
| 54 |
+
[prompt],
|
| 55 |
+
return_tensors="pt"
|
| 56 |
+
)
|
| 57 |
+
input_ids = encoded["input_ids"].to(device)
|
| 58 |
+
|
| 59 |
+
# Generate
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
output_ids = model.generate(
|
| 62 |
+
input_ids=input_ids,
|
| 63 |
+
max_length=max_length,
|
| 64 |
+
temperature=temperature,
|
| 65 |
+
top_k=top_k,
|
| 66 |
+
top_p=top_p,
|
| 67 |
+
repetition_penalty=repetition_penalty,
|
| 68 |
+
no_repeat_ngram_size=no_repeat_ngram_size
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Decode and return
|
| 72 |
+
generated_text = tokenizer.decode(output_ids[0].tolist())
|
| 73 |
+
return generated_text
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
# Example prompts to test
|
| 77 |
+
prompts = [
|
| 78 |
+
"Once upon a time",
|
| 79 |
+
"The meaning of life is",
|
| 80 |
+
"In the distant future",
|
| 81 |
+
"The best way to learn programming is",
|
| 82 |
+
"Today I learned that"
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
print("\nGenerating text from multiple prompts:")
|
| 86 |
+
print("=" * 50)
|
| 87 |
+
|
| 88 |
+
for prompt in prompts:
|
| 89 |
+
generated_text = generate_text(
|
| 90 |
+
prompt=prompt,
|
| 91 |
+
max_length=200,
|
| 92 |
+
temperature=0.8, # Adjust for creativity (higher = more creative)
|
| 93 |
+
top_k=50, # Limit to top 50 tokens
|
| 94 |
+
top_p=0.9, # Nucleus sampling threshold
|
| 95 |
+
repetition_penalty=1.2, # Penalize repetition
|
| 96 |
+
no_repeat_ngram_size=3 # Prevent 3-gram repetition
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
print(f"\nPrompt: {prompt}")
|
| 100 |
+
print(f"Generated: {generated_text}")
|
| 101 |
+
print("-" * 50)
|
src/inference.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer, PretrainedConfig
|
| 3 |
+
from typing import Optional, Tuple, Union, List
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
from model import CustomLanguageModel
|
| 7 |
+
from utils import load_config
|
| 8 |
+
from tokenization import get_tokenizer
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
class CustomConfig(PretrainedConfig):
|
| 12 |
+
"""Configuration class for the custom language model."""
|
| 13 |
+
model_type = "custom_llm"
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
vocab_size: int = 50000,
|
| 18 |
+
n_embd: int = 768,
|
| 19 |
+
n_head: int = 12,
|
| 20 |
+
n_layer: int = 12,
|
| 21 |
+
n_positions: int = 2048,
|
| 22 |
+
tie_word_embeddings: bool = False,
|
| 23 |
+
**kwargs
|
| 24 |
+
):
|
| 25 |
+
self.vocab_size = vocab_size
|
| 26 |
+
self.n_embd = n_embd
|
| 27 |
+
self.n_head = n_head
|
| 28 |
+
self.n_layer = n_layer
|
| 29 |
+
self.n_positions = n_positions
|
| 30 |
+
self.tie_word_embeddings = tie_word_embeddings
|
| 31 |
+
super().__init__(**kwargs)
|
| 32 |
+
|
| 33 |
+
class CustomModelForCausalLM(PreTrainedModel):
|
| 34 |
+
"""Wrapper class to make the model compatible with Hugging Face's interface."""
|
| 35 |
+
config_class = CustomConfig
|
| 36 |
+
supports_gradient_checkpointing = True
|
| 37 |
+
|
| 38 |
+
def __init__(self, config):
|
| 39 |
+
super().__init__(config)
|
| 40 |
+
# Convert config to dictionary format expected by CustomLanguageModel
|
| 41 |
+
model_config = {
|
| 42 |
+
"model": {
|
| 43 |
+
"vocab_size": config.vocab_size,
|
| 44 |
+
"n_embd": config.n_embd,
|
| 45 |
+
"n_head": config.n_head,
|
| 46 |
+
"n_layer": config.n_layer,
|
| 47 |
+
"n_positions": config.n_positions,
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
self.transformer = CustomLanguageModel(model_config)
|
| 51 |
+
|
| 52 |
+
# Explicitly create separate weights for lm_head
|
| 53 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 54 |
+
|
| 55 |
+
# Initialize weights
|
| 56 |
+
self.post_init()
|
| 57 |
+
|
| 58 |
+
def forward(
|
| 59 |
+
self,
|
| 60 |
+
input_ids: torch.LongTensor,
|
| 61 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 62 |
+
labels: Optional[torch.LongTensor] = None,
|
| 63 |
+
**kwargs
|
| 64 |
+
):
|
| 65 |
+
outputs = self.transformer(input_ids=input_ids, labels=labels)
|
| 66 |
+
return outputs
|
| 67 |
+
|
| 68 |
+
def get_input_embeddings(self):
|
| 69 |
+
return self.transformer.token_embedding
|
| 70 |
+
|
| 71 |
+
def set_input_embeddings(self, value):
|
| 72 |
+
self.transformer.token_embedding = value
|
| 73 |
+
|
| 74 |
+
def generate(
|
| 75 |
+
self,
|
| 76 |
+
input_ids: torch.LongTensor,
|
| 77 |
+
max_length: int = 100,
|
| 78 |
+
temperature: float = 1.0,
|
| 79 |
+
top_k: int = 50,
|
| 80 |
+
top_p: float = 0.9,
|
| 81 |
+
repetition_penalty: float = 1.2,
|
| 82 |
+
no_repeat_ngram_size: int = 3,
|
| 83 |
+
**kwargs
|
| 84 |
+
):
|
| 85 |
+
"""Enhanced generation method with better controls for repetition."""
|
| 86 |
+
self.eval()
|
| 87 |
+
current_ids = input_ids.clone()
|
| 88 |
+
batch_size = current_ids.shape[0]
|
| 89 |
+
|
| 90 |
+
# Get EOS token ID from tokenizer
|
| 91 |
+
eos_token_id = self.transformer.eos_token_id if hasattr(self.transformer, 'eos_token_id') else None
|
| 92 |
+
|
| 93 |
+
# Track generated tokens for repetition penalty
|
| 94 |
+
generated_tokens = current_ids.clone()
|
| 95 |
+
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
for _ in range(max_length - input_ids.size(1)):
|
| 98 |
+
# Forward pass
|
| 99 |
+
outputs = self.transformer(current_ids)
|
| 100 |
+
logits = outputs["logits"][:, -1, :] / temperature
|
| 101 |
+
|
| 102 |
+
# Apply repetition penalty
|
| 103 |
+
if repetition_penalty != 1.0:
|
| 104 |
+
for i in range(batch_size):
|
| 105 |
+
for token in set(generated_tokens[i].tolist()):
|
| 106 |
+
logits[i, token] /= repetition_penalty
|
| 107 |
+
|
| 108 |
+
# Apply n-gram blocking
|
| 109 |
+
if no_repeat_ngram_size > 0:
|
| 110 |
+
# Get the last n-gram from the input
|
| 111 |
+
for i in range(batch_size):
|
| 112 |
+
ngram_size = min(no_repeat_ngram_size, len(generated_tokens[i]))
|
| 113 |
+
if ngram_size > 0:
|
| 114 |
+
ngrams = [tuple(generated_tokens[i, -j:].tolist()) for j in range(1, ngram_size + 1)]
|
| 115 |
+
for ngram in ngrams:
|
| 116 |
+
for token_idx in range(len(generated_tokens[i]) - len(ngram) + 1):
|
| 117 |
+
if tuple(generated_tokens[i, token_idx:token_idx + len(ngram)].tolist()) == ngram:
|
| 118 |
+
if token_idx + len(ngram) < len(generated_tokens[i]):
|
| 119 |
+
next_token = generated_tokens[i, token_idx + len(ngram)]
|
| 120 |
+
logits[i, next_token] = float('-inf')
|
| 121 |
+
|
| 122 |
+
# Apply top-k filtering
|
| 123 |
+
if top_k > 0:
|
| 124 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 125 |
+
logits[indices_to_remove] = float('-inf')
|
| 126 |
+
|
| 127 |
+
# Apply top-p (nucleus) filtering
|
| 128 |
+
if top_p < 1.0:
|
| 129 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 130 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 131 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 132 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 133 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 134 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 135 |
+
logits[indices_to_remove] = float('-inf')
|
| 136 |
+
|
| 137 |
+
# Sample from the filtered distribution
|
| 138 |
+
probs = torch.softmax(logits, dim=-1)
|
| 139 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 140 |
+
|
| 141 |
+
# Early stopping if EOS token is generated
|
| 142 |
+
if eos_token_id is not None and (next_token == eos_token_id).any():
|
| 143 |
+
break
|
| 144 |
+
|
| 145 |
+
# Update generated sequence
|
| 146 |
+
current_ids = torch.cat([current_ids, next_token], dim=1)
|
| 147 |
+
generated_tokens = torch.cat([generated_tokens, next_token], dim=1)
|
| 148 |
+
|
| 149 |
+
return current_ids
|
| 150 |
+
|
| 151 |
+
def convert_to_hf_model(checkpoint_path: str, output_dir: str):
|
| 152 |
+
"""Convert the custom model checkpoint to Hugging Face format using safetensors."""
|
| 153 |
+
# Load the original config and checkpoint
|
| 154 |
+
config = load_config()
|
| 155 |
+
|
| 156 |
+
# Get tokenizer and its vocab size
|
| 157 |
+
tokenizer = get_tokenizer(config)
|
| 158 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 159 |
+
|
| 160 |
+
# Create HF config with the correct vocab size
|
| 161 |
+
hf_config = CustomConfig(
|
| 162 |
+
vocab_size=vocab_size,
|
| 163 |
+
n_embd=config["model"]["n_embd"],
|
| 164 |
+
n_head=config["model"]["n_head"],
|
| 165 |
+
n_layer=config["model"]["n_layer"],
|
| 166 |
+
n_positions=config["model"]["n_positions"],
|
| 167 |
+
tie_word_embeddings=False # Explicitly disable weight tying
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Create HF model
|
| 171 |
+
model = CustomModelForCausalLM(hf_config)
|
| 172 |
+
|
| 173 |
+
# Load checkpoint
|
| 174 |
+
checkpoint = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"), map_location="cpu")
|
| 175 |
+
|
| 176 |
+
# Process state dict
|
| 177 |
+
new_state_dict = {}
|
| 178 |
+
for key, value in checkpoint.items():
|
| 179 |
+
if key.startswith("_orig_mod."):
|
| 180 |
+
key = key[len("_orig_mod."):]
|
| 181 |
+
|
| 182 |
+
if "token_embedding.weight" in key:
|
| 183 |
+
new_state_dict[f"transformer.{key}"] = value
|
| 184 |
+
# Copy embedding weights to lm_head
|
| 185 |
+
new_state_dict["lm_head.weight"] = value.clone()
|
| 186 |
+
else:
|
| 187 |
+
new_state_dict[f"transformer.{key}"] = value
|
| 188 |
+
|
| 189 |
+
# Load the modified state dict
|
| 190 |
+
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
|
| 191 |
+
print(f"Missing keys: {missing_keys}")
|
| 192 |
+
print(f"Unexpected keys: {unexpected_keys}")
|
| 193 |
+
|
| 194 |
+
# Save in Hugging Face format with safetensors
|
| 195 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 196 |
+
|
| 197 |
+
# Save the model in safetensors format
|
| 198 |
+
model.save_pretrained(
|
| 199 |
+
output_dir,
|
| 200 |
+
safe_serialization=True
|
| 201 |
+
)
|
| 202 |
+
print(f"Model successfully saved in safetensors format to {output_dir}")
|
| 203 |
+
|
| 204 |
+
# Save config
|
| 205 |
+
hf_config.save_pretrained(output_dir)
|
| 206 |
+
|
| 207 |
+
# Copy tokenizer files
|
| 208 |
+
tokenizer_files = ["vocab.json", "merges.txt", "tokenizer_config.json"]
|
| 209 |
+
for file in tokenizer_files:
|
| 210 |
+
src_path = os.path.join(config["tokenizer"]["model_path"], file)
|
| 211 |
+
dst_path = os.path.join(output_dir, file)
|
| 212 |
+
if os.path.exists(src_path):
|
| 213 |
+
import shutil
|
| 214 |
+
shutil.copy2(src_path, dst_path)
|
| 215 |
+
|
| 216 |
+
return model, tokenizer
|
| 217 |
+
|
| 218 |
+
def generate_text(
|
| 219 |
+
prompt: str,
|
| 220 |
+
model_path: str,
|
| 221 |
+
max_length: int = 100,
|
| 222 |
+
temperature: float = 2,
|
| 223 |
+
top_k: int = 50,
|
| 224 |
+
top_p: float = 0.9,
|
| 225 |
+
repetition_penalty: float = 1.2,
|
| 226 |
+
no_repeat_ngram_size: int = 3
|
| 227 |
+
):
|
| 228 |
+
"""Generate text using the converted model."""
|
| 229 |
+
# Load model and tokenizer
|
| 230 |
+
config = load_config()
|
| 231 |
+
model = CustomModelForCausalLM.from_pretrained(model_path)
|
| 232 |
+
tokenizer = get_tokenizer(config)
|
| 233 |
+
|
| 234 |
+
# Move model to GPU if available
|
| 235 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 236 |
+
model = model.to(device)
|
| 237 |
+
model.eval()
|
| 238 |
+
|
| 239 |
+
# Encode prompt
|
| 240 |
+
encoded = tokenizer.batch_encode(
|
| 241 |
+
[prompt],
|
| 242 |
+
return_tensors="pt"
|
| 243 |
+
)
|
| 244 |
+
input_ids = encoded["input_ids"].to(device)
|
| 245 |
+
|
| 246 |
+
# Generate
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
output_ids = model.generate(
|
| 249 |
+
input_ids=input_ids,
|
| 250 |
+
max_length=max_length,
|
| 251 |
+
temperature=temperature,
|
| 252 |
+
top_k=top_k,
|
| 253 |
+
top_p=top_p,
|
| 254 |
+
repetition_penalty=repetition_penalty,
|
| 255 |
+
no_repeat_ngram_size=no_repeat_ngram_size
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Decode and return
|
| 259 |
+
generated_text = tokenizer.decode(output_ids[0].tolist())
|
| 260 |
+
return generated_text
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
# Example usage
|
| 264 |
+
checkpoint_path = r"my_model/" # Path to your trained model
|
| 265 |
+
hf_output_dir = "outputs/hf_model" # Where to save the converted model
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# Convert model
|
| 269 |
+
model, tokenizer = convert_to_hf_model(checkpoint_path, hf_output_dir)
|
| 270 |
+
|
| 271 |
+
# Generate text with better parameters
|
| 272 |
+
prompt = "Hello I am Clera "
|
| 273 |
+
generated_text = generate_text(
|
| 274 |
+
prompt=prompt,
|
| 275 |
+
model_path=hf_output_dir,
|
| 276 |
+
max_length=20,
|
| 277 |
+
temperature=2.5,
|
| 278 |
+
top_k=50,
|
| 279 |
+
top_p=0.9,
|
| 280 |
+
repetition_penalty=1.2,
|
| 281 |
+
no_repeat_ngram_size=1
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
print(f"\nPrompt: {prompt}")
|
| 285 |
+
print(f"Generated text: {generated_text}")
|
src/model.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from utils import load_config
|
| 5 |
+
from tokenizers import Tokenizer
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
class TransformerBlock(nn.Module):
|
| 10 |
+
"""Single transformer block with self-attention and feed-forward layers"""
|
| 11 |
+
def __init__(self, n_embd, n_head, dropout=0.1):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.attention = nn.MultiheadAttention(n_embd, n_head, dropout=dropout, batch_first=True)
|
| 14 |
+
self.feed_forward = nn.Sequential(
|
| 15 |
+
nn.Linear(n_embd, 4 * n_embd),
|
| 16 |
+
nn.GELU(),
|
| 17 |
+
nn.Linear(4 * n_embd, n_embd)
|
| 18 |
+
)
|
| 19 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 20 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 21 |
+
self.dropout = nn.Dropout(dropout)
|
| 22 |
+
|
| 23 |
+
def forward(self, x, mask=None):
|
| 24 |
+
# Ensure mask is same dtype as input
|
| 25 |
+
if mask is not None:
|
| 26 |
+
mask = mask.to(dtype=x.dtype)
|
| 27 |
+
# Self-attention with residual connection
|
| 28 |
+
attn_out, _ = self.attention(x, x, x, attn_mask=mask)
|
| 29 |
+
x = x + self.dropout(attn_out)
|
| 30 |
+
x = self.ln1(x)
|
| 31 |
+
# Feed-forward with residual connection
|
| 32 |
+
ff_out = self.feed_forward(x)
|
| 33 |
+
x = x + self.dropout(ff_out)
|
| 34 |
+
x = self.ln2(x)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
class CustomLanguageModel(nn.Module):
|
| 38 |
+
"""Custom transformer-based language model"""
|
| 39 |
+
def __init__(self, config):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.vocab_size = config["model"]["vocab_size"]
|
| 42 |
+
self.n_embd = config["model"]["n_embd"]
|
| 43 |
+
self.n_head = config["model"]["n_head"]
|
| 44 |
+
self.n_layer = config["model"]["n_layer"]
|
| 45 |
+
self.n_positions = config["model"]["n_positions"]
|
| 46 |
+
|
| 47 |
+
# Token and position embeddings
|
| 48 |
+
self.token_embedding = nn.Embedding(self.vocab_size, self.n_embd)
|
| 49 |
+
self.position_embedding = nn.Embedding(self.n_positions, self.n_embd)
|
| 50 |
+
|
| 51 |
+
# Transformer blocks
|
| 52 |
+
self.transformer_blocks = nn.ModuleList([
|
| 53 |
+
TransformerBlock(self.n_embd, self.n_head)
|
| 54 |
+
for _ in range(self.n_layer)
|
| 55 |
+
])
|
| 56 |
+
|
| 57 |
+
# Output layer
|
| 58 |
+
self.ln_f = nn.LayerNorm(self.n_embd)
|
| 59 |
+
self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=False)
|
| 60 |
+
|
| 61 |
+
# Tie weights between token embedding and output layer
|
| 62 |
+
self.token_embedding.weight = self.lm_head.weight
|
| 63 |
+
|
| 64 |
+
# Initialize weights
|
| 65 |
+
self.apply(self._init_weights)
|
| 66 |
+
|
| 67 |
+
# Set gradient checkpointing flag based on config
|
| 68 |
+
self.gradient_checkpointing_enable = config["model"].get("gradient_checkpointing", False)
|
| 69 |
+
|
| 70 |
+
def _init_weights(self, module):
|
| 71 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 72 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 73 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 74 |
+
torch.nn.init.zeros_(module.bias)
|
| 75 |
+
elif isinstance(module, nn.LayerNorm):
|
| 76 |
+
torch.nn.init.zeros_(module.bias)
|
| 77 |
+
torch.nn.init.ones_(module.weight)
|
| 78 |
+
|
| 79 |
+
def forward(self, input_ids, labels=None):
|
| 80 |
+
batch_size, seq_length = input_ids.shape
|
| 81 |
+
|
| 82 |
+
# Create position indices
|
| 83 |
+
positions = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
|
| 84 |
+
positions = positions.unsqueeze(0).expand(batch_size, -1)
|
| 85 |
+
|
| 86 |
+
# Get embeddings and sum token & position embeddings
|
| 87 |
+
token_embeddings = self.token_embedding(input_ids)
|
| 88 |
+
position_embeddings = self.position_embedding(positions)
|
| 89 |
+
x = token_embeddings + position_embeddings
|
| 90 |
+
|
| 91 |
+
# Create causal mask and convert to same dtype as embeddings
|
| 92 |
+
mask = torch.triu(torch.ones((seq_length, seq_length), device=input_ids.device) * float('-inf'), diagonal=1)
|
| 93 |
+
mask = mask.to(dtype=x.dtype)
|
| 94 |
+
|
| 95 |
+
# Process through transformer blocks (use gradient checkpointing only if enabled)
|
| 96 |
+
if self.training and self.gradient_checkpointing_enable:
|
| 97 |
+
for block in self.transformer_blocks:
|
| 98 |
+
x = torch.utils.checkpoint.checkpoint(block, x, mask, use_reentrant=False)
|
| 99 |
+
else:
|
| 100 |
+
for block in self.transformer_blocks:
|
| 101 |
+
x = block(x, mask=mask)
|
| 102 |
+
|
| 103 |
+
x = self.ln_f(x)
|
| 104 |
+
logits = self.lm_head(x)
|
| 105 |
+
|
| 106 |
+
if labels is not None:
|
| 107 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 108 |
+
loss = loss_fct(logits.view(-1, self.vocab_size), labels.view(-1))
|
| 109 |
+
return {"loss": loss, "logits": logits}
|
| 110 |
+
|
| 111 |
+
return {"logits": logits}
|
| 112 |
+
|
| 113 |
+
def num_parameters(self):
|
| 114 |
+
"""Returns the number of trainable parameters in the model."""
|
| 115 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 116 |
+
|
| 117 |
+
def create_model(config):
|
| 118 |
+
"""Creates a custom language model from scratch based on the configuration."""
|
| 119 |
+
model = CustomLanguageModel(config)
|
| 120 |
+
return model
|
| 121 |
+
|
| 122 |
+
def get_tokenizer(config):
|
| 123 |
+
"""Loads a trained ByteLevelBPE tokenizer."""
|
| 124 |
+
from tokenizers import ByteLevelBPETokenizer
|
| 125 |
+
|
| 126 |
+
model_path = config["tokenizer"]["model_path"]
|
| 127 |
+
if not os.path.exists(os.path.join(model_path, "vocab.json")):
|
| 128 |
+
raise ValueError(f"No tokenizer found at {model_path}. Please train the tokenizer first.")
|
| 129 |
+
|
| 130 |
+
tokenizer = ByteLevelBPETokenizer(
|
| 131 |
+
os.path.join(model_path, "vocab.json"),
|
| 132 |
+
os.path.join(model_path, "merges.txt")
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Add special tokens if they don't exist
|
| 136 |
+
special_tokens = {
|
| 137 |
+
"eos_token": "<|endoftext|>",
|
| 138 |
+
"pad_token": "<|pad|>",
|
| 139 |
+
"unk_token": "<|unk|>",
|
| 140 |
+
"mask_token": "<|mask|>"
|
| 141 |
+
}
|
| 142 |
+
tokenizer.add_special_tokens(list(special_tokens.values()))
|
| 143 |
+
|
| 144 |
+
# Add methods to match expected interface
|
| 145 |
+
tokenizer.get_vocab_size = lambda: len(tokenizer.get_vocab())
|
| 146 |
+
|
| 147 |
+
def batch_encode(texts, padding=True, truncation=True, max_length=None, return_tensors=None):
|
| 148 |
+
encodings = tokenizer.encode_batch(texts)
|
| 149 |
+
# Extract token ids from encodings
|
| 150 |
+
token_ids = [enc.ids for enc in encodings]
|
| 151 |
+
|
| 152 |
+
if max_length and truncation:
|
| 153 |
+
token_ids = [ids[:max_length] for ids in token_ids]
|
| 154 |
+
|
| 155 |
+
if padding:
|
| 156 |
+
max_len = max(len(ids) for ids in token_ids)
|
| 157 |
+
pad_token_id = tokenizer.token_to_id("<|pad|>")
|
| 158 |
+
padded = []
|
| 159 |
+
for ids in token_ids:
|
| 160 |
+
pad_length = max_len - len(ids)
|
| 161 |
+
padded.append(ids + [pad_token_id] * pad_length)
|
| 162 |
+
token_ids = padded
|
| 163 |
+
|
| 164 |
+
if return_tensors == "pt":
|
| 165 |
+
return {
|
| 166 |
+
"input_ids": torch.tensor(token_ids),
|
| 167 |
+
"attention_mask": torch.ones_like(torch.tensor(token_ids))
|
| 168 |
+
}
|
| 169 |
+
return {"input_ids": token_ids}
|
| 170 |
+
|
| 171 |
+
tokenizer.batch_encode = batch_encode
|
| 172 |
+
|
| 173 |
+
print(f"ByteLevelBPE tokenizer loaded successfully. Vocab size: {tokenizer.get_vocab_size()}")
|
| 174 |
+
return tokenizer
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
config = load_config()
|
| 178 |
+
tokenizer = get_tokenizer(config)
|
| 179 |
+
config["model"]["vocab_size"] = tokenizer.get_vocab_size()
|
| 180 |
+
model = create_model(config)
|
| 181 |
+
print(f"Model created with {model.num_parameters():,} parameters.")
|
| 182 |
+
print(model)
|
src/tokenization.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, processors
|
| 2 |
+
from tokenizers.implementations import ByteLevelBPETokenizer
|
| 3 |
+
import os
|
| 4 |
+
from utils import load_config, setup_logging
|
| 5 |
+
from glob import glob
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
class CustomTokenizer:
|
| 11 |
+
"""Wrapper around ByteLevelBPETokenizer with additional functionality."""
|
| 12 |
+
def __init__(self, tokenizer):
|
| 13 |
+
self.tokenizer = tokenizer
|
| 14 |
+
self._vocab_size = len(tokenizer.get_vocab())
|
| 15 |
+
self.pad_token_id = tokenizer.token_to_id("<|pad|>")
|
| 16 |
+
self.eos_token_id = tokenizer.token_to_id("<|endoftext|>")
|
| 17 |
+
|
| 18 |
+
def get_vocab_size(self):
|
| 19 |
+
return self._vocab_size
|
| 20 |
+
|
| 21 |
+
def batch_encode(self, texts, padding=True, truncation=True, max_length=None, return_tensors=None):
|
| 22 |
+
encodings = self.tokenizer.encode_batch(texts)
|
| 23 |
+
if max_length and truncation:
|
| 24 |
+
encodings = [enc.ids[:max_length] for enc in encodings]
|
| 25 |
+
if padding:
|
| 26 |
+
max_len = max(len(enc.ids) for enc in encodings)
|
| 27 |
+
padded = []
|
| 28 |
+
for enc in encodings:
|
| 29 |
+
pad_length = max_len - len(enc.ids)
|
| 30 |
+
padded.append(enc.ids + [self.pad_token_id] * pad_length)
|
| 31 |
+
encodings = padded
|
| 32 |
+
if return_tensors == "pt":
|
| 33 |
+
return {
|
| 34 |
+
"input_ids": torch.tensor(encodings),
|
| 35 |
+
"attention_mask": torch.ones_like(torch.tensor(encodings))
|
| 36 |
+
}
|
| 37 |
+
return {"input_ids": encodings}
|
| 38 |
+
|
| 39 |
+
def decode(self, token_ids):
|
| 40 |
+
"""Decode a list of token IDs back to a string."""
|
| 41 |
+
if isinstance(token_ids, torch.Tensor):
|
| 42 |
+
token_ids = token_ids.tolist()
|
| 43 |
+
|
| 44 |
+
# Filter out padding tokens
|
| 45 |
+
token_ids = [t for t in token_ids if t != self.pad_token_id]
|
| 46 |
+
|
| 47 |
+
# Use the underlying tokenizer's decode method
|
| 48 |
+
return self.tokenizer.decode(token_ids)
|
| 49 |
+
|
| 50 |
+
def train_tokenizer(config):
|
| 51 |
+
"""Trains a custom BPE tokenizer using the tokenizers library."""
|
| 52 |
+
setup_logging()
|
| 53 |
+
|
| 54 |
+
model_path = config["tokenizer"]["model_path"]
|
| 55 |
+
vocab_size = config["tokenizer"].get("vocab_size", 50000)
|
| 56 |
+
min_frequency = config["tokenizer"].get("min_frequency", 2)
|
| 57 |
+
|
| 58 |
+
# Create output directory if it doesn't exist
|
| 59 |
+
os.makedirs(model_path, exist_ok=True)
|
| 60 |
+
|
| 61 |
+
# Initialize a new tokenizer
|
| 62 |
+
tokenizer = ByteLevelBPETokenizer()
|
| 63 |
+
|
| 64 |
+
# Get all text files from the data directory
|
| 65 |
+
data_files = glob(os.path.join("data/raw", "*.txt"))
|
| 66 |
+
if not data_files:
|
| 67 |
+
raise ValueError("No text files found in data/raw directory")
|
| 68 |
+
|
| 69 |
+
print(f"Training tokenizer on {len(data_files)} files...")
|
| 70 |
+
print(f"Target vocab size: {vocab_size}")
|
| 71 |
+
print(f"Min frequency: {min_frequency}")
|
| 72 |
+
|
| 73 |
+
# Train the tokenizer
|
| 74 |
+
tokenizer.train(
|
| 75 |
+
files=data_files,
|
| 76 |
+
vocab_size=vocab_size,
|
| 77 |
+
min_frequency=min_frequency,
|
| 78 |
+
special_tokens=[
|
| 79 |
+
"<|endoftext|>", # End of text token
|
| 80 |
+
"<|pad|>", # Padding token
|
| 81 |
+
"<|unk|>", # Unknown token
|
| 82 |
+
"<|mask|>" # Mask token
|
| 83 |
+
]
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Save the tokenizer files
|
| 87 |
+
tokenizer.save_model(model_path)
|
| 88 |
+
|
| 89 |
+
# Save the tokenizer configuration
|
| 90 |
+
tokenizer_config = {
|
| 91 |
+
"vocab_size": vocab_size,
|
| 92 |
+
"min_frequency": min_frequency,
|
| 93 |
+
"model_type": "byte_level_bpe",
|
| 94 |
+
"special_tokens": {
|
| 95 |
+
"eos_token": "<|endoftext|>",
|
| 96 |
+
"pad_token": "<|pad|>",
|
| 97 |
+
"unk_token": "<|unk|>",
|
| 98 |
+
"mask_token": "<|mask|>"
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
with open(os.path.join(model_path, "tokenizer_config.json"), "w") as f:
|
| 103 |
+
json.dump(tokenizer_config, f, indent=2)
|
| 104 |
+
|
| 105 |
+
print(f"Tokenizer trained and saved to {model_path}")
|
| 106 |
+
return tokenizer
|
| 107 |
+
|
| 108 |
+
def get_tokenizer(config):
|
| 109 |
+
"""Loads a trained tokenizer."""
|
| 110 |
+
model_path = config["tokenizer"]["model_path"]
|
| 111 |
+
|
| 112 |
+
if not os.path.exists(os.path.join(model_path, "vocab.json")):
|
| 113 |
+
raise ValueError(f"No tokenizer found at {model_path}. Please train the tokenizer first.")
|
| 114 |
+
|
| 115 |
+
base_tokenizer = ByteLevelBPETokenizer(
|
| 116 |
+
os.path.join(model_path, "vocab.json"),
|
| 117 |
+
os.path.join(model_path, "merges.txt")
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Add special tokens if they don't exist
|
| 121 |
+
special_tokens = {
|
| 122 |
+
"eos_token": "<|endoftext|>",
|
| 123 |
+
"pad_token": "<|pad|>",
|
| 124 |
+
"unk_token": "<|unk|>",
|
| 125 |
+
"mask_token": "<|mask|>"
|
| 126 |
+
}
|
| 127 |
+
base_tokenizer.add_special_tokens(list(special_tokens.values()))
|
| 128 |
+
|
| 129 |
+
# Create wrapped tokenizer
|
| 130 |
+
tokenizer = CustomTokenizer(base_tokenizer)
|
| 131 |
+
|
| 132 |
+
print(f"ByteLevelBPE tokenizer loaded successfully. Vocab size: {tokenizer.get_vocab_size()}")
|
| 133 |
+
return tokenizer
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
config = load_config()
|
| 137 |
+
train_tokenizer(config)
|
| 138 |
+
print("Tokenizer training complete.")
|
src/train.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.optim import AdamW
|
| 8 |
+
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from accelerate import Accelerator, DeepSpeedPlugin
|
| 11 |
+
from accelerate.logging import get_logger
|
| 12 |
+
import deepspeed
|
| 13 |
+
import wandb
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from transformers import get_scheduler
|
| 16 |
+
from model import create_model, get_tokenizer
|
| 17 |
+
from utils import load_config, setup_logging
|
| 18 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 19 |
+
|
| 20 |
+
logger = get_logger(__name__)
|
| 21 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 22 |
+
# Enable TF32 for faster matrix multiplications (if supported)
|
| 23 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 24 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 25 |
+
|
| 26 |
+
def load_text_files(data_dir, chunk_size=2000000):
|
| 27 |
+
"""Load text files from directory in chunks."""
|
| 28 |
+
if not os.path.exists(data_dir):
|
| 29 |
+
raise ValueError(f"Data directory {data_dir} does not exist")
|
| 30 |
+
|
| 31 |
+
all_files = [f for f in os.listdir(data_dir) if f.endswith('.txt')]
|
| 32 |
+
print(f"Found {len(all_files)} text files in {data_dir}")
|
| 33 |
+
|
| 34 |
+
total_size = sum(os.path.getsize(os.path.join(data_dir, f)) for f in all_files)
|
| 35 |
+
estimated_chunks = math.ceil(total_size / chunk_size)
|
| 36 |
+
total_characters = 0
|
| 37 |
+
current_chunk_num = 0
|
| 38 |
+
|
| 39 |
+
for file_name in all_files:
|
| 40 |
+
file_path = os.path.join(data_dir, file_name)
|
| 41 |
+
try:
|
| 42 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 43 |
+
file_size = os.path.getsize(file_path)
|
| 44 |
+
print(f"Processing file: {file_name} (Size: {file_size/1024/1024:.2f}MB)")
|
| 45 |
+
print(f"Estimated total chunks: {estimated_chunks}")
|
| 46 |
+
|
| 47 |
+
current_chunk = []
|
| 48 |
+
current_size = 0
|
| 49 |
+
chunk_start_char = total_characters
|
| 50 |
+
|
| 51 |
+
for line in f:
|
| 52 |
+
line = line.strip()
|
| 53 |
+
if line:
|
| 54 |
+
current_chunk.append(line)
|
| 55 |
+
current_size += len(line)
|
| 56 |
+
total_characters += len(line)
|
| 57 |
+
|
| 58 |
+
if current_size >= chunk_size:
|
| 59 |
+
current_chunk_num += 1
|
| 60 |
+
print(f"Yielding chunk {current_chunk_num}/{estimated_chunks} "
|
| 61 |
+
f"({len(current_chunk)} texts, {current_size:,} characters, "
|
| 62 |
+
f"Range: {chunk_start_char:,} - {total_characters:,})")
|
| 63 |
+
yield current_chunk
|
| 64 |
+
current_chunk = []
|
| 65 |
+
current_size = 0
|
| 66 |
+
chunk_start_char = total_characters
|
| 67 |
+
|
| 68 |
+
if current_chunk:
|
| 69 |
+
current_chunk_num += 1
|
| 70 |
+
print(f"Yielding final chunk {current_chunk_num}/{estimated_chunks} "
|
| 71 |
+
f"({len(current_chunk)} texts, {current_size:,} characters, "
|
| 72 |
+
f"Range: {chunk_start_char:,} - {total_characters:,})")
|
| 73 |
+
yield current_chunk
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"Error reading file {file_path}: {e}")
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
class TextDataset(Dataset):
|
| 79 |
+
def __init__(self, tokenized_texts):
|
| 80 |
+
self.input_ids = tokenized_texts["input_ids"]
|
| 81 |
+
self.labels = tokenized_texts["labels"]
|
| 82 |
+
|
| 83 |
+
def __len__(self):
|
| 84 |
+
return len(self.input_ids)
|
| 85 |
+
|
| 86 |
+
def __getitem__(self, idx):
|
| 87 |
+
return {"input_ids": self.input_ids[idx], "labels": self.labels[idx]}
|
| 88 |
+
|
| 89 |
+
class StreamingTextDataset(IterableDataset):
|
| 90 |
+
def __init__(self, data_dir, tokenizer, max_length):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.data_dir = data_dir
|
| 93 |
+
self.tokenizer = tokenizer
|
| 94 |
+
self.max_length = max_length
|
| 95 |
+
self.files = [f for f in os.listdir(data_dir) if f.endswith('.txt')]
|
| 96 |
+
|
| 97 |
+
def __iter__(self):
|
| 98 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 99 |
+
files_per_worker = len(self.files)
|
| 100 |
+
if worker_info is not None:
|
| 101 |
+
files_per_worker = len(self.files) // worker_info.num_workers
|
| 102 |
+
start_idx = worker_info.id * files_per_worker
|
| 103 |
+
end_idx = start_idx + files_per_worker if worker_info.id < worker_info.num_workers - 1 else len(self.files)
|
| 104 |
+
files = self.files[start_idx:end_idx]
|
| 105 |
+
else:
|
| 106 |
+
files = self.files
|
| 107 |
+
|
| 108 |
+
for file_name in files:
|
| 109 |
+
file_path = os.path.join(self.data_dir, file_name)
|
| 110 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 111 |
+
text_buffer = []
|
| 112 |
+
current_length = 0
|
| 113 |
+
|
| 114 |
+
for line in f:
|
| 115 |
+
line = line.strip()
|
| 116 |
+
if not line:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
text_buffer.append(line)
|
| 120 |
+
current_length += len(line)
|
| 121 |
+
|
| 122 |
+
if current_length >= self.max_length:
|
| 123 |
+
# Encode and yield the batch
|
| 124 |
+
text = " ".join(text_buffer)
|
| 125 |
+
encodings = self.tokenizer.batch_encode(
|
| 126 |
+
[text],
|
| 127 |
+
max_length=self.max_length,
|
| 128 |
+
truncation=True,
|
| 129 |
+
padding=False, # Don't pad here, we'll pad in collate_fn
|
| 130 |
+
return_tensors="pt"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Return individual tensors
|
| 134 |
+
yield {
|
| 135 |
+
"input_ids": encodings["input_ids"][0],
|
| 136 |
+
"labels": encodings["input_ids"][0].clone()
|
| 137 |
+
}
|
| 138 |
+
text_buffer = []
|
| 139 |
+
current_length = 0
|
| 140 |
+
|
| 141 |
+
def collate_batch(batch):
|
| 142 |
+
"""Custom collate function to handle variable length sequences."""
|
| 143 |
+
# Separate input_ids and labels
|
| 144 |
+
input_ids = [item["input_ids"] for item in batch]
|
| 145 |
+
labels = [item["labels"] for item in batch]
|
| 146 |
+
|
| 147 |
+
# Pad sequences
|
| 148 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
| 149 |
+
labels = pad_sequence(labels, batch_first=True, padding_value=-100) # -100 is PyTorch's default ignore index
|
| 150 |
+
|
| 151 |
+
# Create attention masks
|
| 152 |
+
attention_mask = (input_ids != 0).long()
|
| 153 |
+
|
| 154 |
+
return {
|
| 155 |
+
"input_ids": input_ids,
|
| 156 |
+
"labels": labels,
|
| 157 |
+
"attention_mask": attention_mask
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
def train_model(config):
|
| 161 |
+
"""Trains the model using DeepSpeed and Accelerate for memory efficiency."""
|
| 162 |
+
# Create output directory
|
| 163 |
+
output_dir = config["training"]["output_dir"]
|
| 164 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 165 |
+
print(f"Model will be saved to: {output_dir}")
|
| 166 |
+
|
| 167 |
+
# Initialize DeepSpeed plugin and accelerator
|
| 168 |
+
deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=config["training"]["deepspeed"])
|
| 169 |
+
accelerator = Accelerator(
|
| 170 |
+
gradient_accumulation_steps=config["training"]["gradient_accumulation_steps"],
|
| 171 |
+
mixed_precision="fp16",
|
| 172 |
+
deepspeed_plugin=deepspeed_plugin,
|
| 173 |
+
log_with=config["training"]["report_to"]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Initialize tracking
|
| 177 |
+
if accelerator.is_main_process:
|
| 178 |
+
accelerator.init_trackers(
|
| 179 |
+
project_name=config["training"]["wandb"]["project"],
|
| 180 |
+
config=config,
|
| 181 |
+
init_kwargs={
|
| 182 |
+
"wandb": {
|
| 183 |
+
"entity": config["training"]["wandb"]["entity"],
|
| 184 |
+
"name": config["training"]["wandb"]["name"],
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
)
|
| 188 |
+
print(f"Tracking initialized with {config['training']['report_to']}")
|
| 189 |
+
|
| 190 |
+
device = accelerator.device
|
| 191 |
+
print(f"Using device: {device}")
|
| 192 |
+
|
| 193 |
+
# Load tokenizer and model
|
| 194 |
+
tokenizer = get_tokenizer(config)
|
| 195 |
+
config["model"]["vocab_size"] = tokenizer.get_vocab_size()
|
| 196 |
+
model = create_model(config)
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
model = torch.compile(model)
|
| 200 |
+
print("torch.compile enabled for faster training.")
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print("torch.compile not available or failed, continuing without it.")
|
| 203 |
+
|
| 204 |
+
optimizer = AdamW(
|
| 205 |
+
model.parameters(),
|
| 206 |
+
lr=config["training"]["learning_rate"],
|
| 207 |
+
weight_decay=config["training"]["weight_decay"]
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Create streaming dataset with custom collate function
|
| 211 |
+
dataset = StreamingTextDataset(
|
| 212 |
+
data_dir="data/raw",
|
| 213 |
+
tokenizer=tokenizer,
|
| 214 |
+
max_length=config["dataset"]["max_length"]
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
train_loader = DataLoader(
|
| 218 |
+
dataset,
|
| 219 |
+
batch_size=config["training"]["per_device_train_batch_size"],
|
| 220 |
+
num_workers=config["training"]["dataloader_num_workers"],
|
| 221 |
+
pin_memory=True,
|
| 222 |
+
collate_fn=collate_batch # Add custom collate function
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Prepare for distributed training
|
| 226 |
+
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
|
| 227 |
+
|
| 228 |
+
# Calculate approximate steps per epoch based on target dataset size
|
| 229 |
+
avg_seq_length = config["dataset"]["max_length"] // 2 # Average sequence length
|
| 230 |
+
batch_size = config["training"]["per_device_train_batch_size"]
|
| 231 |
+
target_size_gb = config["dataset"].get("target_size_gb", 2.5)
|
| 232 |
+
chars_per_token = 4
|
| 233 |
+
total_tokens = (target_size_gb * 1024 * 1024 * 1024) // chars_per_token
|
| 234 |
+
steps_per_epoch = int(total_tokens // (avg_seq_length * batch_size)) # Convert to int
|
| 235 |
+
total_epochs = config["training"]["num_train_epochs"]
|
| 236 |
+
total_steps = int(steps_per_epoch * total_epochs) # Convert to int
|
| 237 |
+
|
| 238 |
+
print(f"\nTraining Statistics (Estimated):")
|
| 239 |
+
print(f"Total epochs: {total_epochs}")
|
| 240 |
+
print(f"Estimated steps per epoch: {steps_per_epoch:,}")
|
| 241 |
+
print(f"Estimated total steps: {total_steps:,}")
|
| 242 |
+
|
| 243 |
+
# Track gradients for logging
|
| 244 |
+
def grad_norm(model):
|
| 245 |
+
total_norm = 0.0
|
| 246 |
+
for p in model.parameters():
|
| 247 |
+
if p.grad is not None:
|
| 248 |
+
param_norm = p.grad.detach().data.norm(2)
|
| 249 |
+
total_norm += param_norm.item() ** 2
|
| 250 |
+
return total_norm ** 0.5
|
| 251 |
+
|
| 252 |
+
# Initialize GPU monitoring
|
| 253 |
+
if torch.cuda.is_available():
|
| 254 |
+
gpu_id = torch.cuda.current_device()
|
| 255 |
+
|
| 256 |
+
training_stats = {
|
| 257 |
+
'train/loss': 0.0,
|
| 258 |
+
'train/learning_rate': 0.0,
|
| 259 |
+
'train/epoch': 0.0,
|
| 260 |
+
'train/global_step': 0,
|
| 261 |
+
'train/samples_per_second': 0.0,
|
| 262 |
+
'train/grad_norm': 0.0,
|
| 263 |
+
'performance/gpu_memory': 0.0,
|
| 264 |
+
'performance/gpu_utilization': 0.0,
|
| 265 |
+
'performance/batch_time': 0.0,
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
for epoch in range(total_epochs):
|
| 269 |
+
epoch_start_time = time.time()
|
| 270 |
+
model.train()
|
| 271 |
+
running_loss = 0
|
| 272 |
+
num_batches = 0
|
| 273 |
+
samples_processed = 0
|
| 274 |
+
|
| 275 |
+
progress_bar = tqdm(
|
| 276 |
+
total=steps_per_epoch,
|
| 277 |
+
desc=f"Epoch {epoch+1}/{total_epochs}",
|
| 278 |
+
disable=not accelerator.is_local_main_process
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
for batch in train_loader:
|
| 282 |
+
batch_start_time = time.time()
|
| 283 |
+
|
| 284 |
+
with accelerator.accumulate(model):
|
| 285 |
+
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
|
| 286 |
+
loss = outputs["loss"]
|
| 287 |
+
accelerator.backward(loss)
|
| 288 |
+
|
| 289 |
+
if accelerator.sync_gradients:
|
| 290 |
+
training_stats['train/grad_norm'] = grad_norm(model)
|
| 291 |
+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
| 292 |
+
|
| 293 |
+
optimizer.step()
|
| 294 |
+
optimizer.zero_grad()
|
| 295 |
+
|
| 296 |
+
# Update statistics
|
| 297 |
+
loss_value = loss.item()
|
| 298 |
+
running_loss += loss_value
|
| 299 |
+
num_batches += 1
|
| 300 |
+
samples_processed += batch["input_ids"].size(0)
|
| 301 |
+
batch_time = time.time() - batch_start_time
|
| 302 |
+
|
| 303 |
+
# Update training stats
|
| 304 |
+
training_stats.update({
|
| 305 |
+
'train/loss': loss_value,
|
| 306 |
+
'train/learning_rate': optimizer.param_groups[0]['lr'],
|
| 307 |
+
'train/epoch': epoch + 1,
|
| 308 |
+
'train/global_step': num_batches + (epoch * steps_per_epoch),
|
| 309 |
+
'train/samples_per_second': batch["input_ids"].size(0) / batch_time,
|
| 310 |
+
'performance/batch_time': batch_time,
|
| 311 |
+
})
|
| 312 |
+
|
| 313 |
+
# GPU stats (if available)
|
| 314 |
+
if torch.cuda.is_available():
|
| 315 |
+
training_stats.update({
|
| 316 |
+
'performance/gpu_memory': torch.cuda.memory_allocated(gpu_id) / 1024**3, # GB
|
| 317 |
+
'performance/gpu_utilization': torch.cuda.utilization(gpu_id),
|
| 318 |
+
})
|
| 319 |
+
|
| 320 |
+
# Update progress bar
|
| 321 |
+
avg_speed = num_batches / (time.time() - epoch_start_time)
|
| 322 |
+
eta_epoch = (steps_per_epoch - num_batches) / avg_speed / 60 # minutes
|
| 323 |
+
eta_total = (total_steps - (epoch * steps_per_epoch + num_batches)) / avg_speed / 60 # minutes
|
| 324 |
+
|
| 325 |
+
progress_bar.set_postfix({
|
| 326 |
+
'loss': f'{loss_value:.4f}',
|
| 327 |
+
'avg_loss': f'{running_loss/num_batches:.4f}',
|
| 328 |
+
'lr': f'{optimizer.param_groups[0]["lr"]:.2e}',
|
| 329 |
+
'samples/s': f'{training_stats["train/samples_per_second"]:.2f}',
|
| 330 |
+
'epoch_eta': f'{eta_epoch:.1f}min',
|
| 331 |
+
'total_eta': f'{eta_total:.1f}min'
|
| 332 |
+
})
|
| 333 |
+
progress_bar.update(1)
|
| 334 |
+
|
| 335 |
+
# Log metrics based on logging_steps
|
| 336 |
+
if num_batches % config["training"]["logging_steps"] == 0:
|
| 337 |
+
if accelerator.is_main_process:
|
| 338 |
+
current_step = int(num_batches + (epoch * steps_per_epoch)) # Convert to int
|
| 339 |
+
accelerator.log(training_stats, step=current_step)
|
| 340 |
+
|
| 341 |
+
# Save checkpoint based on save_steps
|
| 342 |
+
if num_batches % config["training"]["save_steps"] == 0:
|
| 343 |
+
if accelerator.is_local_main_process:
|
| 344 |
+
checkpoint_dir = os.path.join(output_dir, f"checkpoint-epoch{epoch+1}-step{num_batches}")
|
| 345 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 346 |
+
print(f"\nSaving checkpoint at step {num_batches} to {checkpoint_dir}")
|
| 347 |
+
accelerator.save_state(checkpoint_dir)
|
| 348 |
+
with open(os.path.join(checkpoint_dir, "config.json"), "w") as f:
|
| 349 |
+
json.dump(config, f, indent=2)
|
| 350 |
+
|
| 351 |
+
# Break if we've reached the estimated steps for this epoch
|
| 352 |
+
if num_batches >= steps_per_epoch:
|
| 353 |
+
break
|
| 354 |
+
|
| 355 |
+
progress_bar.close()
|
| 356 |
+
|
| 357 |
+
# End of epoch logging
|
| 358 |
+
epoch_time = time.time() - epoch_start_time
|
| 359 |
+
epoch_avg_loss = running_loss / num_batches
|
| 360 |
+
epoch_perplexity = torch.exp(torch.tensor(epoch_avg_loss))
|
| 361 |
+
|
| 362 |
+
if accelerator.is_main_process:
|
| 363 |
+
print(f"\nEpoch {epoch+1}/{total_epochs} Summary:")
|
| 364 |
+
print(f"Time: {epoch_time/60:.2f} minutes")
|
| 365 |
+
print(f"Average Loss: {epoch_avg_loss:.4f}")
|
| 366 |
+
print(f"Perplexity: {epoch_perplexity:.2f}")
|
| 367 |
+
print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
|
| 368 |
+
print(f"Samples Processed: {samples_processed:,}")
|
| 369 |
+
print(f"Average Speed: {samples_processed/epoch_time:.1f} samples/s")
|
| 370 |
+
|
| 371 |
+
# Estimate remaining time
|
| 372 |
+
epochs_remaining = total_epochs - (epoch + 1)
|
| 373 |
+
estimated_remaining_time = epochs_remaining * epoch_time / 60
|
| 374 |
+
print(f"Estimated time for remaining {epochs_remaining} epochs: {estimated_remaining_time:.1f} minutes")
|
| 375 |
+
|
| 376 |
+
# Log epoch summary to wandb with correct step
|
| 377 |
+
current_step = int((epoch + 1) * steps_per_epoch) # Convert to int
|
| 378 |
+
accelerator.log({
|
| 379 |
+
'epoch/average_loss': epoch_avg_loss,
|
| 380 |
+
'epoch/perplexity': epoch_perplexity.item(),
|
| 381 |
+
'epoch/time': epoch_time,
|
| 382 |
+
'epoch/samples_processed': samples_processed,
|
| 383 |
+
}, step=current_step)
|
| 384 |
+
|
| 385 |
+
# Save final model
|
| 386 |
+
if accelerator.is_local_main_process:
|
| 387 |
+
final_model_dir = os.path.join(output_dir, "final_model")
|
| 388 |
+
os.makedirs(final_model_dir, exist_ok=True)
|
| 389 |
+
print(f"\nSaving final model to {final_model_dir}")
|
| 390 |
+
|
| 391 |
+
# Save with DeepSpeed
|
| 392 |
+
accelerator.save_state(final_model_dir)
|
| 393 |
+
|
| 394 |
+
# Save configuration
|
| 395 |
+
with open(os.path.join(final_model_dir, "config.json"), "w") as f:
|
| 396 |
+
json.dump(config, f, indent=2)
|
| 397 |
+
|
| 398 |
+
print("Final model saved successfully")
|
| 399 |
+
accelerator.end_training()
|
| 400 |
+
|
| 401 |
+
if __name__ == "__main__":
|
| 402 |
+
config = load_config()
|
| 403 |
+
train_model(config)
|
| 404 |
+
print("Training complete.")
|
src/utils.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
def load_config(config_path="config/config.yaml"):
|
| 5 |
+
"""Loads the configuration from a YAML file."""
|
| 6 |
+
with open(config_path, "r") as f:
|
| 7 |
+
return yaml.safe_load(f)
|
| 8 |
+
|
| 9 |
+
def setup_logging(log_level="INFO"):
|
| 10 |
+
"""Sets up basic logging."""
|
| 11 |
+
import logging
|
| 12 |
+
logging.basicConfig(
|
| 13 |
+
level=getattr(logging, log_level.upper(), logging.INFO),
|
| 14 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 15 |
+
)
|