diabolic6045 commited on
Commit
eb05668
·
verified ·
1 Parent(s): 4b0778b

Upload 9 files

Browse files
src/convert_to_hf.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import argparse
4
+ import yaml
5
+ import json
6
+ import torch
7
+ import shutil
8
+ import tiktoken
9
+ from model import create_model # your model creation function from model.py
10
+
11
+ def load_config(config_path):
12
+ """Load the training configuration from a YAML file."""
13
+ with open(config_path, "r") as f:
14
+ return yaml.safe_load(f)
15
+
16
+ def load_tokenizer_encoding(tokenizer_dir):
17
+ """Reads the encoding name from encoding_config.txt in your tokenizer directory."""
18
+ encoding_config_path = os.path.join(tokenizer_dir, "encoding_config.txt")
19
+ if not os.path.exists(encoding_config_path):
20
+ raise FileNotFoundError(f"Encoding config not found at {encoding_config_path}")
21
+
22
+ with open(encoding_config_path, "r") as f:
23
+ content = f.read().strip()
24
+ # Expect a line like: "encoding_name: cl100k_base"
25
+ if ":" not in content:
26
+ raise ValueError(f"Invalid encoding config format: {content}")
27
+
28
+ _, encoding_name = content.split(":", 1)
29
+ return encoding_name.strip()
30
+
31
+ def get_tokenizer(encoding_name):
32
+ """Initialize tiktoken encoding."""
33
+ tokenizer = tiktoken.get_encoding(encoding_name)
34
+ return tokenizer
35
+
36
+ def load_state_dict(checkpoint_dir):
37
+ """
38
+ Loads the model state dict from a DeepSpeed checkpoint.
39
+ First tries to load a consolidated checkpoint, then attempts to convert from ZeRO format.
40
+ """
41
+ # First try loading from converted_model directory
42
+ converted_path = os.path.join(checkpoint_dir, "converted_model", "pytorch_model.bin")
43
+ if os.path.exists(converted_path):
44
+ print(f"Loading converted checkpoint from {converted_path}")
45
+ state_dict = torch.load(converted_path, map_location="cpu")
46
+
47
+ # Remove "_orig_mod." prefix from keys if present
48
+ if all(k.startswith("_orig_mod.") for k in state_dict.keys()):
49
+ print("Removing '_orig_mod.' prefix from state dict keys")
50
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
51
+
52
+ return state_dict
53
+
54
+ # Try loading consolidated checkpoint from main directory
55
+ consolidated_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
56
+ if os.path.exists(consolidated_path):
57
+ print(f"Loading consolidated checkpoint from {consolidated_path}")
58
+ state_dict = torch.load(consolidated_path, map_location="cpu")
59
+
60
+ # Remove "_orig_mod." prefix from keys if present
61
+ if all(k.startswith("_orig_mod.") for k in state_dict.keys()):
62
+ print("Removing '_orig_mod.' prefix from state dict keys")
63
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
64
+
65
+ return state_dict
66
+
67
+ # If no consolidated checkpoint exists, try converting from ZeRO format
68
+ print("No consolidated checkpoint found. Converting from ZeRO format...")
69
+
70
+ # Import the zero_to_fp32 module from the checkpoint directory
71
+ import sys
72
+ sys.path.append(checkpoint_dir)
73
+ from zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
74
+
75
+ try:
76
+ # Convert ZeRO checkpoint to consolidated checkpoint
77
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, exclude_frozen_parameters=False)
78
+
79
+ if state_dict is None:
80
+ raise ValueError("Failed to convert ZeRO checkpoint")
81
+
82
+ # Remove "_orig_mod." prefix from keys if present
83
+ if all(k.startswith("_orig_mod.") for k in state_dict.keys()):
84
+ print("Removing '_orig_mod.' prefix from state dict keys")
85
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
86
+
87
+ print("Successfully converted ZeRO checkpoint to consolidated format")
88
+ return state_dict
89
+
90
+ except Exception as e:
91
+ print(f"Error converting ZeRO checkpoint: {str(e)}")
92
+ raise
93
+
94
+ def convert_to_hf(checkpoint_dir, tokenizer_dir, config_path, output_dir):
95
+ # Load configurations
96
+ config = load_config(config_path)
97
+
98
+ # Set up tokenizer
99
+ encoding_name = load_tokenizer_encoding(tokenizer_dir)
100
+ tokenizer = get_tokenizer(encoding_name)
101
+ vocab_size = tokenizer.n_vocab
102
+ print(f"Using tokenizer encoding: {encoding_name} (vocab size: {vocab_size})")
103
+
104
+ # Update config with correct vocab size
105
+ config["model"]["vocab_size"] = vocab_size
106
+
107
+ # Create model and load weights
108
+ model = create_model(config)
109
+ state_dict = load_state_dict(checkpoint_dir)
110
+ model.load_state_dict(state_dict)
111
+ model.eval()
112
+
113
+ # Create output directory
114
+ os.makedirs(output_dir, exist_ok=True)
115
+
116
+ # 1. Save model weights
117
+ model_path = os.path.join(output_dir, "pytorch_model.bin")
118
+ torch.save(model.state_dict(), model_path)
119
+ print(f"Saved model weights to {model_path}")
120
+
121
+ # 2. Save model config
122
+ model_config = {
123
+ "architectures": ["CustomLanguageModel"],
124
+ "model_type": "custom-gpt",
125
+ "vocab_size": vocab_size,
126
+ "n_positions": config["model"]["n_positions"],
127
+ "n_embd": config["model"]["n_embd"],
128
+ "n_layer": config["model"]["n_layer"],
129
+ "n_head": config["model"]["n_head"],
130
+ "bos_token_id": None,
131
+ "eos_token_id": tokenizer.eot_token,
132
+ "tie_word_embeddings": True,
133
+ "gradient_checkpointing": config["model"].get("gradient_checkpointing", False)
134
+ }
135
+
136
+ config_path = os.path.join(output_dir, "config.json")
137
+ with open(config_path, "w") as f:
138
+ json.dump(model_config, f, indent=2)
139
+ print(f"Saved model config to {config_path}")
140
+
141
+ # 3. Save tokenizer config
142
+ tokenizer_config = {
143
+ "model_type": "tiktoken",
144
+ "encoding_name": encoding_name,
145
+ "vocab_size": vocab_size,
146
+ "max_length": config["dataset"]["max_length"],
147
+ "padding_side": "right",
148
+ "truncation_side": "right",
149
+ "bos_token": "<|endoftext|>",
150
+ "eos_token": "<|endoftext|>",
151
+ "unk_token": "<|endoftext|>",
152
+ "pad_token": "<|endoftext|>"
153
+ }
154
+
155
+ tokenizer_config_path = os.path.join(output_dir, "tokenizer_config.json")
156
+ with open(tokenizer_config_path, "w") as f:
157
+ json.dump(tokenizer_config, f, indent=2)
158
+ print(f"Saved tokenizer config to {tokenizer_config_path}")
159
+
160
+ # 4. Copy tokenizer files
161
+ src_encoding_config = os.path.join(tokenizer_dir, "encoding_config.txt")
162
+ if os.path.exists(src_encoding_config):
163
+ dst_encoding_config = os.path.join(output_dir, "encoding_config.txt")
164
+ shutil.copy2(src_encoding_config, dst_encoding_config)
165
+ print(f"Copied encoding config to {dst_encoding_config}")
166
+
167
+ print(f"\nConversion complete! HuggingFace model saved to: {output_dir}")
168
+
169
+ def main():
170
+ parser = argparse.ArgumentParser(description="Convert DeepSpeed checkpoint to HuggingFace format")
171
+ parser.add_argument("--checkpoint_dir", required=True,
172
+ help="Path to the checkpoint directory")
173
+ parser.add_argument("--tokenizer_dir", required=True,
174
+ help="Path to the tokenizer directory")
175
+ parser.add_argument("--config", default="config/config.yaml",
176
+ help="Path to the training config.yaml file")
177
+ parser.add_argument("--output_dir", required=True,
178
+ help="Output directory for HuggingFace model")
179
+
180
+ args = parser.parse_args()
181
+ convert_to_hf(args.checkpoint_dir, args.tokenizer_dir, args.config, args.output_dir)
182
+
183
+ if __name__ == "__main__":
184
+ main()
src/data_pre_to_raw.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ from pathlib import Path
4
+
5
+ def convert_to_raw_text():
6
+ # Setup paths
7
+ processed_dir = Path("data/processed")
8
+ raw_dir = Path("data/raw")
9
+
10
+ # Create raw directory if it doesn't exist
11
+ raw_dir.mkdir(parents=True, exist_ok=True)
12
+
13
+ # Output file for combined raw text
14
+ output_file = raw_dir / "combined_raw_text.txt"
15
+
16
+ # Process all txt files in processed directory
17
+ processed_files = list(processed_dir.glob("*.txt"))
18
+
19
+ print(f"Found {len(processed_files)} files to process")
20
+
21
+ with open(output_file, 'w', encoding='utf-8') as outfile:
22
+ for proc_file in tqdm(processed_files, desc="Converting files"):
23
+ with open(proc_file, 'r', encoding='utf-8') as infile:
24
+ for line in infile:
25
+ # Skip metadata lines (starting with #)
26
+ if not line.startswith('#'):
27
+ # Only write non-empty lines
28
+ line = line.strip()
29
+ if line:
30
+ outfile.write(line + '\n')
31
+
32
+ if __name__ == "__main__":
33
+ try:
34
+ convert_to_raw_text()
35
+ print("Successfully converted processed data to raw text")
36
+ except Exception as e:
37
+ print(f"Error during conversion: {e}")
src/data_processing.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from tqdm import tqdm
3
+ import os
4
+ from utils import load_config, setup_logging
5
+ import psutil # For monitoring memory usage
6
+
7
+
8
+ def download_and_process_data(config):
9
+ """Downloads, preprocesses, and saves the dataset."""
10
+ setup_logging()
11
+
12
+ dataset_name = config["dataset"]["name"]
13
+ streaming = config["dataset"]["streaming"]
14
+ text_column = config["dataset"]["text_column"]
15
+ target_size_gb = config["dataset"]["target_size_gb"]
16
+ max_length = config["dataset"]["max_length"]
17
+ subset = config["dataset"]["subset"]
18
+
19
+
20
+
21
+ # Download dataset (streaming is essential for large datasets)
22
+ try:
23
+ dataset = load_dataset(dataset_name, subset,streaming=streaming)
24
+ if not streaming:
25
+ raise ValueError("Streaming must be True for large datasets like fineweb")
26
+ except Exception as e:
27
+ raise Exception(f"Failed to download dataset: {e}. Check dataset name and internet connection, and HF login.")
28
+
29
+ # Filter data - removing the subset filter since it's specific to CC-MAIN
30
+ dataset = dataset["train"] # Taking only train split
31
+
32
+ # Add basic quality filters
33
+ def quality_filter(example):
34
+ return (
35
+ example['text'] is not None and
36
+ len(example['text'].strip()) > 0 and
37
+ example['language'] == 'en' and # Filter for English content
38
+ example['language_score'] >= 0.8 # High confidence in language detection
39
+ )
40
+
41
+ dataset = dataset.filter(quality_filter)
42
+
43
+ # Create output directory if it doesn't exist
44
+ output_dir = os.path.join("data", "processed")
45
+ os.makedirs(output_dir, exist_ok=True)
46
+
47
+ # Process and save in chunks, monitoring data size
48
+ def process_and_save_chunk(chunk, chunk_num, total_bytes):
49
+ output_file = os.path.join(output_dir, f"processed_data_{chunk_num}.txt")
50
+
51
+ with open(output_file, "w", encoding="utf-8") as f:
52
+ for example in tqdm(chunk, desc=f"Processing chunk {chunk_num}"):
53
+ text = example[text_column].strip()
54
+ if text:
55
+ # Add metadata as a comment before each text
56
+ metadata = f"# ID: {example['id']} | URL: {example['url']} | Date: {example['date']}\n"
57
+ f.write(metadata)
58
+ f.write(text + "\n\n") # Add extra newline for separation
59
+ total_bytes += len(text.encode("utf-8")) + len(metadata.encode("utf-8"))
60
+ return total_bytes
61
+
62
+ chunk_num = 0
63
+ chunk = []
64
+ total_bytes_processed = 0
65
+ target_bytes = target_size_gb * (1024**3) # Convert GB to bytes
66
+
67
+ for example in tqdm(dataset, desc="Processing and saving data"):
68
+ chunk.append(example)
69
+ if len(chunk) >= 10000: # Adjust chunk size as needed
70
+ total_bytes_processed = process_and_save_chunk(chunk, chunk_num, total_bytes_processed)
71
+ chunk = []
72
+ chunk_num += 1
73
+ print(f"Processed: {total_bytes_processed / (1024**3):.2f} GB")
74
+
75
+ if total_bytes_processed >= target_bytes:
76
+ print("Target data size reached.")
77
+ break # Stop processing
78
+
79
+ if chunk:
80
+ process_and_save_chunk(chunk, chunk_num,total_bytes_processed) #for remaining data
81
+
82
+ print(f"Data download and processing complete. Total processed size: {total_bytes_processed / (1024**3):.2f} GB")
83
+
84
+ if __name__ == "__main__":
85
+ config = load_config()
86
+ download_and_process_data(config)
src/hf_inference.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, PretrainedConfig
3
+ from utils import load_config
4
+ from tokenization import get_tokenizer
5
+
6
+ class CustomConfig(PretrainedConfig):
7
+ """Configuration class for the custom language model."""
8
+ model_type = "custom_llm"
9
+
10
+ def __init__(
11
+ self,
12
+ vocab_size: int = 50000,
13
+ n_embd: int = 640,
14
+ n_head: int = 10,
15
+ n_layer: int = 12,
16
+ n_positions: int = 512,
17
+ tie_word_embeddings: bool = True,
18
+ **kwargs
19
+ ):
20
+ self.vocab_size = vocab_size
21
+ self.n_embd = n_embd
22
+ self.n_head = n_head
23
+ self.n_layer = n_layer
24
+ self.n_positions = n_positions
25
+ self.tie_word_embeddings = tie_word_embeddings
26
+ super().__init__(**kwargs)
27
+
28
+ def generate_text(
29
+ prompt: str,
30
+ model_path: str = "outputs/hf_model",
31
+ max_length: int = 200,
32
+ temperature: float = 0.8,
33
+ top_k: int = 50,
34
+ top_p: float = 0.9,
35
+ repetition_penalty: float = 1.2,
36
+ no_repeat_ngram_size: int = 3
37
+ ):
38
+ """Generate text using the model."""
39
+ # Load config and tokenizer
40
+ config = load_config()
41
+ tokenizer = get_tokenizer(config)
42
+
43
+ # Load model
44
+ from inference import CustomModelForCausalLM # Import here to avoid circular imports
45
+ model = CustomModelForCausalLM.from_pretrained(model_path)
46
+
47
+ # Move model to GPU if available
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ model = model.to(device)
50
+ model.eval()
51
+
52
+ # Encode prompt
53
+ encoded = tokenizer.batch_encode(
54
+ [prompt],
55
+ return_tensors="pt"
56
+ )
57
+ input_ids = encoded["input_ids"].to(device)
58
+
59
+ # Generate
60
+ with torch.no_grad():
61
+ output_ids = model.generate(
62
+ input_ids=input_ids,
63
+ max_length=max_length,
64
+ temperature=temperature,
65
+ top_k=top_k,
66
+ top_p=top_p,
67
+ repetition_penalty=repetition_penalty,
68
+ no_repeat_ngram_size=no_repeat_ngram_size
69
+ )
70
+
71
+ # Decode and return
72
+ generated_text = tokenizer.decode(output_ids[0].tolist())
73
+ return generated_text
74
+
75
+ if __name__ == "__main__":
76
+ # Example prompts to test
77
+ prompts = [
78
+ "Once upon a time",
79
+ "The meaning of life is",
80
+ "In the distant future",
81
+ "The best way to learn programming is",
82
+ "Today I learned that"
83
+ ]
84
+
85
+ print("\nGenerating text from multiple prompts:")
86
+ print("=" * 50)
87
+
88
+ for prompt in prompts:
89
+ generated_text = generate_text(
90
+ prompt=prompt,
91
+ max_length=200,
92
+ temperature=0.8, # Adjust for creativity (higher = more creative)
93
+ top_k=50, # Limit to top 50 tokens
94
+ top_p=0.9, # Nucleus sampling threshold
95
+ repetition_penalty=1.2, # Penalize repetition
96
+ no_repeat_ngram_size=3 # Prevent 3-gram repetition
97
+ )
98
+
99
+ print(f"\nPrompt: {prompt}")
100
+ print(f"Generated: {generated_text}")
101
+ print("-" * 50)
src/inference.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PretrainedConfig
3
+ from typing import Optional, Tuple, Union, List
4
+ import os
5
+ import json
6
+ from model import CustomLanguageModel
7
+ from utils import load_config
8
+ from tokenization import get_tokenizer
9
+ import torch.nn as nn
10
+
11
+ class CustomConfig(PretrainedConfig):
12
+ """Configuration class for the custom language model."""
13
+ model_type = "custom_llm"
14
+
15
+ def __init__(
16
+ self,
17
+ vocab_size: int = 50000,
18
+ n_embd: int = 768,
19
+ n_head: int = 12,
20
+ n_layer: int = 12,
21
+ n_positions: int = 2048,
22
+ tie_word_embeddings: bool = False,
23
+ **kwargs
24
+ ):
25
+ self.vocab_size = vocab_size
26
+ self.n_embd = n_embd
27
+ self.n_head = n_head
28
+ self.n_layer = n_layer
29
+ self.n_positions = n_positions
30
+ self.tie_word_embeddings = tie_word_embeddings
31
+ super().__init__(**kwargs)
32
+
33
+ class CustomModelForCausalLM(PreTrainedModel):
34
+ """Wrapper class to make the model compatible with Hugging Face's interface."""
35
+ config_class = CustomConfig
36
+ supports_gradient_checkpointing = True
37
+
38
+ def __init__(self, config):
39
+ super().__init__(config)
40
+ # Convert config to dictionary format expected by CustomLanguageModel
41
+ model_config = {
42
+ "model": {
43
+ "vocab_size": config.vocab_size,
44
+ "n_embd": config.n_embd,
45
+ "n_head": config.n_head,
46
+ "n_layer": config.n_layer,
47
+ "n_positions": config.n_positions,
48
+ }
49
+ }
50
+ self.transformer = CustomLanguageModel(model_config)
51
+
52
+ # Explicitly create separate weights for lm_head
53
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
54
+
55
+ # Initialize weights
56
+ self.post_init()
57
+
58
+ def forward(
59
+ self,
60
+ input_ids: torch.LongTensor,
61
+ attention_mask: Optional[torch.Tensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ **kwargs
64
+ ):
65
+ outputs = self.transformer(input_ids=input_ids, labels=labels)
66
+ return outputs
67
+
68
+ def get_input_embeddings(self):
69
+ return self.transformer.token_embedding
70
+
71
+ def set_input_embeddings(self, value):
72
+ self.transformer.token_embedding = value
73
+
74
+ def generate(
75
+ self,
76
+ input_ids: torch.LongTensor,
77
+ max_length: int = 100,
78
+ temperature: float = 1.0,
79
+ top_k: int = 50,
80
+ top_p: float = 0.9,
81
+ repetition_penalty: float = 1.2,
82
+ no_repeat_ngram_size: int = 3,
83
+ **kwargs
84
+ ):
85
+ """Enhanced generation method with better controls for repetition."""
86
+ self.eval()
87
+ current_ids = input_ids.clone()
88
+ batch_size = current_ids.shape[0]
89
+
90
+ # Get EOS token ID from tokenizer
91
+ eos_token_id = self.transformer.eos_token_id if hasattr(self.transformer, 'eos_token_id') else None
92
+
93
+ # Track generated tokens for repetition penalty
94
+ generated_tokens = current_ids.clone()
95
+
96
+ with torch.no_grad():
97
+ for _ in range(max_length - input_ids.size(1)):
98
+ # Forward pass
99
+ outputs = self.transformer(current_ids)
100
+ logits = outputs["logits"][:, -1, :] / temperature
101
+
102
+ # Apply repetition penalty
103
+ if repetition_penalty != 1.0:
104
+ for i in range(batch_size):
105
+ for token in set(generated_tokens[i].tolist()):
106
+ logits[i, token] /= repetition_penalty
107
+
108
+ # Apply n-gram blocking
109
+ if no_repeat_ngram_size > 0:
110
+ # Get the last n-gram from the input
111
+ for i in range(batch_size):
112
+ ngram_size = min(no_repeat_ngram_size, len(generated_tokens[i]))
113
+ if ngram_size > 0:
114
+ ngrams = [tuple(generated_tokens[i, -j:].tolist()) for j in range(1, ngram_size + 1)]
115
+ for ngram in ngrams:
116
+ for token_idx in range(len(generated_tokens[i]) - len(ngram) + 1):
117
+ if tuple(generated_tokens[i, token_idx:token_idx + len(ngram)].tolist()) == ngram:
118
+ if token_idx + len(ngram) < len(generated_tokens[i]):
119
+ next_token = generated_tokens[i, token_idx + len(ngram)]
120
+ logits[i, next_token] = float('-inf')
121
+
122
+ # Apply top-k filtering
123
+ if top_k > 0:
124
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
125
+ logits[indices_to_remove] = float('-inf')
126
+
127
+ # Apply top-p (nucleus) filtering
128
+ if top_p < 1.0:
129
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
130
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
131
+ sorted_indices_to_remove = cumulative_probs > top_p
132
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
133
+ sorted_indices_to_remove[..., 0] = 0
134
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
135
+ logits[indices_to_remove] = float('-inf')
136
+
137
+ # Sample from the filtered distribution
138
+ probs = torch.softmax(logits, dim=-1)
139
+ next_token = torch.multinomial(probs, num_samples=1)
140
+
141
+ # Early stopping if EOS token is generated
142
+ if eos_token_id is not None and (next_token == eos_token_id).any():
143
+ break
144
+
145
+ # Update generated sequence
146
+ current_ids = torch.cat([current_ids, next_token], dim=1)
147
+ generated_tokens = torch.cat([generated_tokens, next_token], dim=1)
148
+
149
+ return current_ids
150
+
151
+ def convert_to_hf_model(checkpoint_path: str, output_dir: str):
152
+ """Convert the custom model checkpoint to Hugging Face format using safetensors."""
153
+ # Load the original config and checkpoint
154
+ config = load_config()
155
+
156
+ # Get tokenizer and its vocab size
157
+ tokenizer = get_tokenizer(config)
158
+ vocab_size = tokenizer.get_vocab_size()
159
+
160
+ # Create HF config with the correct vocab size
161
+ hf_config = CustomConfig(
162
+ vocab_size=vocab_size,
163
+ n_embd=config["model"]["n_embd"],
164
+ n_head=config["model"]["n_head"],
165
+ n_layer=config["model"]["n_layer"],
166
+ n_positions=config["model"]["n_positions"],
167
+ tie_word_embeddings=False # Explicitly disable weight tying
168
+ )
169
+
170
+ # Create HF model
171
+ model = CustomModelForCausalLM(hf_config)
172
+
173
+ # Load checkpoint
174
+ checkpoint = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"), map_location="cpu")
175
+
176
+ # Process state dict
177
+ new_state_dict = {}
178
+ for key, value in checkpoint.items():
179
+ if key.startswith("_orig_mod."):
180
+ key = key[len("_orig_mod."):]
181
+
182
+ if "token_embedding.weight" in key:
183
+ new_state_dict[f"transformer.{key}"] = value
184
+ # Copy embedding weights to lm_head
185
+ new_state_dict["lm_head.weight"] = value.clone()
186
+ else:
187
+ new_state_dict[f"transformer.{key}"] = value
188
+
189
+ # Load the modified state dict
190
+ missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
191
+ print(f"Missing keys: {missing_keys}")
192
+ print(f"Unexpected keys: {unexpected_keys}")
193
+
194
+ # Save in Hugging Face format with safetensors
195
+ os.makedirs(output_dir, exist_ok=True)
196
+
197
+ # Save the model in safetensors format
198
+ model.save_pretrained(
199
+ output_dir,
200
+ safe_serialization=True
201
+ )
202
+ print(f"Model successfully saved in safetensors format to {output_dir}")
203
+
204
+ # Save config
205
+ hf_config.save_pretrained(output_dir)
206
+
207
+ # Copy tokenizer files
208
+ tokenizer_files = ["vocab.json", "merges.txt", "tokenizer_config.json"]
209
+ for file in tokenizer_files:
210
+ src_path = os.path.join(config["tokenizer"]["model_path"], file)
211
+ dst_path = os.path.join(output_dir, file)
212
+ if os.path.exists(src_path):
213
+ import shutil
214
+ shutil.copy2(src_path, dst_path)
215
+
216
+ return model, tokenizer
217
+
218
+ def generate_text(
219
+ prompt: str,
220
+ model_path: str,
221
+ max_length: int = 100,
222
+ temperature: float = 2,
223
+ top_k: int = 50,
224
+ top_p: float = 0.9,
225
+ repetition_penalty: float = 1.2,
226
+ no_repeat_ngram_size: int = 3
227
+ ):
228
+ """Generate text using the converted model."""
229
+ # Load model and tokenizer
230
+ config = load_config()
231
+ model = CustomModelForCausalLM.from_pretrained(model_path)
232
+ tokenizer = get_tokenizer(config)
233
+
234
+ # Move model to GPU if available
235
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
236
+ model = model.to(device)
237
+ model.eval()
238
+
239
+ # Encode prompt
240
+ encoded = tokenizer.batch_encode(
241
+ [prompt],
242
+ return_tensors="pt"
243
+ )
244
+ input_ids = encoded["input_ids"].to(device)
245
+
246
+ # Generate
247
+ with torch.no_grad():
248
+ output_ids = model.generate(
249
+ input_ids=input_ids,
250
+ max_length=max_length,
251
+ temperature=temperature,
252
+ top_k=top_k,
253
+ top_p=top_p,
254
+ repetition_penalty=repetition_penalty,
255
+ no_repeat_ngram_size=no_repeat_ngram_size
256
+ )
257
+
258
+ # Decode and return
259
+ generated_text = tokenizer.decode(output_ids[0].tolist())
260
+ return generated_text
261
+
262
+ if __name__ == "__main__":
263
+ # Example usage
264
+ checkpoint_path = r"my_model/" # Path to your trained model
265
+ hf_output_dir = "outputs/hf_model" # Where to save the converted model
266
+
267
+
268
+ # Convert model
269
+ model, tokenizer = convert_to_hf_model(checkpoint_path, hf_output_dir)
270
+
271
+ # Generate text with better parameters
272
+ prompt = "Hello I am Clera "
273
+ generated_text = generate_text(
274
+ prompt=prompt,
275
+ model_path=hf_output_dir,
276
+ max_length=20,
277
+ temperature=2.5,
278
+ top_k=50,
279
+ top_p=0.9,
280
+ repetition_penalty=1.2,
281
+ no_repeat_ngram_size=1
282
+ )
283
+
284
+ print(f"\nPrompt: {prompt}")
285
+ print(f"Generated text: {generated_text}")
src/model.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer
4
+ from utils import load_config
5
+ from tokenizers import Tokenizer
6
+ import os
7
+ import json
8
+
9
+ class TransformerBlock(nn.Module):
10
+ """Single transformer block with self-attention and feed-forward layers"""
11
+ def __init__(self, n_embd, n_head, dropout=0.1):
12
+ super().__init__()
13
+ self.attention = nn.MultiheadAttention(n_embd, n_head, dropout=dropout, batch_first=True)
14
+ self.feed_forward = nn.Sequential(
15
+ nn.Linear(n_embd, 4 * n_embd),
16
+ nn.GELU(),
17
+ nn.Linear(4 * n_embd, n_embd)
18
+ )
19
+ self.ln1 = nn.LayerNorm(n_embd)
20
+ self.ln2 = nn.LayerNorm(n_embd)
21
+ self.dropout = nn.Dropout(dropout)
22
+
23
+ def forward(self, x, mask=None):
24
+ # Ensure mask is same dtype as input
25
+ if mask is not None:
26
+ mask = mask.to(dtype=x.dtype)
27
+ # Self-attention with residual connection
28
+ attn_out, _ = self.attention(x, x, x, attn_mask=mask)
29
+ x = x + self.dropout(attn_out)
30
+ x = self.ln1(x)
31
+ # Feed-forward with residual connection
32
+ ff_out = self.feed_forward(x)
33
+ x = x + self.dropout(ff_out)
34
+ x = self.ln2(x)
35
+ return x
36
+
37
+ class CustomLanguageModel(nn.Module):
38
+ """Custom transformer-based language model"""
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ self.vocab_size = config["model"]["vocab_size"]
42
+ self.n_embd = config["model"]["n_embd"]
43
+ self.n_head = config["model"]["n_head"]
44
+ self.n_layer = config["model"]["n_layer"]
45
+ self.n_positions = config["model"]["n_positions"]
46
+
47
+ # Token and position embeddings
48
+ self.token_embedding = nn.Embedding(self.vocab_size, self.n_embd)
49
+ self.position_embedding = nn.Embedding(self.n_positions, self.n_embd)
50
+
51
+ # Transformer blocks
52
+ self.transformer_blocks = nn.ModuleList([
53
+ TransformerBlock(self.n_embd, self.n_head)
54
+ for _ in range(self.n_layer)
55
+ ])
56
+
57
+ # Output layer
58
+ self.ln_f = nn.LayerNorm(self.n_embd)
59
+ self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=False)
60
+
61
+ # Tie weights between token embedding and output layer
62
+ self.token_embedding.weight = self.lm_head.weight
63
+
64
+ # Initialize weights
65
+ self.apply(self._init_weights)
66
+
67
+ # Set gradient checkpointing flag based on config
68
+ self.gradient_checkpointing_enable = config["model"].get("gradient_checkpointing", False)
69
+
70
+ def _init_weights(self, module):
71
+ if isinstance(module, (nn.Linear, nn.Embedding)):
72
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
73
+ if isinstance(module, nn.Linear) and module.bias is not None:
74
+ torch.nn.init.zeros_(module.bias)
75
+ elif isinstance(module, nn.LayerNorm):
76
+ torch.nn.init.zeros_(module.bias)
77
+ torch.nn.init.ones_(module.weight)
78
+
79
+ def forward(self, input_ids, labels=None):
80
+ batch_size, seq_length = input_ids.shape
81
+
82
+ # Create position indices
83
+ positions = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
84
+ positions = positions.unsqueeze(0).expand(batch_size, -1)
85
+
86
+ # Get embeddings and sum token & position embeddings
87
+ token_embeddings = self.token_embedding(input_ids)
88
+ position_embeddings = self.position_embedding(positions)
89
+ x = token_embeddings + position_embeddings
90
+
91
+ # Create causal mask and convert to same dtype as embeddings
92
+ mask = torch.triu(torch.ones((seq_length, seq_length), device=input_ids.device) * float('-inf'), diagonal=1)
93
+ mask = mask.to(dtype=x.dtype)
94
+
95
+ # Process through transformer blocks (use gradient checkpointing only if enabled)
96
+ if self.training and self.gradient_checkpointing_enable:
97
+ for block in self.transformer_blocks:
98
+ x = torch.utils.checkpoint.checkpoint(block, x, mask, use_reentrant=False)
99
+ else:
100
+ for block in self.transformer_blocks:
101
+ x = block(x, mask=mask)
102
+
103
+ x = self.ln_f(x)
104
+ logits = self.lm_head(x)
105
+
106
+ if labels is not None:
107
+ loss_fct = nn.CrossEntropyLoss()
108
+ loss = loss_fct(logits.view(-1, self.vocab_size), labels.view(-1))
109
+ return {"loss": loss, "logits": logits}
110
+
111
+ return {"logits": logits}
112
+
113
+ def num_parameters(self):
114
+ """Returns the number of trainable parameters in the model."""
115
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
116
+
117
+ def create_model(config):
118
+ """Creates a custom language model from scratch based on the configuration."""
119
+ model = CustomLanguageModel(config)
120
+ return model
121
+
122
+ def get_tokenizer(config):
123
+ """Loads a trained ByteLevelBPE tokenizer."""
124
+ from tokenizers import ByteLevelBPETokenizer
125
+
126
+ model_path = config["tokenizer"]["model_path"]
127
+ if not os.path.exists(os.path.join(model_path, "vocab.json")):
128
+ raise ValueError(f"No tokenizer found at {model_path}. Please train the tokenizer first.")
129
+
130
+ tokenizer = ByteLevelBPETokenizer(
131
+ os.path.join(model_path, "vocab.json"),
132
+ os.path.join(model_path, "merges.txt")
133
+ )
134
+
135
+ # Add special tokens if they don't exist
136
+ special_tokens = {
137
+ "eos_token": "<|endoftext|>",
138
+ "pad_token": "<|pad|>",
139
+ "unk_token": "<|unk|>",
140
+ "mask_token": "<|mask|>"
141
+ }
142
+ tokenizer.add_special_tokens(list(special_tokens.values()))
143
+
144
+ # Add methods to match expected interface
145
+ tokenizer.get_vocab_size = lambda: len(tokenizer.get_vocab())
146
+
147
+ def batch_encode(texts, padding=True, truncation=True, max_length=None, return_tensors=None):
148
+ encodings = tokenizer.encode_batch(texts)
149
+ # Extract token ids from encodings
150
+ token_ids = [enc.ids for enc in encodings]
151
+
152
+ if max_length and truncation:
153
+ token_ids = [ids[:max_length] for ids in token_ids]
154
+
155
+ if padding:
156
+ max_len = max(len(ids) for ids in token_ids)
157
+ pad_token_id = tokenizer.token_to_id("<|pad|>")
158
+ padded = []
159
+ for ids in token_ids:
160
+ pad_length = max_len - len(ids)
161
+ padded.append(ids + [pad_token_id] * pad_length)
162
+ token_ids = padded
163
+
164
+ if return_tensors == "pt":
165
+ return {
166
+ "input_ids": torch.tensor(token_ids),
167
+ "attention_mask": torch.ones_like(torch.tensor(token_ids))
168
+ }
169
+ return {"input_ids": token_ids}
170
+
171
+ tokenizer.batch_encode = batch_encode
172
+
173
+ print(f"ByteLevelBPE tokenizer loaded successfully. Vocab size: {tokenizer.get_vocab_size()}")
174
+ return tokenizer
175
+
176
+ if __name__ == "__main__":
177
+ config = load_config()
178
+ tokenizer = get_tokenizer(config)
179
+ config["model"]["vocab_size"] = tokenizer.get_vocab_size()
180
+ model = create_model(config)
181
+ print(f"Model created with {model.num_parameters():,} parameters.")
182
+ print(model)
src/tokenization.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer, models, pre_tokenizers, trainers, processors
2
+ from tokenizers.implementations import ByteLevelBPETokenizer
3
+ import os
4
+ from utils import load_config, setup_logging
5
+ from glob import glob
6
+ from tqdm import tqdm
7
+ import json
8
+ import torch
9
+
10
+ class CustomTokenizer:
11
+ """Wrapper around ByteLevelBPETokenizer with additional functionality."""
12
+ def __init__(self, tokenizer):
13
+ self.tokenizer = tokenizer
14
+ self._vocab_size = len(tokenizer.get_vocab())
15
+ self.pad_token_id = tokenizer.token_to_id("<|pad|>")
16
+ self.eos_token_id = tokenizer.token_to_id("<|endoftext|>")
17
+
18
+ def get_vocab_size(self):
19
+ return self._vocab_size
20
+
21
+ def batch_encode(self, texts, padding=True, truncation=True, max_length=None, return_tensors=None):
22
+ encodings = self.tokenizer.encode_batch(texts)
23
+ if max_length and truncation:
24
+ encodings = [enc.ids[:max_length] for enc in encodings]
25
+ if padding:
26
+ max_len = max(len(enc.ids) for enc in encodings)
27
+ padded = []
28
+ for enc in encodings:
29
+ pad_length = max_len - len(enc.ids)
30
+ padded.append(enc.ids + [self.pad_token_id] * pad_length)
31
+ encodings = padded
32
+ if return_tensors == "pt":
33
+ return {
34
+ "input_ids": torch.tensor(encodings),
35
+ "attention_mask": torch.ones_like(torch.tensor(encodings))
36
+ }
37
+ return {"input_ids": encodings}
38
+
39
+ def decode(self, token_ids):
40
+ """Decode a list of token IDs back to a string."""
41
+ if isinstance(token_ids, torch.Tensor):
42
+ token_ids = token_ids.tolist()
43
+
44
+ # Filter out padding tokens
45
+ token_ids = [t for t in token_ids if t != self.pad_token_id]
46
+
47
+ # Use the underlying tokenizer's decode method
48
+ return self.tokenizer.decode(token_ids)
49
+
50
+ def train_tokenizer(config):
51
+ """Trains a custom BPE tokenizer using the tokenizers library."""
52
+ setup_logging()
53
+
54
+ model_path = config["tokenizer"]["model_path"]
55
+ vocab_size = config["tokenizer"].get("vocab_size", 50000)
56
+ min_frequency = config["tokenizer"].get("min_frequency", 2)
57
+
58
+ # Create output directory if it doesn't exist
59
+ os.makedirs(model_path, exist_ok=True)
60
+
61
+ # Initialize a new tokenizer
62
+ tokenizer = ByteLevelBPETokenizer()
63
+
64
+ # Get all text files from the data directory
65
+ data_files = glob(os.path.join("data/raw", "*.txt"))
66
+ if not data_files:
67
+ raise ValueError("No text files found in data/raw directory")
68
+
69
+ print(f"Training tokenizer on {len(data_files)} files...")
70
+ print(f"Target vocab size: {vocab_size}")
71
+ print(f"Min frequency: {min_frequency}")
72
+
73
+ # Train the tokenizer
74
+ tokenizer.train(
75
+ files=data_files,
76
+ vocab_size=vocab_size,
77
+ min_frequency=min_frequency,
78
+ special_tokens=[
79
+ "<|endoftext|>", # End of text token
80
+ "<|pad|>", # Padding token
81
+ "<|unk|>", # Unknown token
82
+ "<|mask|>" # Mask token
83
+ ]
84
+ )
85
+
86
+ # Save the tokenizer files
87
+ tokenizer.save_model(model_path)
88
+
89
+ # Save the tokenizer configuration
90
+ tokenizer_config = {
91
+ "vocab_size": vocab_size,
92
+ "min_frequency": min_frequency,
93
+ "model_type": "byte_level_bpe",
94
+ "special_tokens": {
95
+ "eos_token": "<|endoftext|>",
96
+ "pad_token": "<|pad|>",
97
+ "unk_token": "<|unk|>",
98
+ "mask_token": "<|mask|>"
99
+ }
100
+ }
101
+
102
+ with open(os.path.join(model_path, "tokenizer_config.json"), "w") as f:
103
+ json.dump(tokenizer_config, f, indent=2)
104
+
105
+ print(f"Tokenizer trained and saved to {model_path}")
106
+ return tokenizer
107
+
108
+ def get_tokenizer(config):
109
+ """Loads a trained tokenizer."""
110
+ model_path = config["tokenizer"]["model_path"]
111
+
112
+ if not os.path.exists(os.path.join(model_path, "vocab.json")):
113
+ raise ValueError(f"No tokenizer found at {model_path}. Please train the tokenizer first.")
114
+
115
+ base_tokenizer = ByteLevelBPETokenizer(
116
+ os.path.join(model_path, "vocab.json"),
117
+ os.path.join(model_path, "merges.txt")
118
+ )
119
+
120
+ # Add special tokens if they don't exist
121
+ special_tokens = {
122
+ "eos_token": "<|endoftext|>",
123
+ "pad_token": "<|pad|>",
124
+ "unk_token": "<|unk|>",
125
+ "mask_token": "<|mask|>"
126
+ }
127
+ base_tokenizer.add_special_tokens(list(special_tokens.values()))
128
+
129
+ # Create wrapped tokenizer
130
+ tokenizer = CustomTokenizer(base_tokenizer)
131
+
132
+ print(f"ByteLevelBPE tokenizer loaded successfully. Vocab size: {tokenizer.get_vocab_size()}")
133
+ return tokenizer
134
+
135
+ if __name__ == "__main__":
136
+ config = load_config()
137
+ train_tokenizer(config)
138
+ print("Tokenizer training complete.")
src/train.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import time
4
+ import json
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.optim import AdamW
8
+ from torch.utils.data import DataLoader, Dataset, IterableDataset
9
+ from tqdm import tqdm
10
+ from accelerate import Accelerator, DeepSpeedPlugin
11
+ from accelerate.logging import get_logger
12
+ import deepspeed
13
+ import wandb
14
+ from datetime import datetime
15
+ from transformers import get_scheduler
16
+ from model import create_model, get_tokenizer
17
+ from utils import load_config, setup_logging
18
+ from torch.nn.utils.rnn import pad_sequence
19
+
20
+ logger = get_logger(__name__)
21
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
+ # Enable TF32 for faster matrix multiplications (if supported)
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ torch.backends.cudnn.allow_tf32 = True
25
+
26
+ def load_text_files(data_dir, chunk_size=2000000):
27
+ """Load text files from directory in chunks."""
28
+ if not os.path.exists(data_dir):
29
+ raise ValueError(f"Data directory {data_dir} does not exist")
30
+
31
+ all_files = [f for f in os.listdir(data_dir) if f.endswith('.txt')]
32
+ print(f"Found {len(all_files)} text files in {data_dir}")
33
+
34
+ total_size = sum(os.path.getsize(os.path.join(data_dir, f)) for f in all_files)
35
+ estimated_chunks = math.ceil(total_size / chunk_size)
36
+ total_characters = 0
37
+ current_chunk_num = 0
38
+
39
+ for file_name in all_files:
40
+ file_path = os.path.join(data_dir, file_name)
41
+ try:
42
+ with open(file_path, 'r', encoding='utf-8') as f:
43
+ file_size = os.path.getsize(file_path)
44
+ print(f"Processing file: {file_name} (Size: {file_size/1024/1024:.2f}MB)")
45
+ print(f"Estimated total chunks: {estimated_chunks}")
46
+
47
+ current_chunk = []
48
+ current_size = 0
49
+ chunk_start_char = total_characters
50
+
51
+ for line in f:
52
+ line = line.strip()
53
+ if line:
54
+ current_chunk.append(line)
55
+ current_size += len(line)
56
+ total_characters += len(line)
57
+
58
+ if current_size >= chunk_size:
59
+ current_chunk_num += 1
60
+ print(f"Yielding chunk {current_chunk_num}/{estimated_chunks} "
61
+ f"({len(current_chunk)} texts, {current_size:,} characters, "
62
+ f"Range: {chunk_start_char:,} - {total_characters:,})")
63
+ yield current_chunk
64
+ current_chunk = []
65
+ current_size = 0
66
+ chunk_start_char = total_characters
67
+
68
+ if current_chunk:
69
+ current_chunk_num += 1
70
+ print(f"Yielding final chunk {current_chunk_num}/{estimated_chunks} "
71
+ f"({len(current_chunk)} texts, {current_size:,} characters, "
72
+ f"Range: {chunk_start_char:,} - {total_characters:,})")
73
+ yield current_chunk
74
+ except Exception as e:
75
+ print(f"Error reading file {file_path}: {e}")
76
+ continue
77
+
78
+ class TextDataset(Dataset):
79
+ def __init__(self, tokenized_texts):
80
+ self.input_ids = tokenized_texts["input_ids"]
81
+ self.labels = tokenized_texts["labels"]
82
+
83
+ def __len__(self):
84
+ return len(self.input_ids)
85
+
86
+ def __getitem__(self, idx):
87
+ return {"input_ids": self.input_ids[idx], "labels": self.labels[idx]}
88
+
89
+ class StreamingTextDataset(IterableDataset):
90
+ def __init__(self, data_dir, tokenizer, max_length):
91
+ super().__init__()
92
+ self.data_dir = data_dir
93
+ self.tokenizer = tokenizer
94
+ self.max_length = max_length
95
+ self.files = [f for f in os.listdir(data_dir) if f.endswith('.txt')]
96
+
97
+ def __iter__(self):
98
+ worker_info = torch.utils.data.get_worker_info()
99
+ files_per_worker = len(self.files)
100
+ if worker_info is not None:
101
+ files_per_worker = len(self.files) // worker_info.num_workers
102
+ start_idx = worker_info.id * files_per_worker
103
+ end_idx = start_idx + files_per_worker if worker_info.id < worker_info.num_workers - 1 else len(self.files)
104
+ files = self.files[start_idx:end_idx]
105
+ else:
106
+ files = self.files
107
+
108
+ for file_name in files:
109
+ file_path = os.path.join(self.data_dir, file_name)
110
+ with open(file_path, 'r', encoding='utf-8') as f:
111
+ text_buffer = []
112
+ current_length = 0
113
+
114
+ for line in f:
115
+ line = line.strip()
116
+ if not line:
117
+ continue
118
+
119
+ text_buffer.append(line)
120
+ current_length += len(line)
121
+
122
+ if current_length >= self.max_length:
123
+ # Encode and yield the batch
124
+ text = " ".join(text_buffer)
125
+ encodings = self.tokenizer.batch_encode(
126
+ [text],
127
+ max_length=self.max_length,
128
+ truncation=True,
129
+ padding=False, # Don't pad here, we'll pad in collate_fn
130
+ return_tensors="pt"
131
+ )
132
+
133
+ # Return individual tensors
134
+ yield {
135
+ "input_ids": encodings["input_ids"][0],
136
+ "labels": encodings["input_ids"][0].clone()
137
+ }
138
+ text_buffer = []
139
+ current_length = 0
140
+
141
+ def collate_batch(batch):
142
+ """Custom collate function to handle variable length sequences."""
143
+ # Separate input_ids and labels
144
+ input_ids = [item["input_ids"] for item in batch]
145
+ labels = [item["labels"] for item in batch]
146
+
147
+ # Pad sequences
148
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
149
+ labels = pad_sequence(labels, batch_first=True, padding_value=-100) # -100 is PyTorch's default ignore index
150
+
151
+ # Create attention masks
152
+ attention_mask = (input_ids != 0).long()
153
+
154
+ return {
155
+ "input_ids": input_ids,
156
+ "labels": labels,
157
+ "attention_mask": attention_mask
158
+ }
159
+
160
+ def train_model(config):
161
+ """Trains the model using DeepSpeed and Accelerate for memory efficiency."""
162
+ # Create output directory
163
+ output_dir = config["training"]["output_dir"]
164
+ os.makedirs(output_dir, exist_ok=True)
165
+ print(f"Model will be saved to: {output_dir}")
166
+
167
+ # Initialize DeepSpeed plugin and accelerator
168
+ deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=config["training"]["deepspeed"])
169
+ accelerator = Accelerator(
170
+ gradient_accumulation_steps=config["training"]["gradient_accumulation_steps"],
171
+ mixed_precision="fp16",
172
+ deepspeed_plugin=deepspeed_plugin,
173
+ log_with=config["training"]["report_to"]
174
+ )
175
+
176
+ # Initialize tracking
177
+ if accelerator.is_main_process:
178
+ accelerator.init_trackers(
179
+ project_name=config["training"]["wandb"]["project"],
180
+ config=config,
181
+ init_kwargs={
182
+ "wandb": {
183
+ "entity": config["training"]["wandb"]["entity"],
184
+ "name": config["training"]["wandb"]["name"],
185
+ }
186
+ }
187
+ )
188
+ print(f"Tracking initialized with {config['training']['report_to']}")
189
+
190
+ device = accelerator.device
191
+ print(f"Using device: {device}")
192
+
193
+ # Load tokenizer and model
194
+ tokenizer = get_tokenizer(config)
195
+ config["model"]["vocab_size"] = tokenizer.get_vocab_size()
196
+ model = create_model(config)
197
+
198
+ try:
199
+ model = torch.compile(model)
200
+ print("torch.compile enabled for faster training.")
201
+ except Exception as e:
202
+ print("torch.compile not available or failed, continuing without it.")
203
+
204
+ optimizer = AdamW(
205
+ model.parameters(),
206
+ lr=config["training"]["learning_rate"],
207
+ weight_decay=config["training"]["weight_decay"]
208
+ )
209
+
210
+ # Create streaming dataset with custom collate function
211
+ dataset = StreamingTextDataset(
212
+ data_dir="data/raw",
213
+ tokenizer=tokenizer,
214
+ max_length=config["dataset"]["max_length"]
215
+ )
216
+
217
+ train_loader = DataLoader(
218
+ dataset,
219
+ batch_size=config["training"]["per_device_train_batch_size"],
220
+ num_workers=config["training"]["dataloader_num_workers"],
221
+ pin_memory=True,
222
+ collate_fn=collate_batch # Add custom collate function
223
+ )
224
+
225
+ # Prepare for distributed training
226
+ model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
227
+
228
+ # Calculate approximate steps per epoch based on target dataset size
229
+ avg_seq_length = config["dataset"]["max_length"] // 2 # Average sequence length
230
+ batch_size = config["training"]["per_device_train_batch_size"]
231
+ target_size_gb = config["dataset"].get("target_size_gb", 2.5)
232
+ chars_per_token = 4
233
+ total_tokens = (target_size_gb * 1024 * 1024 * 1024) // chars_per_token
234
+ steps_per_epoch = int(total_tokens // (avg_seq_length * batch_size)) # Convert to int
235
+ total_epochs = config["training"]["num_train_epochs"]
236
+ total_steps = int(steps_per_epoch * total_epochs) # Convert to int
237
+
238
+ print(f"\nTraining Statistics (Estimated):")
239
+ print(f"Total epochs: {total_epochs}")
240
+ print(f"Estimated steps per epoch: {steps_per_epoch:,}")
241
+ print(f"Estimated total steps: {total_steps:,}")
242
+
243
+ # Track gradients for logging
244
+ def grad_norm(model):
245
+ total_norm = 0.0
246
+ for p in model.parameters():
247
+ if p.grad is not None:
248
+ param_norm = p.grad.detach().data.norm(2)
249
+ total_norm += param_norm.item() ** 2
250
+ return total_norm ** 0.5
251
+
252
+ # Initialize GPU monitoring
253
+ if torch.cuda.is_available():
254
+ gpu_id = torch.cuda.current_device()
255
+
256
+ training_stats = {
257
+ 'train/loss': 0.0,
258
+ 'train/learning_rate': 0.0,
259
+ 'train/epoch': 0.0,
260
+ 'train/global_step': 0,
261
+ 'train/samples_per_second': 0.0,
262
+ 'train/grad_norm': 0.0,
263
+ 'performance/gpu_memory': 0.0,
264
+ 'performance/gpu_utilization': 0.0,
265
+ 'performance/batch_time': 0.0,
266
+ }
267
+
268
+ for epoch in range(total_epochs):
269
+ epoch_start_time = time.time()
270
+ model.train()
271
+ running_loss = 0
272
+ num_batches = 0
273
+ samples_processed = 0
274
+
275
+ progress_bar = tqdm(
276
+ total=steps_per_epoch,
277
+ desc=f"Epoch {epoch+1}/{total_epochs}",
278
+ disable=not accelerator.is_local_main_process
279
+ )
280
+
281
+ for batch in train_loader:
282
+ batch_start_time = time.time()
283
+
284
+ with accelerator.accumulate(model):
285
+ outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
286
+ loss = outputs["loss"]
287
+ accelerator.backward(loss)
288
+
289
+ if accelerator.sync_gradients:
290
+ training_stats['train/grad_norm'] = grad_norm(model)
291
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
292
+
293
+ optimizer.step()
294
+ optimizer.zero_grad()
295
+
296
+ # Update statistics
297
+ loss_value = loss.item()
298
+ running_loss += loss_value
299
+ num_batches += 1
300
+ samples_processed += batch["input_ids"].size(0)
301
+ batch_time = time.time() - batch_start_time
302
+
303
+ # Update training stats
304
+ training_stats.update({
305
+ 'train/loss': loss_value,
306
+ 'train/learning_rate': optimizer.param_groups[0]['lr'],
307
+ 'train/epoch': epoch + 1,
308
+ 'train/global_step': num_batches + (epoch * steps_per_epoch),
309
+ 'train/samples_per_second': batch["input_ids"].size(0) / batch_time,
310
+ 'performance/batch_time': batch_time,
311
+ })
312
+
313
+ # GPU stats (if available)
314
+ if torch.cuda.is_available():
315
+ training_stats.update({
316
+ 'performance/gpu_memory': torch.cuda.memory_allocated(gpu_id) / 1024**3, # GB
317
+ 'performance/gpu_utilization': torch.cuda.utilization(gpu_id),
318
+ })
319
+
320
+ # Update progress bar
321
+ avg_speed = num_batches / (time.time() - epoch_start_time)
322
+ eta_epoch = (steps_per_epoch - num_batches) / avg_speed / 60 # minutes
323
+ eta_total = (total_steps - (epoch * steps_per_epoch + num_batches)) / avg_speed / 60 # minutes
324
+
325
+ progress_bar.set_postfix({
326
+ 'loss': f'{loss_value:.4f}',
327
+ 'avg_loss': f'{running_loss/num_batches:.4f}',
328
+ 'lr': f'{optimizer.param_groups[0]["lr"]:.2e}',
329
+ 'samples/s': f'{training_stats["train/samples_per_second"]:.2f}',
330
+ 'epoch_eta': f'{eta_epoch:.1f}min',
331
+ 'total_eta': f'{eta_total:.1f}min'
332
+ })
333
+ progress_bar.update(1)
334
+
335
+ # Log metrics based on logging_steps
336
+ if num_batches % config["training"]["logging_steps"] == 0:
337
+ if accelerator.is_main_process:
338
+ current_step = int(num_batches + (epoch * steps_per_epoch)) # Convert to int
339
+ accelerator.log(training_stats, step=current_step)
340
+
341
+ # Save checkpoint based on save_steps
342
+ if num_batches % config["training"]["save_steps"] == 0:
343
+ if accelerator.is_local_main_process:
344
+ checkpoint_dir = os.path.join(output_dir, f"checkpoint-epoch{epoch+1}-step{num_batches}")
345
+ os.makedirs(checkpoint_dir, exist_ok=True)
346
+ print(f"\nSaving checkpoint at step {num_batches} to {checkpoint_dir}")
347
+ accelerator.save_state(checkpoint_dir)
348
+ with open(os.path.join(checkpoint_dir, "config.json"), "w") as f:
349
+ json.dump(config, f, indent=2)
350
+
351
+ # Break if we've reached the estimated steps for this epoch
352
+ if num_batches >= steps_per_epoch:
353
+ break
354
+
355
+ progress_bar.close()
356
+
357
+ # End of epoch logging
358
+ epoch_time = time.time() - epoch_start_time
359
+ epoch_avg_loss = running_loss / num_batches
360
+ epoch_perplexity = torch.exp(torch.tensor(epoch_avg_loss))
361
+
362
+ if accelerator.is_main_process:
363
+ print(f"\nEpoch {epoch+1}/{total_epochs} Summary:")
364
+ print(f"Time: {epoch_time/60:.2f} minutes")
365
+ print(f"Average Loss: {epoch_avg_loss:.4f}")
366
+ print(f"Perplexity: {epoch_perplexity:.2f}")
367
+ print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
368
+ print(f"Samples Processed: {samples_processed:,}")
369
+ print(f"Average Speed: {samples_processed/epoch_time:.1f} samples/s")
370
+
371
+ # Estimate remaining time
372
+ epochs_remaining = total_epochs - (epoch + 1)
373
+ estimated_remaining_time = epochs_remaining * epoch_time / 60
374
+ print(f"Estimated time for remaining {epochs_remaining} epochs: {estimated_remaining_time:.1f} minutes")
375
+
376
+ # Log epoch summary to wandb with correct step
377
+ current_step = int((epoch + 1) * steps_per_epoch) # Convert to int
378
+ accelerator.log({
379
+ 'epoch/average_loss': epoch_avg_loss,
380
+ 'epoch/perplexity': epoch_perplexity.item(),
381
+ 'epoch/time': epoch_time,
382
+ 'epoch/samples_processed': samples_processed,
383
+ }, step=current_step)
384
+
385
+ # Save final model
386
+ if accelerator.is_local_main_process:
387
+ final_model_dir = os.path.join(output_dir, "final_model")
388
+ os.makedirs(final_model_dir, exist_ok=True)
389
+ print(f"\nSaving final model to {final_model_dir}")
390
+
391
+ # Save with DeepSpeed
392
+ accelerator.save_state(final_model_dir)
393
+
394
+ # Save configuration
395
+ with open(os.path.join(final_model_dir, "config.json"), "w") as f:
396
+ json.dump(config, f, indent=2)
397
+
398
+ print("Final model saved successfully")
399
+ accelerator.end_training()
400
+
401
+ if __name__ == "__main__":
402
+ config = load_config()
403
+ train_model(config)
404
+ print("Training complete.")
src/utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from pathlib import Path
3
+
4
+ def load_config(config_path="config/config.yaml"):
5
+ """Loads the configuration from a YAML file."""
6
+ with open(config_path, "r") as f:
7
+ return yaml.safe_load(f)
8
+
9
+ def setup_logging(log_level="INFO"):
10
+ """Sets up basic logging."""
11
+ import logging
12
+ logging.basicConfig(
13
+ level=getattr(logging, log_level.upper(), logging.INFO),
14
+ format="%(asctime)s - %(levelname)s - %(message)s"
15
+ )