burtenshaw
commited on
Commit
·
e6a04c6
1
Parent(s):
90196d3
update blog post to use fancy mc fancienson style
Browse files- README.md +0 -470
- app/src/components/Hero.astro +2 -1
- app/src/content/article.mdx +551 -28
- app/src/content/assets/image/nanochat-banner.png +3 -0
- app/src/content/assets/image/tweet.png +3 -0
- app/src/content/chapters/grpo.mdx +406 -0
- app/src/content/chapters/inference.mdx +97 -0
- app/src/content/chapters/sft.mdx +400 -0
- grpo.ipynb +654 -0
- sft.ipynb +591 -0
README.md
CHANGED
|
@@ -28,473 +28,3 @@ thumbnail: https://HuggingFaceTB-smol-training-playbook.hf.space/thumb.png
|
|
| 28 |
**[Try the live demo & documentation →](https://huggingface.co/spaces/tfrere/research-article-template)**
|
| 29 |
|
| 30 |
</div>
|
| 31 |
-
|
| 32 |
-
# Porting nanochat to Transformers: an AI modeling history lesson
|
| 33 |
-
|
| 34 |
-
**tldr:** There is a lot t learn about ML from nanochat, and even more to learn about the history of the transformer architecture.
|
| 35 |
-
|
| 36 |
-
Recently I was working on helping students of the [nanochat](https://huggingface.co/nanochat-students) project to share their models and discuss their learning on Hugging Face. In the process, I thought it would be useful if the model was integrated into the `transformers` library. This would allow others to use their nanochat models for inference in loads of downstream libraries like vLLM for inference or TRL for post-training.
|
| 37 |
-
|
| 38 |
-
You can now use nanochat models in transformers and tap into all those educational gains across the ecosystem. But along the way, I uncovered a further treasure trove of education about how canonical models relate to each other, and the components they share.
|
| 39 |
-
|
| 40 |
-
I received the lesson from the simple teacher of class inheritance and transformers modular philosophy. If you want to learn more about that, check out this [guide here](https://huggingface.co/docs/transformers/v4.48.0/modular_transformers).
|
| 41 |
-
|
| 42 |
-
Here, let’s tuck into this deep dive on how NanoChat relates the lineage of transformer architectures.
|
| 43 |
-
|
| 44 |
-
## What is `nanochat`?
|
| 45 |
-
|
| 46 |
-
On October 13th 2025, Andrej Karpathy unceremoniously [dropped](https://x.com/karpathy/status/1977755427569111362) the nanochat [repo](https://github.com/karpathy/nanochat) into the unsuspecting AI world. To hype seekers, this was just a small and pretty average LLM. To ML devotees, this was nirvana. A raw unadulterated chance to tinker, fiddle, and play with a transformer model defined in pure pytorch. Nothing was hidden away in fancy `torch` methods or inherited from complex class structures. It was all there in a simple file.
|
| 47 |
-
|
| 48 |
-
![][image1]
|
| 49 |
-
|
| 50 |
-
Karpathy had painstakingly implemented an end-to-end build of an LLM system without the use of most major libraries. Even though in real world situations most rely on transformers, tokenizers, datasets, trl, etc. This back to basics approach gives us the chance to genuinely learn and understand something from the ground up.
|
| 51 |
-
|
| 52 |
-
Personally, I found the process to be one of the most educational I can remember.
|
| 53 |
-
|
| 54 |
-
## What is `transformers` and how is it educational?
|
| 55 |
-
|
| 56 |
-
Most of know the `transformers` library as the backbone of modern machine learning, but if we dig a little deeper, it’s a powerful piece of education.
|
| 57 |
-
|
| 58 |
-
If you don’t know… transformers is the de facto implementation of modern AI models that bear the same name; ‘transformers’ like models in GPT, DeepSeek, Claude, series. `transformers` is a special project because it contains the implementation of all major open model architecture and those model architectures are modularized to reuse functionality from each other. If you want to explore the philosophy and lineage behind transformers’ modularity, check out this [guide here](https://huggingface.co/docs/transformers/v4.48.0/modular_transformers).
|
| 59 |
-
|
| 60 |
-
In general, scientists at AI research labs design, implement, and train their models in their framework of choice, be that torch, JAX, etc. When they come to share their open model with the community, they will open a PR on transformers and refactor their code to use relevant modules.
|
| 61 |
-
|
| 62 |
-
Because `transformers` contain most major model implementations, researchers have to inherent model architecture attributes from other canonical models. This is in every sense a ‘single source of truth’.
|
| 63 |
-
|
| 64 |
-
This practical feature of the library has an amazingly educational quality to it. We can read a model implementation as a series of references to other usages of those architectural features. For example, when one model uses a certain type of [RMSNorm](https://github.com/huggingface/transformers/blob/9f5b2d1b8995daa539b757e28c337e36408055e6/src/transformers/models/nanochat/modular_nanochat.py#L44), we can plainly see that it is the same implementation as another model because it inherits that class entirely. For example, check out nanochat’s RMSNorm:
|
| 65 |
-
|
| 66 |
-
```py
|
| 67 |
-
class NanoChatRMSNorm(Llama4TextL2Norm):
|
| 68 |
-
pass
|
| 69 |
-
```
|
| 70 |
-
|
| 71 |
-
The `transformers` library then converts the `modular_*` implementation into a `modeling_*` implementation, which contains the complete `torch` native implementation:
|
| 72 |
-
|
| 73 |
-
```py
|
| 74 |
-
class NanoChatRMSNorm(torch.nn.Module):
|
| 75 |
-
def __init__(self, eps: float = 1e-6):
|
| 76 |
-
super().__init__()
|
| 77 |
-
self.eps = eps
|
| 78 |
-
|
| 79 |
-
def _norm(self, x):
|
| 80 |
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 81 |
-
|
| 82 |
-
def forward(self, x):
|
| 83 |
-
return self._norm(x.float()).type_as(x)
|
| 84 |
-
|
| 85 |
-
def extra_repr(self):
|
| 86 |
-
return f"eps={self.eps}"
|
| 87 |
-
```
|
| 88 |
-
|
| 89 |
-
If we review a model in `transformers`, we can review both sides and learn from the math and literature of the model’s implementation. Due to the educational nature of nanochat, I thought that it was a perfect opportunity to explore this aspect of transformers and share what I learnt with students.
|
| 90 |
-
|
| 91 |
-
## Why do we need nanochat in `transformers`?
|
| 92 |
-
|
| 93 |
-
It might seem counterintuitive to support an educational model like nanochat in a production grade library like `transformers`. After all, we can see from nanochat’s benchmark scores that it does not rival state of the art models like Qwen3, SmolLM3, Gemma3, or Olmo3.
|
| 94 |
-
|
| 95 |
-
Nanochat was never really intended as a production grade model. It was meant as an educational tool, and that’s the same reason why we need it in transformers. There are four main reasons:
|
| 96 |
-
|
| 97 |
-
- `transformers` as a single source of truth teaches us about `nanochat`’s lineage.
|
| 98 |
-
- use the `nanochat` model in other libraries.
|
| 99 |
-
- save money by reusing nanochat checkpoints for fine-tuning.
|
| 100 |
-
- compare nanochat fine-tuning with other open model checkpoints.
|
| 101 |
-
|
| 102 |
-
Firstly, as mentioned above`transformers` teaches us about the modeling conventions that Karpathy uses from other canonical implementations.
|
| 103 |
-
|
| 104 |
-
Secondly, because transformers is a standard within the ecosystem, it unlocks more downstream learning in post training libraries, quantisation tools, inference libraries, and device integrations. In practical terms, here are some examples nanochat students could learn on top of `transformers`:
|
| 105 |
-
|
| 106 |
-
- Quantize models in llama.cpp ($0)
|
| 107 |
-
- Integrate models into the browser and WebGPU ($0)
|
| 108 |
-
- SFT training in TRL/torch on Google Colab ($0)
|
| 109 |
-
- RL training TRL/torch on Google Colab ($0 \- $9)
|
| 110 |
-
- Agentic RL in TRL on Google Colab ($0 \- $9)
|
| 111 |
-
|
| 112 |
-
Finally, training AI models is expensive. Running the `nanochat` [`speedrun.sh`](https://github.com/karpathy/nanochat/blob/master/speedrun.sh) costs between $200 and $2k depending on the model size we use. Which is little compared to the millions of dollars invested by frontier labs. But that is still a significant sum for students, who always learn best by taking a few chances to fail and build experience.
|
| 113 |
-
|
| 114 |
-
In short, let’s unlock more opportunities for education\!
|
| 115 |
-
|
| 116 |
-
## The nanochat architecture
|
| 117 |
-
|
| 118 |
-
As described by Karpathy, nanochat uses an archetypal architecture that is common across the field, which makes it an excellent choice for an educational resource because folk get to learn from what works.
|
| 119 |
-
|
| 120 |
-
The core model implementation ([`nanochat/gpt.py`](http://gpt.py), 291 lines) demonstrates modern transformer architecture, with every design decision documented and justified.
|
| 121 |
-
|
| 122 |
-
The configuration uses a single complexity slider: depth. Set `--depth=20` and everything else automatically adjusts. Model dimension equals depth × 64 (20 layers → 1,280 dimensions). Number of attention heads equals depth ÷ 2 (10 heads). Head dimension is fixed at 128\. This "aspect ratio philosophy" simplifies scaling. So if you want a more capable model or have a bigger budget. Just increase depth to 26 ($300 budget) or 30 ($1,000 budget).
|
| 123 |
-
|
| 124 |
-
The architecture incorporates five key improvements over vanilla transformers. Let’s work through the components of this architecture and compare them across implementation:
|
| 125 |
-
|
| 126 |
-
#### Forward pass based on the Llama Architecture
|
| 127 |
-
|
| 128 |
-
The forward pass in nanochat handles both training and generation. We can simply read that the input `x` is embedded and then updated by each layer then the head. During training, a loss is calculated and returned instead of the logits themselves.
|
| 129 |
-
|
| 130 |
-
```py
|
| 131 |
-
def forward(self, x, targets=None, loss_reduction='mean'):
|
| 132 |
-
x = self.token_emb(x)
|
| 133 |
-
for layer in self.layers:
|
| 134 |
-
x = layer(x)
|
| 135 |
-
x = self.ln_f(x)
|
| 136 |
-
logits = self.lm_head(x)
|
| 137 |
-
|
| 138 |
-
if targets is not None:
|
| 139 |
-
loss = F.cross_entropy(
|
| 140 |
-
logits.view(-1, self.vocab_size),
|
| 141 |
-
targets.view(-1),
|
| 142 |
-
ignore_index=-1,
|
| 143 |
-
reduction=loss_reduction
|
| 144 |
-
)
|
| 145 |
-
return loss
|
| 146 |
-
return logits
|
| 147 |
-
```
|
| 148 |
-
|
| 149 |
-
By returning loss directly when targets are provided, the training loop becomes trivial. No separate loss computation, no manual masking logic—just `loss = model(inputs, targets)` followed by `loss.backward()`.
|
| 150 |
-
|
| 151 |
-
`transformers` has to make things a bit more complex to facilitate the downstream ecosystem that uses logits in a broad spectrum of ways. Therefore, loss calculation is dealt with in training-specific code, and the `forward` function returns `BaseModelOutputWithPast`.
|
| 152 |
-
|
| 153 |
-
```py
|
| 154 |
-
class NanoChatModel(LlamaModel):
|
| 155 |
-
def __init__(self, config: NanoChatConfig):
|
| 156 |
-
super().__init__(config)
|
| 157 |
-
|
| 158 |
-
self.initial_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
| 159 |
-
self.norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
| 160 |
-
|
| 161 |
-
def forward(
|
| 162 |
-
self,
|
| 163 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 164 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 165 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 166 |
-
past_key_values: Optional[Cache] = None,
|
| 167 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 168 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 169 |
-
use_cache: Optional[bool] = None,
|
| 170 |
-
**kwargs: Unpack[TransformersKwargs],
|
| 171 |
-
) -> BaseModelOutputWithPast:
|
| 172 |
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 173 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 174 |
-
|
| 175 |
-
if inputs_embeds is None:
|
| 176 |
-
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
| 177 |
-
|
| 178 |
-
if use_cache and past_key_values is None:
|
| 179 |
-
past_key_values = DynamicCache(config=self.config)
|
| 180 |
-
|
| 181 |
-
if cache_position is None:
|
| 182 |
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 183 |
-
cache_position: torch.Tensor = torch.arange(
|
| 184 |
-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
if position_ids is None:
|
| 188 |
-
position_ids = cache_position.unsqueeze(0)
|
| 189 |
-
|
| 190 |
-
causal_mask = create_causal_mask(
|
| 191 |
-
config=self.config,
|
| 192 |
-
input_embeds=inputs_embeds,
|
| 193 |
-
attention_mask=attention_mask,
|
| 194 |
-
cache_position=cache_position,
|
| 195 |
-
past_key_values=past_key_values,
|
| 196 |
-
position_ids=position_ids,
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
hidden_states = inputs_embeds
|
| 200 |
-
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
|
| 201 |
-
|
| 202 |
-
hidden_states = self.initial_norm(hidden_states) # Additional norm before the layers
|
| 203 |
-
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 204 |
-
hidden_states = decoder_layer(
|
| 205 |
-
hidden_states,
|
| 206 |
-
attention_mask=causal_mask,
|
| 207 |
-
position_embeddings=position_embeddings,
|
| 208 |
-
position_ids=position_ids,
|
| 209 |
-
past_key_values=past_key_values,
|
| 210 |
-
cache_position=cache_position,
|
| 211 |
-
**kwargs,
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
-
hidden_states = self.norm(hidden_states)
|
| 215 |
-
return BaseModelOutputWithPast(
|
| 216 |
-
last_hidden_state=hidden_states,
|
| 217 |
-
past_key_values=past_key_values,
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
```
|
| 221 |
-
|
| 222 |
-
#### Rotary Position Embeddings (RoPE)
|
| 223 |
-
|
| 224 |
-
Rotary Position Embeddings (RoPE) replace learned positional encodings by rotating query and key vectors using precomputed sin/cos frequencies:
|
| 225 |
-
|
| 226 |
-
```py
|
| 227 |
-
def apply_rope(x, cos, sin):
|
| 228 |
-
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 229 |
-
y1 = x1 * cos - x2 * sin
|
| 230 |
-
y2 = x1 * sin + x2 * cos
|
| 231 |
-
return torch.stack([y1, y2], dim=-1).flatten(-2)
|
| 232 |
-
```
|
| 233 |
-
|
| 234 |
-
In transformers, the rotary embeddings are implemented like so:
|
| 235 |
-
|
| 236 |
-
```py
|
| 237 |
-
from ..llama.modeling_llama import (
|
| 238 |
-
LlamaDecoderLayer,
|
| 239 |
-
LlamaModel,
|
| 240 |
-
LlamaPreTrainedModel,
|
| 241 |
-
LlamaRotaryEmbedding,
|
| 242 |
-
apply_rotary_pos_emb,
|
| 243 |
-
eager_attention_forward,
|
| 244 |
-
)
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
|
| 248 |
-
pass
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
def rotate_half(x):
|
| 252 |
-
"""Rotates half the hidden dims of the input with flipped signs for NanoChat."""
|
| 253 |
-
x1 = x[..., : x.shape[-1] // 2]
|
| 254 |
-
x2 = x[..., x.shape[-1] // 2 :]
|
| 255 |
-
return torch.cat((x2, -x1), dim=-1)
|
| 256 |
-
```
|
| 257 |
-
|
| 258 |
-
`NanoChatRotaryEmbedding` almost entirely inherits from the original Llama series, except for a sign inversion in `rotate_half`**.**
|
| 259 |
-
|
| 260 |
-
### **QK Normalization**
|
| 261 |
-
|
| 262 |
-
NanoChat applies RMSNorm to queries and keys before computing attention to stabilize training.
|
| 263 |
-
|
| 264 |
-
In the original gpt.py, this is achieved via a functional norm helper applied directly inside the attention forward pass:
|
| 265 |
-
|
| 266 |
-
```py
|
| 267 |
-
def norm(x):
|
| 268 |
-
# Purely functional rmsnorm with no learnable params
|
| 269 |
-
return F.rms_norm(x, (x.size(-1),))
|
| 270 |
-
|
| 271 |
-
class CausalSelfAttention(nn.Module):
|
| 272 |
-
...
|
| 273 |
-
def forward(self, x, cos_sin, kv_cache):
|
| 274 |
-
B, T, C = x.size()
|
| 275 |
-
|
| 276 |
-
# Project the input to get queries, keys, and values
|
| 277 |
-
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
| 278 |
-
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 279 |
-
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 280 |
-
|
| 281 |
-
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
| 282 |
-
cos, sin = cos_sin
|
| 283 |
-
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
| 284 |
-
q, k = norm(q), norm(k) # QK norm
|
| 285 |
-
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
| 286 |
-
...
|
| 287 |
-
```
|
| 288 |
-
|
| 289 |
-
In the modular transformers implementation, we see a fascinating mix of lineages. The `NanoChatRMSNorm` inherits directly from `Llama4TextL2Norm`, while the attention mechanism inherits from `Qwen3Attention`. We simply inject the QK normalization into the Qwen3 logic:
|
| 290 |
-
|
| 291 |
-
```py
|
| 292 |
-
|
| 293 |
-
class NanoChatRMSNorm(Llama4TextL2Norm):
|
| 294 |
-
pass
|
| 295 |
-
|
| 296 |
-
class NanoChatAttention(Qwen3Attention):
|
| 297 |
-
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
| 298 |
-
super().__init__(config, layer_idx)
|
| 299 |
-
del self.sliding_window
|
| 300 |
-
del self.layer_type
|
| 301 |
-
|
| 302 |
-
self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
| 303 |
-
self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
| 304 |
-
|
| 305 |
-
def forward(
|
| 306 |
-
self,
|
| 307 |
-
hidden_states: torch.Tensor,
|
| 308 |
-
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 309 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 310 |
-
past_key_values: Optional[Cache] = None,
|
| 311 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 312 |
-
**kwargs: Unpack[TransformersKwargs],
|
| 313 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 314 |
-
input_shape = hidden_states.shape[:-1]
|
| 315 |
-
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 316 |
-
|
| 317 |
-
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 318 |
-
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 319 |
-
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 320 |
-
|
| 321 |
-
cos, sin = position_embeddings
|
| 322 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 323 |
-
|
| 324 |
-
# RoPE -> Norm (instead of usual Norm -> RoPE)
|
| 325 |
-
query_states = self.q_norm(query_states)
|
| 326 |
-
key_states = self.k_norm(key_states)
|
| 327 |
-
|
| 328 |
-
if past_key_values is not None:
|
| 329 |
-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 330 |
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 331 |
-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 332 |
-
|
| 333 |
-
attention_interface: Callable = eager_attention_forward
|
| 334 |
-
if self.config._attn_implementation != "eager":
|
| 335 |
-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 336 |
-
|
| 337 |
-
attn_output, attn_weights = attention_interface(
|
| 338 |
-
self,
|
| 339 |
-
query_states,
|
| 340 |
-
key_states,
|
| 341 |
-
value_states,
|
| 342 |
-
attention_mask,
|
| 343 |
-
dropout=0.0 if not self.training else self.attention_dropout,
|
| 344 |
-
scaling=self.scaling,
|
| 345 |
-
**kwargs,
|
| 346 |
-
)
|
| 347 |
-
|
| 348 |
-
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 349 |
-
attn_output = self.o_proj(attn_output)
|
| 350 |
-
return attn_output, attn_weights
|
| 351 |
-
```
|
| 352 |
-
|
| 353 |
-
### **Untied Weights**
|
| 354 |
-
|
| 355 |
-
Karpathy's implementation deliberately unties the weights between the token embedding and the language model head to provide the model with more flexibility. In gpt.py, these are initialized as two completely separate modules:
|
| 356 |
-
|
| 357 |
-
```py
|
| 358 |
-
class GPT(nn.Module):
|
| 359 |
-
def __init__(self, config):
|
| 360 |
-
super().__init__()
|
| 361 |
-
self.config = config
|
| 362 |
-
self.transformer = nn.ModuleDict({
|
| 363 |
-
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
| 364 |
-
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
| 365 |
-
})
|
| 366 |
-
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 367 |
-
# ... (rest of init)
|
| 368 |
-
```
|
| 369 |
-
|
| 370 |
-
In the modular implementation, we inherit from `Gemma2ForCausalLM`. This is a powerful simplification—Gemma 2 also supports untied weights and advanced output structures. By simply inheriting the class, we pull in all the necessary machinery for causal generation, while the configuration object (defined elsewhere) ensures the weights remain untied:
|
| 371 |
-
|
| 372 |
-
```py
|
| 373 |
-
class NanoChatForCausalLM(Gemma2ForCausalLM):
|
| 374 |
-
def forward(self, **super_kwargs) -> CausalLMOutputWithPast:
|
| 375 |
-
super().forward(**super_kwargs)
|
| 376 |
-
```
|
| 377 |
-
|
| 378 |
-
###
|
| 379 |
-
|
| 380 |
-
### **ReLU² Activation**
|
| 381 |
-
|
| 382 |
-
The original implementation replaces the standard GELU activation with ReLU², which is simply ReLU squared. This provides a faster alternative without performance loss. In gpt.py, this is hardcoded into the MLP block:
|
| 383 |
-
|
| 384 |
-
```py
|
| 385 |
-
class MLP(nn.Module):
|
| 386 |
-
def __init__(self, config):
|
| 387 |
-
super().__init__()
|
| 388 |
-
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
| 389 |
-
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
| 390 |
-
def forward(self, x):
|
| 391 |
-
x = self.c_fc(x)
|
| 392 |
-
x = F.relu(x).square()
|
| 393 |
-
x = self.c_proj(x)
|
| 394 |
-
return x
|
| 395 |
-
```
|
| 396 |
-
|
| 397 |
-
In the modular file, we see another surprising inheritance: `CLIPMLP`. The CLIP architecture uses a structure that fits our needs perfectly, so we inherit the structural definition from CLIP and let the configuration drive the specific activation function (ReLU2):
|
| 398 |
-
|
| 399 |
-
```py
|
| 400 |
-
class NanoChatMLP(CLIPMLP):
|
| 401 |
-
def __init__(self, config):
|
| 402 |
-
super().__init__(config)
|
| 403 |
-
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
| 404 |
-
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
| 405 |
-
```
|
| 406 |
-
|
| 407 |
-
### **Multi-Query Attention (MQA)**
|
| 408 |
-
|
| 409 |
-
NanoChat uses Multi-Query Attention (MQA) to reduce the memory footprint of the KV cache, using 10 query heads but only 4 key/value heads (in the default config).
|
| 410 |
-
|
| 411 |
-
In gpt.py, this logic is handled by passing distinct head counts and relying on PyTorch's functional attention to handle the broadcasting (or explicitly handling it during inference):
|
| 412 |
-
|
| 413 |
-
```py
|
| 414 |
-
class CausalSelfAttention(nn.Module):
|
| 415 |
-
# ...
|
| 416 |
-
def forward(self, x, cos_sin, kv_cache):
|
| 417 |
-
# ...
|
| 418 |
-
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
| 419 |
-
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
| 420 |
-
if kv_cache is None or Tq == Tk:
|
| 421 |
-
# During training (no KV cache), attend as usual with causal attention
|
| 422 |
-
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
| 423 |
-
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
| 424 |
-
elif Tq == 1:
|
| 425 |
-
# During inference but with a single query in this forward pass:
|
| 426 |
-
# The query has to attend to all the keys/values in the cache
|
| 427 |
-
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
| 428 |
-
else:
|
| 429 |
-
# During inference AND we have a chunk of queries in this forward pass:
|
| 430 |
-
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
| 431 |
-
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
| 432 |
-
prefix_len = Tk - Tq
|
| 433 |
-
if prefix_len > 0: # can't be negative but could be zero
|
| 434 |
-
attn_mask[:, :prefix_len] = True
|
| 435 |
-
# Then, causal attention within this chunk
|
| 436 |
-
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
| 437 |
-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
| 438 |
-
# ...
|
| 439 |
-
```
|
| 440 |
-
|
| 441 |
-
###
|
| 442 |
-
|
| 443 |
-
In `modular_nanochat.py`, we don't need to write this logic at all. As seen in the QK Normalization section above, `NanoChatAttention` inherits from `Qwen3Attention`. The Qwen3 implementation is robust and fully supports GQA/MQA out of the box. By using this parent class, we get production-grade attention implementation "for free," allowing us to focus solely on the unique normalizations required by NanoChat.
|
| 444 |
-
|
| 445 |
-
## Conclusion
|
| 446 |
-
|
| 447 |
-
It’s very clear that Andrej Karpathy’s implementation offers 10 times more to learn from than the transformer version which inherits almost entirely from existing models or features. That said, we can still take more away from the inherited modular modeling implementation. Models like Llama, Llama4, Gemma2, Qwen3, and CLIP are all reused to create a genuinely canonical implementation of a transformer.
|
| 448 |
-
|
| 449 |
-
## Use Nanochat in Transformers
|
| 450 |
-
|
| 451 |
-
If you’d like to try out your own nanochat models in `transformers`
|
| 452 |
-
|
| 453 |
-
1. Download the nanochat-d34 checkpoint
|
| 454 |
-
|
| 455 |
-
```
|
| 456 |
-
hf download karpathy/nanochat-d34 --local-dir nanochat-d34
|
| 457 |
-
```
|
| 458 |
-
|
| 459 |
-
2. Convert the checkpoint to transformers format
|
| 460 |
-
|
| 461 |
-
```
|
| 462 |
-
uv run \
|
| 463 |
-
--with "transformers @ git+https://github.com/huggingface/transformers.git@nanochat-implementation" \
|
| 464 |
-
--with "tiktoken>=0.12.0" \
|
| 465 |
-
https://raw.githubusercontent.com/huggingface/transformers/nanochat-implementation/src/transformers/models/nanochat/convert_nanochat_checkpoints.py \
|
| 466 |
-
--input_dir ./nanochat-d34 \
|
| 467 |
-
--output_dir ./nanochat-d3-hf
|
| 468 |
-
```
|
| 469 |
-
|
| 470 |
-
3. (optional) Upload the checkpoint to the Hugging Face Hub
|
| 471 |
-
|
| 472 |
-
```
|
| 473 |
-
hf upload <username>/nanochat-d34 nanochat-d34
|
| 474 |
-
```
|
| 475 |
-
|
| 476 |
-
4. Test the model
|
| 477 |
-
|
| 478 |
-
```py
|
| 479 |
-
import torch
|
| 480 |
-
from transformers import AutoTokenizer, NanoChatForCausalLM
|
| 481 |
-
|
| 482 |
-
tokenizer = AutoTokenizer.from_pretrained("./nanochat-d3-hf")
|
| 483 |
-
model = NanoChatForCausalLM.from_pretrained("./nanochat-d3-hf")
|
| 484 |
-
|
| 485 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 486 |
-
model = model.to(device)
|
| 487 |
-
|
| 488 |
-
prompt = "Hello, how are you?"
|
| 489 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 490 |
-
inputs.pop("token_type_ids", None)
|
| 491 |
-
outputs = model.generate(**inputs, max_new_tokens=100)
|
| 492 |
-
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 493 |
-
```
|
| 494 |
-
|
| 495 |
-
## Notebooks
|
| 496 |
-
|
| 497 |
-
If you want to train with these models, you can use these colab notebooks:
|
| 498 |
-
|
| 499 |
-
- [SFT](https://colab.research.google.com/#fileId=https%3A//huggingface.co/datasets/nanochat-students/notebooks/blob/main/sft.ipynb)
|
| 500 |
-
- [GRPO](https://colab.research.google.com/#fileId=https%3A//huggingface.co/datasets/nanochat-students/notebooks/blob/main/grpo.ipynb)
|
|
|
|
| 28 |
**[Try the live demo & documentation →](https://huggingface.co/spaces/tfrere/research-article-template)**
|
| 29 |
|
| 30 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/src/components/Hero.astro
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
---
|
| 2 |
import HtmlEmbed from "./HtmlEmbed.astro";
|
|
|
|
| 3 |
|
| 4 |
interface Props {
|
| 5 |
title: string; // may contain HTML (e.g., <br/>)
|
|
@@ -98,7 +99,7 @@ const pdfFilename = `${slugify(pdfBase)}.pdf`;
|
|
| 98 |
<section class="hero">
|
| 99 |
<h1 class="hero-title" set:html={title} />
|
| 100 |
<div class="hero-banner">
|
| 101 |
-
<
|
| 102 |
{description && <p class="hero-desc">{description}</p>}
|
| 103 |
</div>
|
| 104 |
</section>
|
|
|
|
| 1 |
---
|
| 2 |
import HtmlEmbed from "./HtmlEmbed.astro";
|
| 3 |
+
import bannerImage from "../content/assets/image/nanochat-banner.png";
|
| 4 |
|
| 5 |
interface Props {
|
| 6 |
title: string; // may contain HTML (e.g., <br/>)
|
|
|
|
| 99 |
<section class="hero">
|
| 100 |
<h1 class="hero-title" set:html={title} />
|
| 101 |
<div class="hero-banner">
|
| 102 |
+
<img src={bannerImage.src} alt="Banner" style="width: 100%; max-width: 980px;" />
|
| 103 |
{description && <p class="hero-desc">{description}</p>}
|
| 104 |
</div>
|
| 105 |
</section>
|
app/src/content/article.mdx
CHANGED
|
@@ -1,19 +1,25 @@
|
|
| 1 |
---
|
| 2 |
-
title: "
|
| 3 |
-
subtitle: "
|
| 4 |
-
description: "
|
| 5 |
authors:
|
| 6 |
-
- name: "
|
| 7 |
-
url: "https://huggingface.co/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
affiliations: [1]
|
| 9 |
affiliations:
|
| 10 |
- name: "Hugging Face"
|
| 11 |
url: "https://huggingface.co"
|
| 12 |
-
published: "
|
| 13 |
doi: 10.1234/abcd.efgh
|
| 14 |
licence: >
|
| 15 |
Diagrams and text are licensed under <a href="https://creativecommons.org/licenses/by/4.0/" target="_blank" rel="noopener noreferrer">CC‑BY 4.0</a> with the source available on <a href="https://huggingface.co/spaces/tfrere/research-article-template" target="_blank" rel="noopener noreferrer">Hugging Face</a>, unless noted otherwise.
|
| 16 |
-
Figures reused from other sources are excluded and marked in their captions (
|
| 17 |
tags:
|
| 18 |
- research
|
| 19 |
- template
|
|
@@ -22,36 +28,553 @@ pdfProOnly: false
|
|
| 22 |
showPdf: true
|
| 23 |
---
|
| 24 |
|
| 25 |
-
import
|
| 26 |
-
|
| 27 |
-
import
|
| 28 |
-
import
|
| 29 |
-
import
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
|
| 39 |
-
|
| 40 |
|
| 41 |
-
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
|
| 47 |
-
|
| 48 |
|
| 49 |
-
|
| 50 |
|
| 51 |
-
|
| 52 |
|
| 53 |
-
|
| 54 |
|
| 55 |
-
<
|
| 56 |
|
|
|
|
| 57 |
|
|
|
|
| 1 |
---
|
| 2 |
+
title: "Porting nanochat to Transformers: an AI modeling history lesson"
|
| 3 |
+
subtitle: "There is a lot t learn about ML from nanochat, and even more to learn about the history of the transformer architecture."
|
| 4 |
+
description: "**tldr:** There is a lot t learn about ML from nanochat, and even more to learn about the history of the transformer architecture."
|
| 5 |
authors:
|
| 6 |
+
- name: "Ben Burtenshaw"
|
| 7 |
+
url: "https://huggingface.co/burtenshaw"
|
| 8 |
+
affiliations: [1]
|
| 9 |
+
- name: "Sergio Paniego"
|
| 10 |
+
url: "https://huggingface.co/sergiopaniego"
|
| 11 |
+
affiliations: [1]
|
| 12 |
+
- name: "Anton Vlasjuk"
|
| 13 |
+
url: "https://huggingface.co/AntonV"
|
| 14 |
affiliations: [1]
|
| 15 |
affiliations:
|
| 16 |
- name: "Hugging Face"
|
| 17 |
url: "https://huggingface.co"
|
| 18 |
+
published: "Dec. 01, 2025"
|
| 19 |
doi: 10.1234/abcd.efgh
|
| 20 |
licence: >
|
| 21 |
Diagrams and text are licensed under <a href="https://creativecommons.org/licenses/by/4.0/" target="_blank" rel="noopener noreferrer">CC‑BY 4.0</a> with the source available on <a href="https://huggingface.co/spaces/tfrere/research-article-template" target="_blank" rel="noopener noreferrer">Hugging Face</a>, unless noted otherwise.
|
| 22 |
+
Figures reused from other sources are excluded and marked in their captions ("Figure from …").
|
| 23 |
tags:
|
| 24 |
- research
|
| 25 |
- template
|
|
|
|
| 28 |
showPdf: true
|
| 29 |
---
|
| 30 |
|
| 31 |
+
import Sidenote from '../../components/Sidenote.astro'
|
| 32 |
+
|
| 33 |
+
import GRPO from "./chapters/grpo.mdx";
|
| 34 |
+
import SFT from "./chapters/sft.mdx";
|
| 35 |
+
import Inference from "./chapters/inference.mdx";
|
| 36 |
+
|
| 37 |
+
<Sidenote>
|
| 38 |
+
|
| 39 |
+
The [nanochat-students](https://huggingface.co/nanochat-students) organization on Hugging Face hosts community models and discussions.
|
| 40 |
+
|
| 41 |
+
</Sidenote>
|
| 42 |
+
|
| 43 |
+
Recently I was working on helping students of the nanochat project to share their models and discuss their learning on Hugging Face. In the process, I thought it would be useful if the model was integrated into the `transformers` library. This would allow others to use their nanochat models for inference in loads of downstream libraries like vLLM for inference or TRL for post-training.
|
| 44 |
+
|
| 45 |
+
<Sidenote>
|
| 46 |
+
|
| 47 |
+
[vLLM](https://docs.vllm.ai/) provides high-throughput inference, while [TRL](https://huggingface.co/docs/trl/index) offers tools for reinforcement learning from human feedback (RLHF) and other post-training methods.
|
| 48 |
+
|
| 49 |
+
</Sidenote>
|
| 50 |
+
|
| 51 |
+
You can now use nanochat models in transformers and tap into all those educational gains across the ecosystem. But along the way, we uncovered a further treasure trove of education about how canonical models relate to each other, and the components they share. We received the lesson from the simple teacher of class inheritance and transformers modular philosophy.
|
| 52 |
+
|
| 53 |
+
<Sidenote>
|
| 54 |
+
|
| 55 |
+
Learn more about how transformers achieves modularity in the [modular transformers guide](https://huggingface.co/docs/transformers/v4.48.0/modular_transformers).
|
| 56 |
+
|
| 57 |
+
</Sidenote>
|
| 58 |
+
|
| 59 |
+
Now, let's tuck into this deep dive on how NanoChat relates the lineage of transformer architectures.
|
| 60 |
+
|
| 61 |
+
## What is `nanochat`?
|
| 62 |
+
|
| 63 |
+
<Sidenote>
|
| 64 |
+
|
| 65 |
+
See Karpathy's [original announcement](https://x.com/karpathy/status/1977755427569111362) and the [nanochat repository](https://github.com/karpathy/nanochat) on GitHub.
|
| 66 |
+
|
| 67 |
+
</Sidenote>
|
| 68 |
+
|
| 69 |
+
On October 13th 2025, Andrej Karpathy unceremoniously dropped the nanochat repo into the unsuspecting AI world. To hype seekers, this was just a small and pretty average LLM. To ML devotees, this was nirvana. A raw unadulterated chance to tinker, fiddle, and play with a transformer model defined in pure pytorch. Nothing was hidden away in fancy `torch` methods or inherited from complex class structures. It was all there in a simple file.
|
| 70 |
+
|
| 71 |
+

|
| 72 |
+
|
| 73 |
+
<Sidenote>
|
| 74 |
+
|
| 75 |
+
The core libraries Karpathy avoided: [transformers](https://huggingface.co/docs/transformers/index), [tokenizers](https://huggingface.co/docs/tokenizers/index), [datasets](https://huggingface.co/docs/datasets/index), [trl](https://huggingface.co/docs/trl/index), and many dependencies. All for the sake of our learning!
|
| 76 |
+
|
| 77 |
+
</Sidenote>
|
| 78 |
+
|
| 79 |
+
Karpathy had painstakingly implemented an end-to-end build of an LLM system without the use of most major libraries. Even though in real world situations most rely on transformers, tokenizers, datasets, trl, etc. This back to basics approach gives us the chance to genuinely learn and understand something from the ground up.
|
| 80 |
+
|
| 81 |
+
Personally, I found the process to be one of the most educational I can remember.
|
| 82 |
+
|
| 83 |
+
## What is `transformers` and how is it educational?
|
| 84 |
+
|
| 85 |
+
<Sidenote>
|
| 86 |
+
|
| 87 |
+
The [transformers documentation](https://huggingface.co/docs/transformers/index) covers everything from quickstart guides to advanced model internals.
|
| 88 |
+
|
| 89 |
+
</Sidenote>
|
| 90 |
+
|
| 91 |
+
Most of know the `transformers` library as the backbone of modern machine learning, but if we dig a little deeper, it's a powerful piece of education.
|
| 92 |
+
|
| 93 |
+
If you don't know… transformers is the de facto implementation of modern AI models that bear the same name; 'transformers' like models in GPT, DeepSeek, Claude, series. `transformers` is a special project because it contains the implementation of all major open model architecture and those model architectures are modularized to reuse functionality from each other.
|
| 94 |
+
|
| 95 |
+
<Sidenote>
|
| 96 |
+
|
| 97 |
+
Explore the [model hub](https://huggingface.co/models) to see thousands of models built on these shared architectures.
|
| 98 |
+
|
| 99 |
+
</Sidenote>
|
| 100 |
+
|
| 101 |
+
In general, scientists at AI research labs design, implement, and train their models in their framework of choice, be that torch, JAX, etc. When they come to share their open model with the community, they will open a PR on transformers and refactor their code to use relevant modules.
|
| 102 |
+
|
| 103 |
+
Because `transformers` contain most major model implementations, researchers have to inherent model architecture attributes from other canonical models. This is in every sense a 'single source of truth'.
|
| 104 |
+
|
| 105 |
+
<Sidenote>
|
| 106 |
+
|
| 107 |
+
See nanochat's [RMSNorm implementation](https://github.com/huggingface/transformers/blob/9f5b2d1b8995daa539b757e28c337e36408055e6/src/transformers/models/nanochat/modular_nanochat.py#L44) in the transformers codebase.
|
| 108 |
+
|
| 109 |
+
</Sidenote>
|
| 110 |
+
|
| 111 |
+
This practical feature of the library has an amazingly educational quality to it. We can read a model implementation as a series of references to other usages of those architectural features. For example, when one model uses a certain type of RMSNorm, we can plainly see that it is the same implementation as another model because it inherits that class entirely. For example, check out nanochat's RMSNorm:
|
| 112 |
+
|
| 113 |
+
```py
|
| 114 |
+
class NanoChatRMSNorm(Llama4TextL2Norm):
|
| 115 |
+
pass
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
The `transformers` library then converts the `modular_*` implementation into a `modeling_*` implementation, which contains the complete `torch` native implementation:
|
| 119 |
+
|
| 120 |
+
```py
|
| 121 |
+
class NanoChatRMSNorm(torch.nn.Module):
|
| 122 |
+
def __init__(self, eps: float = 1e-6):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.eps = eps
|
| 125 |
+
|
| 126 |
+
def _norm(self, x):
|
| 127 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
return self._norm(x.float()).type_as(x)
|
| 131 |
+
|
| 132 |
+
def extra_repr(self):
|
| 133 |
+
return f"eps={self.eps}"
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
If we review a model in `transformers`, we can review both sides and learn from the math and literature of the model's implementation. Due to the educational nature of nanochat, I thought that it was a perfect opportunity to explore this aspect of transformers and share what I learnt with students.
|
| 137 |
+
|
| 138 |
+
## Why do we need nanochat in `transformers`?
|
| 139 |
+
|
| 140 |
+
It might seem counterintuitive to support an educational model like nanochat in a production grade library like `transformers`. After all, we can see from nanochat's benchmark scores that it does not rival state of the art models like Qwen3, SmolLM3, Gemma3, or [Olmo3](https://huggingface.co/allenai/Olmo-3-32B-Think). In fact, that's the reason we think nanochat should be in `transformers`. Here's what the community gains from its inclusion:
|
| 141 |
+
|
| 142 |
+
- `transformers` as a single source of truth teaches us about `nanochat`'s lineage.
|
| 143 |
+
- we can use the `nanochat` model in other libraries.
|
| 144 |
+
- save money by reusing nanochat checkpoints for fine-tuning.
|
| 145 |
+
- compare nanochat fine-tuning implementation with other open model checkpoints.
|
| 146 |
+
|
| 147 |
+
Firstly, as mentioned above `transformers` teaches us about the modeling conventions that Karpathy uses from other canonical implementations.
|
| 148 |
+
|
| 149 |
+
Secondly, because transformers is a standard within the ecosystem, it unlocks more downstream learning in post training libraries, quantisation tools, inference libraries, and device integrations. In practical terms, here are some examples nanochat students could learn on top of `transformers`:
|
| 150 |
+
|
| 151 |
+
<Sidenote>
|
| 152 |
+
|
| 153 |
+
Learn about [model quantization](https://huggingface.co/docs/transformers/en/quantization/overview) to reduce model size and memory usage.
|
| 154 |
+
|
| 155 |
+
</Sidenote>
|
| 156 |
+
|
| 157 |
+
- Quantize models in llama.cpp ($0)
|
| 158 |
+
- Integrate models into the browser and WebGPU ($0)
|
| 159 |
+
- SFT training in TRL/torch on Google Colab ($0)
|
| 160 |
+
- RL training TRL/torch on Google Colab ($0 \- $9)
|
| 161 |
+
- Agentic RL in TRL on Google Colab ($0 \- $9)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
Finally, training AI models is expensive. Running the nanochat `speedrun.sh` costs between $200 and $2k depending on the model size we use. Which is little compared to the millions of dollars invested by frontier labs. But that is still a significant sum for students, who always learn best by taking a few chances to fail and build experience.
|
| 165 |
+
|
| 166 |
+
<Sidenote>
|
| 167 |
+
|
| 168 |
+
The [speedrun.sh](https://github.com/karpathy/nanochat/blob/master/speedrun.sh) script in nanochat benchmarks training costs across different configurations.
|
| 169 |
+
|
| 170 |
+
</Sidenote>
|
| 171 |
+
|
| 172 |
+
In short, let's unlock more opportunities for education\!
|
| 173 |
+
|
| 174 |
+
## The nanochat architecture
|
| 175 |
+
|
| 176 |
+
<Sidenote>
|
| 177 |
+
|
| 178 |
+
The original [gpt.py](https://github.com/karpathy/nanochat/blob/master/nanochat/gpt.py) implementation is just 291 lines of pure PyTorch.
|
| 179 |
+
|
| 180 |
+
</Sidenote>
|
| 181 |
+
|
| 182 |
+
As described by Karpathy, nanochat uses an archetypal architecture that is common across the field, which makes it an excellent choice for an educational resource because folk get to learn from what works. The core model implementation demonstrates modern transformer architecture, with every design decision documented and justified.
|
| 183 |
+
|
| 184 |
+
The configuration uses a single complexity slider: depth. Set `--depth=20` and everything else automatically adjusts. Model dimension equals depth × 64 (20 layers → 1,280 dimensions). Number of attention heads equals depth ÷ 2 (10 heads). Head dimension is fixed at 128\. This "aspect ratio philosophy" simplifies scaling. So if you want a more capable model or have a bigger budget. Just increase depth to 26 ($300 budget) or 30 ($1,000 budget).
|
| 185 |
+
|
| 186 |
+
The architecture incorporates five key improvements over vanilla transformers. Let's work through the components of this architecture and compare them across implementation:
|
| 187 |
+
|
| 188 |
+
#### Forward pass based on the Llama Architecture
|
| 189 |
+
|
| 190 |
+
<Sidenote>
|
| 191 |
+
|
| 192 |
+
See the [Llama model documentation](https://huggingface.co/docs/transformers/en/model_doc/llama) for the full architecture details.
|
| 193 |
+
|
| 194 |
+
</Sidenote>
|
| 195 |
+
|
| 196 |
+
The forward pass in nanochat handles both training and generation. We can simply read that the input `x` is embedded and then updated by each layer then the head. During training, a loss is calculated and returned instead of the logits themselves.
|
| 197 |
+
|
| 198 |
+
```py
|
| 199 |
+
def forward(self, x, targets=None, loss_reduction='mean'):
|
| 200 |
+
x = self.token_emb(x)
|
| 201 |
+
for layer in self.layers:
|
| 202 |
+
x = layer(x)
|
| 203 |
+
x = self.ln_f(x)
|
| 204 |
+
logits = self.lm_head(x)
|
| 205 |
+
|
| 206 |
+
if targets is not None:
|
| 207 |
+
loss = F.cross_entropy(
|
| 208 |
+
logits.view(-1, self.vocab_size),
|
| 209 |
+
targets.view(-1),
|
| 210 |
+
ignore_index=-1,
|
| 211 |
+
reduction=loss_reduction
|
| 212 |
+
)
|
| 213 |
+
return loss
|
| 214 |
+
return logits
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
By returning loss directly when targets are provided, the training loop becomes trivial. No separate loss computation, no manual masking logic—just `loss = model(inputs, targets)` followed by `loss.backward()`.
|
| 218 |
+
|
| 219 |
+
<Sidenote>
|
| 220 |
+
|
| 221 |
+
The [BaseModelOutputWithPast](https://huggingface.co/docs/transformers/en/main_classes/output#transformers.modeling_outputs.BaseModelOutputWithPast) class standardizes model outputs across the ecosystem.
|
| 222 |
+
|
| 223 |
+
</Sidenote>
|
| 224 |
+
|
| 225 |
+
`transformers` has to make things a bit more complex to facilitate the downstream ecosystem that uses logits in a broad spectrum of ways. Therefore, loss calculation is dealt with in training-specific code, and the `forward` function returns `BaseModelOutputWithPast`.
|
| 226 |
+
|
| 227 |
+
```py
|
| 228 |
+
class NanoChatModel(LlamaModel):
|
| 229 |
+
def __init__(self, config: NanoChatConfig):
|
| 230 |
+
super().__init__(config)
|
| 231 |
+
|
| 232 |
+
self.initial_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
| 233 |
+
self.norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
| 234 |
+
|
| 235 |
+
def forward(
|
| 236 |
+
self,
|
| 237 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 238 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 239 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 240 |
+
past_key_values: Optional[Cache] = None,
|
| 241 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 242 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 243 |
+
use_cache: Optional[bool] = None,
|
| 244 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 245 |
+
) -> BaseModelOutputWithPast:
|
| 246 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 247 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 248 |
+
|
| 249 |
+
if inputs_embeds is None:
|
| 250 |
+
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
| 251 |
+
|
| 252 |
+
if use_cache and past_key_values is None:
|
| 253 |
+
past_key_values = DynamicCache(config=self.config)
|
| 254 |
+
|
| 255 |
+
if cache_position is None:
|
| 256 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 257 |
+
cache_position: torch.Tensor = torch.arange(
|
| 258 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
if position_ids is None:
|
| 262 |
+
position_ids = cache_position.unsqueeze(0)
|
| 263 |
+
|
| 264 |
+
causal_mask = create_causal_mask(
|
| 265 |
+
config=self.config,
|
| 266 |
+
input_embeds=inputs_embeds,
|
| 267 |
+
attention_mask=attention_mask,
|
| 268 |
+
cache_position=cache_position,
|
| 269 |
+
past_key_values=past_key_values,
|
| 270 |
+
position_ids=position_ids,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
hidden_states = inputs_embeds
|
| 274 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
|
| 275 |
+
|
| 276 |
+
hidden_states = self.initial_norm(hidden_states) # Additional norm before the layers
|
| 277 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 278 |
+
hidden_states = decoder_layer(
|
| 279 |
+
hidden_states,
|
| 280 |
+
attention_mask=causal_mask,
|
| 281 |
+
position_embeddings=position_embeddings,
|
| 282 |
+
position_ids=position_ids,
|
| 283 |
+
past_key_values=past_key_values,
|
| 284 |
+
cache_position=cache_position,
|
| 285 |
+
**kwargs,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
hidden_states = self.norm(hidden_states)
|
| 289 |
+
return BaseModelOutputWithPast(
|
| 290 |
+
last_hidden_state=hidden_states,
|
| 291 |
+
past_key_values=past_key_values,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
#### Rotary Position Embeddings (RoPE)
|
| 297 |
+
|
| 298 |
+
<Sidenote>
|
| 299 |
+
|
| 300 |
+
The [RoFormer paper](https://arxiv.org/abs/2104.09864) introduced RoPE, now used in Llama, Mistral, and many other modern LLMs.
|
| 301 |
+
|
| 302 |
+
</Sidenote>
|
| 303 |
+
|
| 304 |
+
Rotary Position Embeddings (RoPE) replace learned positional encodings by rotating query and key vectors using precomputed sin/cos frequencies:
|
| 305 |
+
|
| 306 |
+
```py
|
| 307 |
+
def apply_rope(x, cos, sin):
|
| 308 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 309 |
+
y1 = x1 * cos - x2 * sin
|
| 310 |
+
y2 = x1 * sin + x2 * cos
|
| 311 |
+
return torch.stack([y1, y2], dim=-1).flatten(-2)
|
| 312 |
+
```
|
| 313 |
+
|
| 314 |
+
In transformers, the rotary embeddings are implemented like so:
|
| 315 |
+
|
| 316 |
+
```py
|
| 317 |
+
from ..llama.modeling_llama import (
|
| 318 |
+
LlamaDecoderLayer,
|
| 319 |
+
LlamaModel,
|
| 320 |
+
LlamaPreTrainedModel,
|
| 321 |
+
LlamaRotaryEmbedding,
|
| 322 |
+
apply_rotary_pos_emb,
|
| 323 |
+
eager_attention_forward,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
|
| 328 |
+
pass
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def rotate_half(x):
|
| 332 |
+
"""Rotates half the hidden dims of the input with flipped signs for NanoChat."""
|
| 333 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 334 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 335 |
+
return torch.cat((x2, -x1), dim=-1)
|
| 336 |
+
```
|
| 337 |
+
|
| 338 |
+
`NanoChatRotaryEmbedding` almost entirely inherits from the original Llama series, except for a sign inversion in `rotate_half`**.**
|
| 339 |
+
|
| 340 |
+
### **QK Normalization**
|
| 341 |
+
|
| 342 |
+
<Sidenote>
|
| 343 |
+
|
| 344 |
+
QK normalization was popularized by [Llama 4](https://huggingface.co/docs/transformers/en/model_doc/llama4) and helps stabilize attention scores during training.
|
| 345 |
+
|
| 346 |
+
</Sidenote>
|
| 347 |
+
|
| 348 |
+
NanoChat applies RMSNorm to queries and keys before computing attention to stabilize training.
|
| 349 |
+
|
| 350 |
+
In the original gpt.py, this is achieved via a functional norm helper applied directly inside the attention forward pass:
|
| 351 |
+
|
| 352 |
+
```py
|
| 353 |
+
def norm(x):
|
| 354 |
+
# Purely functional rmsnorm with no learnable params
|
| 355 |
+
return F.rms_norm(x, (x.size(-1),))
|
| 356 |
+
|
| 357 |
+
class CausalSelfAttention(nn.Module):
|
| 358 |
+
...
|
| 359 |
+
def forward(self, x, cos_sin, kv_cache):
|
| 360 |
+
B, T, C = x.size()
|
| 361 |
+
|
| 362 |
+
# Project the input to get queries, keys, and values
|
| 363 |
+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
| 364 |
+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 365 |
+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 366 |
+
|
| 367 |
+
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
| 368 |
+
cos, sin = cos_sin
|
| 369 |
+
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
| 370 |
+
q, k = norm(q), norm(k) # QK norm
|
| 371 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
| 372 |
+
...
|
| 373 |
+
```
|
| 374 |
+
|
| 375 |
+
<Sidenote>
|
| 376 |
+
|
| 377 |
+
[Qwen3](https://huggingface.co/docs/transformers/en/model_doc/qwen3) provides a robust attention implementation that nanochat extends with QK normalization.
|
| 378 |
+
|
| 379 |
+
</Sidenote>
|
| 380 |
+
|
| 381 |
+
In the modular transformers implementation, we see a fascinating mix of lineages. The `NanoChatRMSNorm` inherits directly from `Llama4TextL2Norm`, while the attention mechanism inherits from `Qwen3Attention`. We simply inject the QK normalization into the Qwen3 logic:
|
| 382 |
+
|
| 383 |
+
```py
|
| 384 |
+
|
| 385 |
+
class NanoChatRMSNorm(Llama4TextL2Norm):
|
| 386 |
+
pass
|
| 387 |
+
|
| 388 |
+
class NanoChatAttention(Qwen3Attention):
|
| 389 |
+
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
| 390 |
+
super().__init__(config, layer_idx)
|
| 391 |
+
del self.sliding_window
|
| 392 |
+
del self.layer_type
|
| 393 |
+
|
| 394 |
+
self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
| 395 |
+
self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
| 396 |
+
|
| 397 |
+
def forward(
|
| 398 |
+
self,
|
| 399 |
+
hidden_states: torch.Tensor,
|
| 400 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 401 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 402 |
+
past_key_values: Optional[Cache] = None,
|
| 403 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 404 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 405 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 406 |
+
input_shape = hidden_states.shape[:-1]
|
| 407 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 408 |
+
|
| 409 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 410 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 411 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 412 |
+
|
| 413 |
+
cos, sin = position_embeddings
|
| 414 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 415 |
+
|
| 416 |
+
# RoPE -> Norm (instead of usual Norm -> RoPE)
|
| 417 |
+
query_states = self.q_norm(query_states)
|
| 418 |
+
key_states = self.k_norm(key_states)
|
| 419 |
+
|
| 420 |
+
if past_key_values is not None:
|
| 421 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 422 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 423 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 424 |
+
|
| 425 |
+
attention_interface: Callable = eager_attention_forward
|
| 426 |
+
if self.config._attn_implementation != "eager":
|
| 427 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 428 |
+
|
| 429 |
+
attn_output, attn_weights = attention_interface(
|
| 430 |
+
self,
|
| 431 |
+
query_states,
|
| 432 |
+
key_states,
|
| 433 |
+
value_states,
|
| 434 |
+
attention_mask,
|
| 435 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 436 |
+
scaling=self.scaling,
|
| 437 |
+
**kwargs,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 441 |
+
attn_output = self.o_proj(attn_output)
|
| 442 |
+
return attn_output, attn_weights
|
| 443 |
+
```
|
| 444 |
+
|
| 445 |
+
### **Untied Weights**
|
| 446 |
+
|
| 447 |
+
<Sidenote>
|
| 448 |
+
|
| 449 |
+
Weight tying between embeddings and the LM head is common but [research shows](https://arxiv.org/abs/1608.05859) untying can improve performance.
|
| 450 |
+
|
| 451 |
+
</Sidenote>
|
| 452 |
+
|
| 453 |
+
Karpathy's implementation deliberately unties the weights between the token embedding and the language model head to provide the model with more flexibility. In gpt.py, these are initialized as two completely separate modules:
|
| 454 |
+
|
| 455 |
+
```py
|
| 456 |
+
class GPT(nn.Module):
|
| 457 |
+
def __init__(self, config):
|
| 458 |
+
super().__init__()
|
| 459 |
+
self.config = config
|
| 460 |
+
self.transformer = nn.ModuleDict({
|
| 461 |
+
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
| 462 |
+
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
| 463 |
+
})
|
| 464 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 465 |
+
# ... (rest of init)
|
| 466 |
+
```
|
| 467 |
+
|
| 468 |
+
<Sidenote>
|
| 469 |
+
|
| 470 |
+
[Gemma 2](https://huggingface.co/docs/transformers/en/model_doc/gemma2) supports both tied and untied weight configurations via the model config.
|
| 471 |
+
|
| 472 |
+
</Sidenote>
|
| 473 |
+
|
| 474 |
+
In the modular implementation, we inherit from `Gemma2ForCausalLM`. This is a powerful simplification—Gemma 2 also supports untied weights and advanced output structures. By simply inheriting the class, we pull in all the necessary machinery for causal generation, while the configuration object (defined elsewhere) ensures the weights remain untied:
|
| 475 |
+
|
| 476 |
+
```py
|
| 477 |
+
class NanoChatForCausalLM(Gemma2ForCausalLM):
|
| 478 |
+
def forward(self, **super_kwargs) -> CausalLMOutputWithPast:
|
| 479 |
+
super().forward(**super_kwargs)
|
| 480 |
+
```
|
| 481 |
+
|
| 482 |
+
### **ReLU² Activation**
|
| 483 |
+
|
| 484 |
+
<Sidenote>
|
| 485 |
+
|
| 486 |
+
The [Primer paper](https://arxiv.org/abs/2109.08668) showed squared ReLU can match or exceed GELU performance with lower compute.
|
| 487 |
+
|
| 488 |
+
</Sidenote>
|
| 489 |
+
|
| 490 |
+
The original implementation replaces the standard GELU activation with ReLU², which is simply ReLU squared. This provides a faster alternative without performance loss. In gpt.py, this is hardcoded into the MLP block:
|
| 491 |
+
|
| 492 |
+
```py
|
| 493 |
+
class MLP(nn.Module):
|
| 494 |
+
def __init__(self, config):
|
| 495 |
+
super().__init__()
|
| 496 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
| 497 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
| 498 |
+
def forward(self, x):
|
| 499 |
+
x = self.c_fc(x)
|
| 500 |
+
x = F.relu(x).square()
|
| 501 |
+
x = self.c_proj(x)
|
| 502 |
+
return x
|
| 503 |
+
```
|
| 504 |
+
|
| 505 |
+
<Sidenote>
|
| 506 |
+
|
| 507 |
+
[CLIP](https://huggingface.co/docs/transformers/en/model_doc/clip) provides a clean MLP structure that nanochat extends with the ReLU² activation.
|
| 508 |
+
|
| 509 |
+
</Sidenote>
|
| 510 |
+
|
| 511 |
+
In the modular file, we see another surprising inheritance: `CLIPMLP`. The CLIP architecture uses a structure that fits our needs perfectly, so we inherit the structural definition from CLIP and let the configuration drive the specific activation function (ReLU2):
|
| 512 |
+
|
| 513 |
+
```py
|
| 514 |
+
class NanoChatMLP(CLIPMLP):
|
| 515 |
+
def __init__(self, config):
|
| 516 |
+
super().__init__(config)
|
| 517 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
| 518 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
| 519 |
+
```
|
| 520 |
+
|
| 521 |
+
### **Multi-Query Attention (MQA)**
|
| 522 |
+
|
| 523 |
+
<Sidenote>
|
| 524 |
+
|
| 525 |
+
The [GQA paper](https://arxiv.org/abs/2305.13245) explains how grouped-query attention reduces memory while maintaining quality.
|
| 526 |
+
|
| 527 |
+
</Sidenote>
|
| 528 |
+
|
| 529 |
+
NanoChat uses Multi-Query Attention (MQA) to reduce the memory footprint of the KV cache, using 10 query heads but only 4 key/value heads (in the default config).
|
| 530 |
+
|
| 531 |
+
<Sidenote>
|
| 532 |
|
| 533 |
+
PyTorch's [scaled_dot_product_attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) handles GQA broadcasting automatically via `enable_gqa`.
|
| 534 |
|
| 535 |
+
</Sidenote>
|
| 536 |
|
| 537 |
+
In gpt.py, this logic is handled by passing distinct head counts and relying on PyTorch's functional attention to handle the broadcasting (or explicitly handling it during inference):
|
| 538 |
|
| 539 |
+
```py
|
| 540 |
+
class CausalSelfAttention(nn.Module):
|
| 541 |
+
# ...
|
| 542 |
+
def forward(self, x, cos_sin, kv_cache):
|
| 543 |
+
# ...
|
| 544 |
+
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
| 545 |
+
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
| 546 |
+
if kv_cache is None or Tq == Tk:
|
| 547 |
+
# During training (no KV cache), attend as usual with causal attention
|
| 548 |
+
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
| 549 |
+
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
| 550 |
+
elif Tq == 1:
|
| 551 |
+
# During inference but with a single query in this forward pass:
|
| 552 |
+
# The query has to attend to all the keys/values in the cache
|
| 553 |
+
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
| 554 |
+
else:
|
| 555 |
+
# During inference AND we have a chunk of queries in this forward pass:
|
| 556 |
+
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
| 557 |
+
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
| 558 |
+
prefix_len = Tk - Tq
|
| 559 |
+
if prefix_len > 0: # can't be negative but could be zero
|
| 560 |
+
attn_mask[:, :prefix_len] = True
|
| 561 |
+
# Then, causal attention within this chunk
|
| 562 |
+
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
| 563 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
| 564 |
+
# ...
|
| 565 |
+
```
|
| 566 |
|
| 567 |
+
In `modular_nanochat.py`, we don't need to write this logic at all. As seen in the QK Normalization section above, `NanoChatAttention` inherits from `Qwen3Attention`. The Qwen3 implementation is robust and fully supports GQA/MQA out of the box. By using this parent class, we get production-grade attention implementation "for free," allowing us to focus solely on the unique normalizations required by NanoChat.
|
| 568 |
|
| 569 |
+
## Conclusion
|
| 570 |
|
| 571 |
+
It's very clear that Andrej Karpathy's implementation offers 10 times more to learn from than the transformer version which inherits almost entirely from existing models or features. That said, we can still take more away from the inherited modular modeling implementation. Models like Llama, Llama4, Gemma2, Qwen3, and CLIP are all reused to create a genuinely canonical implementation of a transformer.
|
| 572 |
|
| 573 |
+
# Hands-on Tutorial
|
| 574 |
|
| 575 |
+
Ok. Let's cut the philosphy and see what we can do with `nanochat` in transformers.
|
| 576 |
|
| 577 |
+
<Inference />
|
| 578 |
|
| 579 |
+
<SFT />
|
| 580 |
|
app/src/content/assets/image/nanochat-banner.png
ADDED
|
Git LFS Details
|
app/src/content/assets/image/tweet.png
ADDED
|
Git LFS Details
|
app/src/content/chapters/grpo.mdx
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [BONUS 3] Group Relative Policy Optimization in `torch`
|
| 2 |
+
|
| 3 |
+
- [GRPO](https://colab.research.google.com/#fileId=https%3A//huggingface.co/datasets/nanochat-students/notebooks/blob/main/grpo.ipynb)
|
| 4 |
+
|
| 5 |
+
This chapter demonstrates Group Relative Policy Optimization (GRPO) training for the NanoChat model—a reinforcement learning approach for improving model responses based on reward signals.
|
| 6 |
+
|
| 7 |
+
## Import model and tokenizer
|
| 8 |
+
|
| 9 |
+
```python
|
| 10 |
+
import torch
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
model_id = "karpathy/nanochat-d32"
|
| 17 |
+
revision = "refs/pr/1"
|
| 18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
|
| 22 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 23 |
+
model_id,
|
| 24 |
+
revision=revision,
|
| 25 |
+
torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32,
|
| 26 |
+
).to(device)
|
| 27 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 28 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Setup LoRA
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
from peft import LoraConfig, get_peft_model
|
| 35 |
+
|
| 36 |
+
lora_config = LoraConfig(
|
| 37 |
+
r=1,
|
| 38 |
+
lora_alpha=2,
|
| 39 |
+
lora_dropout=0.00,
|
| 40 |
+
task_type="CAUSAL_LM",
|
| 41 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
model = get_peft_model(model, lora_config)
|
| 45 |
+
model.print_trainable_parameters()
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
trainable params: 1,179,648 || all params: 1,880,227,840 || trainable%: 0.0627
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Demo the model
|
| 53 |
+
|
| 54 |
+
Test with a plain autoregressive prompt:
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
print("=" * 80)
|
| 58 |
+
print("TEST 1: Plain Autoregressive Prompt")
|
| 59 |
+
print("=" * 80)
|
| 60 |
+
prompt = "The Eiffel Tower stands in Paris and"
|
| 61 |
+
test_inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
test_outputs = model.generate(
|
| 66 |
+
**test_inputs,
|
| 67 |
+
max_new_tokens=64,
|
| 68 |
+
do_sample=False,
|
| 69 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
generated_tokens = test_outputs[0, test_inputs["input_ids"].shape[1] :]
|
| 73 |
+
print(f"Prompt: {prompt}")
|
| 74 |
+
print(f"\nGenerated: {tokenizer.decode(generated_tokens, skip_special_tokens=True)}")
|
| 75 |
+
print("=" * 80)
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
```
|
| 79 |
+
================================================================================
|
| 80 |
+
TEST 1: Plain Autoregressive Prompt
|
| 81 |
+
================================================================================
|
| 82 |
+
Prompt: The Eiffel Tower stands in Paris and
|
| 83 |
+
|
| 84 |
+
Generated: is one of the most famous landmarks in the world. It is located on the Champ de Mars in the heart of the city. The tower was built for the 1889 World's Fair. It was designed by the French engineer Gustave Eiffel and took 2 years to build. The Eiffel Tower stands 324 meters
|
| 85 |
+
================================================================================
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
And with the chat template:
|
| 89 |
+
|
| 90 |
+
```python
|
| 91 |
+
print("=" * 80)
|
| 92 |
+
print("TEST 2: Chat Template")
|
| 93 |
+
print("="*80)
|
| 94 |
+
conversation = [
|
| 95 |
+
{"role": "user", "content": "What is the capital of France?"},
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
inputs = tokenizer.apply_chat_template(
|
| 99 |
+
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
| 100 |
+
).to(device)
|
| 101 |
+
|
| 102 |
+
print(f"Formatted prompt: {tokenizer.decode(inputs['input_ids'][0])}")
|
| 103 |
+
print(f"Input IDs: {inputs['input_ids'][0].tolist()}")
|
| 104 |
+
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
outputs = model.generate(
|
| 107 |
+
**inputs,
|
| 108 |
+
max_new_tokens=64,
|
| 109 |
+
do_sample=False
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
|
| 113 |
+
print(f"\nGenerated: {tokenizer.decode(generated_tokens)}")
|
| 114 |
+
print("=" * 80)
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
```
|
| 118 |
+
================================================================================
|
| 119 |
+
TEST 2: Chat Template
|
| 120 |
+
================================================================================
|
| 121 |
+
Formatted prompt: <|bos|><|user_start|>What is the capital of France?<|user_end|><|assistant_start|>
|
| 122 |
+
Input IDs: [65527, 65528, 1442, 309, 261, 3429, 281, 4215, 63, 65529, 65530]
|
| 123 |
+
|
| 124 |
+
Generated: The capital of France is Paris.<|assistant_end|>
|
| 125 |
+
================================================================================
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## Dataset
|
| 129 |
+
|
| 130 |
+
We use the OpenR1-Math dataset for math reasoning tasks:
|
| 131 |
+
|
| 132 |
+
```python
|
| 133 |
+
raw_dataset = load_dataset("HuggingFaceH4/OpenR1-Math-220k-default-verified", split="train")
|
| 134 |
+
splits = raw_dataset.train_test_split(test_size=0.1, seed=42)
|
| 135 |
+
train_dataset = splits["train"]
|
| 136 |
+
eval_dataset = splits["test"]
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
## Training Configuration
|
| 140 |
+
|
| 141 |
+
```python
|
| 142 |
+
max_train_steps = 50
|
| 143 |
+
prompt_batch_size = 1
|
| 144 |
+
num_generations = 4
|
| 145 |
+
max_new_tokens = 128
|
| 146 |
+
temperature = 1.0
|
| 147 |
+
top_k = 50
|
| 148 |
+
learning_rate = 5e-6
|
| 149 |
+
weight_decay = 0.0
|
| 150 |
+
epsilon = 0.2
|
| 151 |
+
gradient_accumulation_steps = 1
|
| 152 |
+
warmup_ratio = 0.1
|
| 153 |
+
logging_frequency = 5
|
| 154 |
+
max_train_samples = 1000
|
| 155 |
+
max_eval_samples = 100
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## Reward Functions
|
| 159 |
+
|
| 160 |
+
GRPO requires reward functions to guide the policy optimization. We define several:
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
import re
|
| 164 |
+
import numpy as np
|
| 165 |
+
import torch.nn.functional as F
|
| 166 |
+
from contextlib import nullcontext
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def think_format_reward(completions):
|
| 170 |
+
"""
|
| 171 |
+
Reward function that checks if the reasoning process is enclosed within <think> and </think> tags.
|
| 172 |
+
Returns 1.0 if the format is correct, otherwise 0.0.
|
| 173 |
+
"""
|
| 174 |
+
pattern = r"^(?!.*<think>)(.*?)</think>.*$"
|
| 175 |
+
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
|
| 176 |
+
return [1.0 if match else 0.0 for match in matches]
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def accuracy_reward(completions, solutions):
|
| 180 |
+
"""
|
| 181 |
+
Reward function that checks if the completion matches the solution.
|
| 182 |
+
For simplicity, we'll do basic string matching here.
|
| 183 |
+
"""
|
| 184 |
+
rewards = []
|
| 185 |
+
for completion, solution in zip(completions, solutions):
|
| 186 |
+
# Simple string matching (normalized)
|
| 187 |
+
reward = 1.0 if solution.strip().lower() in completion.strip().lower() else 0.0
|
| 188 |
+
rewards.append(reward)
|
| 189 |
+
return rewards
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def min_length_reward(completions, min_length=10):
|
| 193 |
+
"""
|
| 194 |
+
Reward function that checks if the completion is at least a certain length.
|
| 195 |
+
Returns 1.0 if the length is greater than or equal to the minimum length, otherwise 0.0.
|
| 196 |
+
"""
|
| 197 |
+
return [1.0 if len(completion) >= min_length else 0.0 for completion in completions]
|
| 198 |
+
|
| 199 |
+
def combined_reward(completions, solutions):
|
| 200 |
+
"""
|
| 201 |
+
Combines format and accuracy rewards with equal weight.
|
| 202 |
+
"""
|
| 203 |
+
format_rewards = think_format_reward(completions)
|
| 204 |
+
accuracy_rewards = accuracy_reward(completions, solutions)
|
| 205 |
+
min_length_rewards = min_length_reward(completions)
|
| 206 |
+
return [np.mean([f, a, m]) for f, a, m in zip(format_rewards, accuracy_rewards, min_length_rewards)]
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
## Helper Functions
|
| 210 |
+
|
| 211 |
+
```python
|
| 212 |
+
def per_token_log_probs(logits, labels):
|
| 213 |
+
logits = logits.float()
|
| 214 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 215 |
+
return log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def prepare_prompt(example, problem_key="problem", solution_key="solution"):
|
| 219 |
+
# Extract the messages (should be a list of dicts with 'role' and 'content')
|
| 220 |
+
prompt = example.get(problem_key, "")
|
| 221 |
+
messages = [{"role": "user", "content": prompt}]
|
| 222 |
+
|
| 223 |
+
formatted = tokenizer.apply_chat_template(
|
| 224 |
+
messages,
|
| 225 |
+
add_generation_prompt=True,
|
| 226 |
+
truncation=True,
|
| 227 |
+
max_length=2048,
|
| 228 |
+
padding=False,
|
| 229 |
+
return_dict=True,
|
| 230 |
+
return_tensors="pt",
|
| 231 |
+
)
|
| 232 |
+
return formatted["input_ids"], formatted["attention_mask"]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if device.type == "cuda":
|
| 236 |
+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 237 |
+
else:
|
| 238 |
+
autocast_ctx = nullcontext()
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
## Optimizer and Scheduler
|
| 242 |
+
|
| 243 |
+
```python
|
| 244 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
| 245 |
+
total_update_steps = max_train_steps // gradient_accumulation_steps
|
| 246 |
+
warmup_steps = max(1, int(total_update_steps * warmup_ratio))
|
| 247 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_update_steps)
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
## The Training Loop
|
| 251 |
+
|
| 252 |
+
The GRPO training loop generates multiple completions per prompt, computes rewards, and updates the policy using a clipped objective similar to PPO:
|
| 253 |
+
|
| 254 |
+
```python
|
| 255 |
+
# Sample dataset if needed
|
| 256 |
+
if max_train_samples is not None and len(train_dataset) > max_train_samples:
|
| 257 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
| 258 |
+
if max_eval_samples is not None and len(eval_dataset) > max_eval_samples:
|
| 259 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
| 260 |
+
|
| 261 |
+
model.train()
|
| 262 |
+
train_index = 0
|
| 263 |
+
global_step = 0
|
| 264 |
+
running_reward = 0.0
|
| 265 |
+
running_loss = 0.0
|
| 266 |
+
|
| 267 |
+
for step in range(1, max_train_steps + 1):
|
| 268 |
+
example = train_dataset[train_index % len(train_dataset)]
|
| 269 |
+
train_index += 1
|
| 270 |
+
|
| 271 |
+
prompt_ids, prompt_mask = prepare_prompt(example)
|
| 272 |
+
prompt_ids = prompt_ids.to(device)
|
| 273 |
+
prompt_mask = prompt_mask.to(device)
|
| 274 |
+
prompt_length = prompt_ids.shape[1]
|
| 275 |
+
|
| 276 |
+
prompt_repeat = prompt_ids.repeat(num_generations, 1)
|
| 277 |
+
mask_repeat = prompt_mask.repeat(num_generations, 1)
|
| 278 |
+
|
| 279 |
+
# Generate completions
|
| 280 |
+
model.eval()
|
| 281 |
+
with torch.no_grad():
|
| 282 |
+
generated = model.generate(
|
| 283 |
+
input_ids=prompt_repeat,
|
| 284 |
+
attention_mask=mask_repeat,
|
| 285 |
+
max_new_tokens=max_new_tokens,
|
| 286 |
+
do_sample=True,
|
| 287 |
+
temperature=temperature,
|
| 288 |
+
top_k=top_k,
|
| 289 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 290 |
+
)
|
| 291 |
+
model.train()
|
| 292 |
+
|
| 293 |
+
sequences = generated
|
| 294 |
+
attention_mask = (sequences != tokenizer.pad_token_id).long()
|
| 295 |
+
completion_mask = attention_mask.clone()
|
| 296 |
+
completion_mask[:, :prompt_length] = 0
|
| 297 |
+
|
| 298 |
+
completion_tokens = sequences[:, prompt_length:]
|
| 299 |
+
completion_texts = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
| 300 |
+
|
| 301 |
+
# Get solution
|
| 302 |
+
solution = example.get("solution", example.get("answer", ""))
|
| 303 |
+
solutions = [solution] * num_generations
|
| 304 |
+
|
| 305 |
+
# Compute rewards
|
| 306 |
+
rewards = combined_reward(completion_texts, solutions)
|
| 307 |
+
rewards = torch.tensor(rewards, dtype=torch.float32, device=device)
|
| 308 |
+
running_reward += rewards.mean().item()
|
| 309 |
+
|
| 310 |
+
rewards_view = rewards.view(prompt_batch_size, num_generations)
|
| 311 |
+
mean_rewards = rewards_view.mean(dim=1, keepdim=True)
|
| 312 |
+
std_rewards = rewards_view.std(dim=1, keepdim=True)
|
| 313 |
+
std_rewards = torch.where(std_rewards > 0, std_rewards, torch.ones_like(std_rewards))
|
| 314 |
+
advantages = ((rewards_view - mean_rewards) / std_rewards).view(-1)
|
| 315 |
+
|
| 316 |
+
labels = sequences[:, 1:].clone()
|
| 317 |
+
labels[attention_mask[:, 1:] == 0] = tokenizer.pad_token_id
|
| 318 |
+
|
| 319 |
+
# Compute old log probs
|
| 320 |
+
with torch.no_grad():
|
| 321 |
+
with (autocast_ctx if device.type == "cuda" else nullcontext()):
|
| 322 |
+
old_outputs = model(
|
| 323 |
+
input_ids=sequences,
|
| 324 |
+
attention_mask=attention_mask,
|
| 325 |
+
use_cache=False,
|
| 326 |
+
)
|
| 327 |
+
old_log_probs = per_token_log_probs(old_outputs.logits[:, :-1], labels)
|
| 328 |
+
|
| 329 |
+
valid_mask = (completion_mask[:, 1:] == 1) & (labels != tokenizer.pad_token_id)
|
| 330 |
+
|
| 331 |
+
# Compute loss
|
| 332 |
+
optimizer.zero_grad(set_to_none=True)
|
| 333 |
+
with (autocast_ctx if device.type == "cuda" else nullcontext()):
|
| 334 |
+
outputs = model(
|
| 335 |
+
input_ids=sequences,
|
| 336 |
+
attention_mask=attention_mask,
|
| 337 |
+
use_cache=False,
|
| 338 |
+
)
|
| 339 |
+
log_probs = per_token_log_probs(outputs.logits[:, :-1], labels)
|
| 340 |
+
|
| 341 |
+
ratio = (log_probs - old_log_probs).exp()
|
| 342 |
+
ratio = torch.where(valid_mask, ratio, torch.ones_like(ratio))
|
| 343 |
+
clipped_ratio = ratio.clamp(1.0 - epsilon, 1.0 + epsilon)
|
| 344 |
+
|
| 345 |
+
adv = advantages.unsqueeze(1)
|
| 346 |
+
loss_unclipped = ratio * adv
|
| 347 |
+
loss_clipped = clipped_ratio * adv
|
| 348 |
+
per_token_loss = -torch.min(loss_unclipped, loss_clipped)
|
| 349 |
+
per_token_loss = torch.where(valid_mask, per_token_loss, torch.zeros_like(per_token_loss))
|
| 350 |
+
|
| 351 |
+
denom = valid_mask.sum().clamp(min=1)
|
| 352 |
+
loss = per_token_loss.sum() / denom
|
| 353 |
+
|
| 354 |
+
loss.backward()
|
| 355 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 356 |
+
optimizer.step()
|
| 357 |
+
scheduler.step()
|
| 358 |
+
|
| 359 |
+
global_step += 1
|
| 360 |
+
running_loss += loss.item()
|
| 361 |
+
|
| 362 |
+
if step % logging_frequency == 0:
|
| 363 |
+
avg_reward = running_reward / logging_frequency
|
| 364 |
+
avg_loss = running_loss / logging_frequency
|
| 365 |
+
current_lr = scheduler.get_last_lr()[0]
|
| 366 |
+
print(
|
| 367 |
+
f"step={step:04d} | loss={avg_loss:.4f} | avg_reward={avg_reward:.4f} | lr={current_lr:.2e}"
|
| 368 |
+
)
|
| 369 |
+
running_reward = 0.0
|
| 370 |
+
running_loss = 0.0
|
| 371 |
+
|
| 372 |
+
# Sample evaluation
|
| 373 |
+
model.eval()
|
| 374 |
+
eval_example = eval_dataset[0]
|
| 375 |
+
prompt_ids, prompt_mask = prepare_prompt(eval_example)
|
| 376 |
+
with torch.no_grad():
|
| 377 |
+
eval_sequences = model.generate(
|
| 378 |
+
input_ids=prompt_ids.to(device),
|
| 379 |
+
attention_mask=prompt_mask.to(device),
|
| 380 |
+
max_new_tokens=max_new_tokens,
|
| 381 |
+
do_sample=True,
|
| 382 |
+
top_k=top_k,
|
| 383 |
+
temperature=temperature,
|
| 384 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 385 |
+
)
|
| 386 |
+
model.train()
|
| 387 |
+
completion = eval_sequences[0, prompt_ids.shape[1] :]
|
| 388 |
+
print("Sample eval completion:", tokenizer.decode(completion, skip_special_tokens=True)[:100])
|
| 389 |
+
|
| 390 |
+
print("Training complete.")
|
| 391 |
+
```
|
| 392 |
+
|
| 393 |
+
```
|
| 394 |
+
step=0005 | loss=0.0000 | avg_reward=0.4000 | lr=0.00e+00
|
| 395 |
+
Sample eval completion: 3^4 - 11 and 3^6 - 17
|
| 396 |
+
step=0010 | loss=0.0000 | avg_reward=0.3333 | lr=0.00e+00
|
| 397 |
+
Sample eval completion: 11.
|
| 398 |
+
|
| 399 |
+
This statement refers to an optimization problem where we seek to find the smallest prime \( p
|
| 400 |
+
step=0015 | loss=0.0000 | avg_reward=0.4667 | lr=0.00e+00
|
| 401 |
+
Sample eval completion: What number has two prime factors, 1 and itself, without additional restrictions? One possible combi
|
| 402 |
+
step=0020 | loss=-0.0983 | avg_reward=0.4500 | lr=0.00e+00
|
| 403 |
+
...
|
| 404 |
+
Training complete.
|
| 405 |
+
```
|
| 406 |
+
|
app/src/content/chapters/inference.mdx
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Inference on `nano` in Transformers
|
| 2 |
+
|
| 3 |
+
First bonus tutorial will help you to do basic inference in `transformers`:
|
| 4 |
+
|
| 5 |
+
```py
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import AutoTokenizer, NanoChatForCausalLM
|
| 8 |
+
|
| 9 |
+
tokenizer = AutoTokenizer.from_pretrained("nanochat-students/nanochat-d20")
|
| 10 |
+
model = NanoChatForCausalLM.from_pretrained("nanochat-students/nanochat-d20")
|
| 11 |
+
|
| 12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
model = model.to(device)
|
| 14 |
+
|
| 15 |
+
prompt = "Hello, how are you?"
|
| 16 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 17 |
+
inputs.pop("token_type_ids", None)
|
| 18 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
| 19 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### Inference in `transformers` with `vLLM`
|
| 23 |
+
|
| 24 |
+
Next, let's use `transformers` as a backend for `vLLM` to serve the model for optimized inference.
|
| 25 |
+
|
| 26 |
+
We'll need to install `vLLM` from main:
|
| 27 |
+
|
| 28 |
+
```sh
|
| 29 |
+
pip install git+https://github.com/huggingface/transformers.git@main
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Then we can start a `vLLM` server like so:
|
| 33 |
+
|
| 34 |
+
```
|
| 35 |
+
vllm serve nanochat-students/nanochat-d20 --enforce-eager --revision refs/pr/1
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Finally, we can call the server like so:
|
| 39 |
+
|
| 40 |
+
```sh
|
| 41 |
+
curl -X POST "http://localhost:8000/v1/completions" \
|
| 42 |
+
-H "Content-Type: application/json" \
|
| 43 |
+
--data '{
|
| 44 |
+
"model": "nanochat-students/nanochat-d20",
|
| 45 |
+
"prompt": "Once upon a time,",
|
| 46 |
+
"max_tokens": 512,
|
| 47 |
+
"temperature": 0.5
|
| 48 |
+
}'
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Inference on your trained `nanochat` weights
|
| 52 |
+
|
| 53 |
+
Let's say you've followed the nanochat repo and used it to train a model. The you can add transformer compatibility to your model and use it in other libraries.
|
| 54 |
+
|
| 55 |
+
1. download any `nanochat` checkpoint from the hub. Here we use Karpathy's but this could be yours:
|
| 56 |
+
|
| 57 |
+
```
|
| 58 |
+
hf download karpathy/nanochat-d34 --local-dir nanochat-d34
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
2. convert the checkpoint to transformers format using the conversion scripts:
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
uv run \
|
| 65 |
+
--with "transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
| 66 |
+
--with "tiktoken>=0.12.0" \
|
| 67 |
+
https://raw.githubusercontent.com/huggingface/transformers/main/src/transformers/models/nanochat/convert_nanochat_checkpoints.py \
|
| 68 |
+
--input_dir ./nanochat-d34 \
|
| 69 |
+
--output_dir ./nanochat-d3-hf
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
3. (optional) Upload the checkpoint to the Hugging Face Hub
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
hf upload <username>/nanochat-d34 nanochat-d34
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
4. As above, you can generate with your model in `transformers`.
|
| 79 |
+
|
| 80 |
+
```py
|
| 81 |
+
import torch
|
| 82 |
+
from transformers import AutoTokenizer, NanoChatForCausalLM
|
| 83 |
+
|
| 84 |
+
tokenizer = AutoTokenizer.from_pretrained("./nanochat-d3-hf")
|
| 85 |
+
model = NanoChatForCausalLM.from_pretrained("./nanochat-d3-hf")
|
| 86 |
+
|
| 87 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 88 |
+
model = model.to(device)
|
| 89 |
+
|
| 90 |
+
prompt = "Hello, how are you?"
|
| 91 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 92 |
+
inputs.pop("token_type_ids", None)
|
| 93 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
| 94 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
|
app/src/content/chapters/sft.mdx
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import Sidenote from '../../components/Sidenote.astro'
|
| 2 |
+
import Note from '../../components/Note.astro'
|
| 3 |
+
|
| 4 |
+
# [BONUS 2] Supervised Fine-tuning in `torch`
|
| 5 |
+
|
| 6 |
+
<Sidenote>
|
| 7 |
+
|
| 8 |
+
[](https://colab.research.google.com/#fileId=https%3A//huggingface.co/datasets/nanochat-students/notebooks/blob/main/sft.ipynb)
|
| 9 |
+
|
| 10 |
+
</Sidenote>
|
| 11 |
+
|
| 12 |
+
Supervised Fine-Tuning (SFT) is the process of adapting a pre-trained language model to follow instructions by training it on curated input-output pairs. Unlike pre-training which learns general language patterns from massive text corpora, SFT teaches the model *how* to respond—following a specific format, tone, or task structure.
|
| 13 |
+
|
| 14 |
+
In this tutorial, we'll fine-tune the NanoChat model using pure PyTorch, giving you complete visibility into every step of the training process.
|
| 15 |
+
|
| 16 |
+
<Note>
|
| 17 |
+
|
| 18 |
+
**Want a production-ready solution?** TRL is Hugging Face's reinforcement learning library with battle-tested SFT implementations. Check out the [SFT notebook](https://github.com/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb) to use it with your nanochat checkpoint.
|
| 19 |
+
|
| 20 |
+
</Note>
|
| 21 |
+
|
| 22 |
+
## Import model and tokenizer
|
| 23 |
+
|
| 24 |
+
<Sidenote>
|
| 25 |
+
|
| 26 |
+
Learn more about [AutoModelForCausalLM](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM) and the [from_pretrained](https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.from_pretrained) method.
|
| 27 |
+
|
| 28 |
+
</Sidenote>
|
| 29 |
+
|
| 30 |
+
We start by loading the pre-trained NanoChat model and its tokenizer. The `revision` parameter points to a specific model version—useful when models are updated frequently or you want reproducible results.
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
import torch
|
| 34 |
+
from torch.utils.data import DataLoader
|
| 35 |
+
from datasets import load_dataset
|
| 36 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
model_id = "karpathy/nanochat-d32"
|
| 40 |
+
revision = "refs/pr/1"
|
| 41 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
|
| 45 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 46 |
+
model_id,
|
| 47 |
+
revision=revision,
|
| 48 |
+
torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32,
|
| 49 |
+
).to(device)
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
We use `bfloat16` precision on GPU to reduce memory usage while maintaining training stability. On CPU, we fall back to `float32` for compatibility.
|
| 53 |
+
|
| 54 |
+
## Setup LoRA
|
| 55 |
+
|
| 56 |
+
<Sidenote>
|
| 57 |
+
|
| 58 |
+
Read the [LoRA paper](https://arxiv.org/abs/2106.09685) or explore [PEFT documentation](https://huggingface.co/docs/peft) for a deeper understanding of low-rank adaptation.
|
| 59 |
+
|
| 60 |
+
</Sidenote>
|
| 61 |
+
|
| 62 |
+
Training all 1.8B parameters would require significant GPU memory and risk catastrophic forgetting. Instead, we use **LoRA (Low-Rank Adaptation)** which freezes the original weights and injects small trainable matrices into specific layers.
|
| 63 |
+
|
| 64 |
+
The key parameters:
|
| 65 |
+
- **`r=1`**: The rank of the low-rank matrices. Lower = fewer parameters, but potentially less expressiveness
|
| 66 |
+
- **`lora_alpha=2`**: Scaling factor for LoRA updates (typically `2 * r`)
|
| 67 |
+
- **`target_modules`**: Which layers to adapt—we target all attention projections and the MLP
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
from peft import LoraConfig, get_peft_model
|
| 71 |
+
|
| 72 |
+
lora_config = LoraConfig(
|
| 73 |
+
r=1,
|
| 74 |
+
lora_alpha=2,
|
| 75 |
+
lora_dropout=0.00,
|
| 76 |
+
task_type="CAUSAL_LM",
|
| 77 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"]
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
model = get_peft_model(model, lora_config)
|
| 81 |
+
model.print_trainable_parameters()
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
trainable params: 1,179,648 || all params: 1,880,227,840 || trainable%: 0.0627
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
With LoRA, we're only training **0.06%** of the model's parameters—just over 1 million weights instead of 1.8 billion. This makes fine-tuning feasible on consumer hardware.
|
| 89 |
+
|
| 90 |
+
## Demo the model
|
| 91 |
+
|
| 92 |
+
<Sidenote>
|
| 93 |
+
|
| 94 |
+
The [generate](https://huggingface.co/docs/transformers/main_classes/text_generation) method supports many decoding strategies: greedy, beam search, sampling, and more.
|
| 95 |
+
|
| 96 |
+
</Sidenote>
|
| 97 |
+
|
| 98 |
+
Before training, let's verify the model works correctly. We'll test two modes: raw text completion and chat-formatted generation.
|
| 99 |
+
|
| 100 |
+
**Plain autoregressive completion** continues text naturally:
|
| 101 |
+
|
| 102 |
+
```python
|
| 103 |
+
print("=" * 80)
|
| 104 |
+
print("TEST 1: Plain Autoregressive Prompt")
|
| 105 |
+
print("=" * 80)
|
| 106 |
+
prompt = "The Eiffel Tower stands in Paris and"
|
| 107 |
+
test_inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
test_outputs = model.generate(
|
| 112 |
+
**test_inputs,
|
| 113 |
+
max_new_tokens=64,
|
| 114 |
+
do_sample=False,
|
| 115 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
generated_tokens = test_outputs[0, test_inputs["input_ids"].shape[1] :]
|
| 119 |
+
print(f"Prompt: {prompt}")
|
| 120 |
+
print(f"\nGenerated: {tokenizer.decode(generated_tokens, skip_special_tokens=True)}")
|
| 121 |
+
print("=" * 80)
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
```
|
| 125 |
+
================================================================================
|
| 126 |
+
TEST 1: Plain Autoregressive Prompt
|
| 127 |
+
================================================================================
|
| 128 |
+
Prompt: The Eiffel Tower stands in Paris and
|
| 129 |
+
|
| 130 |
+
Generated: is one of the most famous landmarks in the world. It is located on the Champ de Mars in the heart of the city. The tower was built for the 1889 World's Fair. It was designed by the French engineer Gustave Eiffel and took 2 years to build. The Eiffel Tower stands 324 meters
|
| 131 |
+
================================================================================
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
<Sidenote>
|
| 135 |
+
|
| 136 |
+
Chat templates ensure consistent formatting. See [chat templating guide](https://huggingface.co/docs/transformers/en/chat_templating) for details on how different models structure conversations.
|
| 137 |
+
|
| 138 |
+
</Sidenote>
|
| 139 |
+
|
| 140 |
+
The chat template wraps the input in special tokens that the model learned during instruction tuning:
|
| 141 |
+
|
| 142 |
+
```python
|
| 143 |
+
print("=" * 80)
|
| 144 |
+
print("TEST 2: Chat Template")
|
| 145 |
+
print("="*80)
|
| 146 |
+
conversation = [
|
| 147 |
+
{"role": "user", "content": "What is the capital of France?"},
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
inputs = tokenizer.apply_chat_template(
|
| 151 |
+
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
| 152 |
+
).to(device)
|
| 153 |
+
|
| 154 |
+
print(f"Formatted prompt: {tokenizer.decode(inputs['input_ids'][0])}")
|
| 155 |
+
print(f"Input IDs: {inputs['input_ids'][0].tolist()}")
|
| 156 |
+
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
outputs = model.generate(
|
| 159 |
+
**inputs,
|
| 160 |
+
max_new_tokens=64,
|
| 161 |
+
do_sample=False
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
|
| 165 |
+
print(f"\nGenerated: {tokenizer.decode(generated_tokens)}")
|
| 166 |
+
print("=" * 80)
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
```
|
| 170 |
+
================================================================================
|
| 171 |
+
TEST 2: Chat Template
|
| 172 |
+
================================================================================
|
| 173 |
+
Formatted prompt: <|bos|><|user_start|>What is the capital of France?<|user_end|><|assistant_start|>
|
| 174 |
+
Input IDs: [65527, 65528, 1442, 309, 261, 3429, 281, 4215, 63, 65529, 65530]
|
| 175 |
+
|
| 176 |
+
Generated: The capital of France is Paris.<|assistant_end|>
|
| 177 |
+
================================================================================
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
Notice the special tokens: `<|bos|>`, `<|user_start|>`, `<|assistant_start|>`, etc. These delimiters help the model understand conversation structure.
|
| 181 |
+
|
| 182 |
+
## Dataset
|
| 183 |
+
|
| 184 |
+
<Sidenote>
|
| 185 |
+
|
| 186 |
+
Explore the [OpenThoughts dataset](https://huggingface.co/datasets/HuggingFaceTB/smoltalk2) on the Hub. It contains instruction-following examples with chain-of-thought reasoning.
|
| 187 |
+
|
| 188 |
+
</Sidenote>
|
| 189 |
+
|
| 190 |
+
For SFT, we need high-quality instruction-response pairs. We'll use OpenThoughts, a dataset designed for training models to reason step-by-step before answering.
|
| 191 |
+
|
| 192 |
+
```python
|
| 193 |
+
raw_dataset = load_dataset("HuggingFaceTB/smoltalk2", "SFT", split="OpenThoughts3_1.2M_think")
|
| 194 |
+
splits = raw_dataset.train_test_split(test_size=0.1, seed=42)
|
| 195 |
+
train_dataset = splits["train"]
|
| 196 |
+
eval_dataset = splits["test"]
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
### Process the Dataset
|
| 200 |
+
|
| 201 |
+
<Sidenote>
|
| 202 |
+
|
| 203 |
+
The [datasets map](https://huggingface.co/docs/datasets/process#map) function applies transformations efficiently with caching and multiprocessing support.
|
| 204 |
+
|
| 205 |
+
</Sidenote>
|
| 206 |
+
|
| 207 |
+
Raw examples contain message lists that need to be converted into token sequences. The `apply_chat_template` method handles this conversion, inserting the appropriate special tokens.
|
| 208 |
+
|
| 209 |
+
We limit examples to 2048 tokens and cap the dataset size to make training tractable on limited hardware:
|
| 210 |
+
|
| 211 |
+
```python
|
| 212 |
+
max_length = 2048
|
| 213 |
+
max_train_examples = 20000
|
| 214 |
+
max_eval_examples = 1000
|
| 215 |
+
|
| 216 |
+
def format_example(example):
|
| 217 |
+
formatted = tokenizer.apply_chat_template(
|
| 218 |
+
example["messages"],
|
| 219 |
+
add_generation_prompt=False,
|
| 220 |
+
truncation=True,
|
| 221 |
+
max_length=max_length,
|
| 222 |
+
padding=False,
|
| 223 |
+
return_dict=True,
|
| 224 |
+
return_tensors="pt",
|
| 225 |
+
)
|
| 226 |
+
return {
|
| 227 |
+
"input_ids": formatted["input_ids"][0].tolist(),
|
| 228 |
+
"attention_mask": formatted["attention_mask"][0].tolist(),
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
train_dataset = train_dataset.select(range(min(len(train_dataset), max_train_examples)))
|
| 233 |
+
train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)
|
| 234 |
+
|
| 235 |
+
eval_dataset = eval_dataset.select(range(min(len(eval_dataset), max_eval_examples)))
|
| 236 |
+
eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
## Training Configuration
|
| 240 |
+
|
| 241 |
+
These hyperparameters control the training dynamics. We use conservative values that work well across different hardware:
|
| 242 |
+
|
| 243 |
+
```python
|
| 244 |
+
train_batch_size = 2
|
| 245 |
+
eval_batch_size = 2
|
| 246 |
+
num_epochs = 1
|
| 247 |
+
gradient_accumulation_steps = 4
|
| 248 |
+
learning_rate = 1e-5
|
| 249 |
+
weight_decay = 0.0
|
| 250 |
+
warmup_ratio = 0.03
|
| 251 |
+
logging_frequency = 10
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
<Sidenote>
|
| 255 |
+
|
| 256 |
+
**Gradient accumulation** simulates larger batch sizes by accumulating gradients over multiple forward passes before updating weights. Effective batch size = `train_batch_size × gradient_accumulation_steps` = 8.
|
| 257 |
+
|
| 258 |
+
</Sidenote>
|
| 259 |
+
|
| 260 |
+
Key configuration choices include using a low learning rate (`1e-5`), as LoRA generally requires smaller learning rates given that the base model weights are kept frozen. Additionally, gradient accumulation is employed to enable larger effective batch sizes, which helps when training on GPUs with limited memory.
|
| 261 |
+
|
| 262 |
+
## Create a DataLoader
|
| 263 |
+
|
| 264 |
+
<Sidenote>
|
| 265 |
+
|
| 266 |
+
PyTorch's [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) handles batching, shuffling, and parallel data loading automatically.
|
| 267 |
+
|
| 268 |
+
</Sidenote>
|
| 269 |
+
|
| 270 |
+
The collate pads variable-length sequences to the same length within each batch and creates the labels tensor for loss computation:
|
| 271 |
+
|
| 272 |
+
```python
|
| 273 |
+
def collate_fn(batch):
|
| 274 |
+
batch_dict = {
|
| 275 |
+
"input_ids": [record["input_ids"] for record in batch],
|
| 276 |
+
"attention_mask": [record["attention_mask"] for record in batch],
|
| 277 |
+
}
|
| 278 |
+
padded = tokenizer.pad(batch_dict, padding=True, return_tensors="pt")
|
| 279 |
+
labels = padded["input_ids"].clone()
|
| 280 |
+
labels[padded["attention_mask"] == 0] = -100
|
| 281 |
+
padded["labels"] = labels
|
| 282 |
+
return padded
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)
|
| 286 |
+
eval_loader = DataLoader(eval_dataset, batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn)
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
Setting padding tokens to `-100` in labels tells PyTorch's cross-entropy loss to ignore them—we don't want to penalize the model for not predicting padding.
|
| 290 |
+
|
| 291 |
+
## Optimizer
|
| 292 |
+
|
| 293 |
+
<Sidenote>
|
| 294 |
+
|
| 295 |
+
[AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html) decouples weight decay from the gradient update, improving regularization behavior compared to L2 regularization in standard Adam.
|
| 296 |
+
|
| 297 |
+
</Sidenote>
|
| 298 |
+
|
| 299 |
+
AdamW is the standard optimizer for transformer fine-tuning. It combines Adam's adaptive learning rates with proper weight decay:
|
| 300 |
+
|
| 301 |
+
```python
|
| 302 |
+
optimizer = torch.optim.AdamW(
|
| 303 |
+
model.parameters(),
|
| 304 |
+
lr=learning_rate,
|
| 305 |
+
weight_decay=weight_decay,
|
| 306 |
+
)
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
## Learning Rate Scheduler
|
| 310 |
+
|
| 311 |
+
<Sidenote>
|
| 312 |
+
|
| 313 |
+
Warmup prevents early instability when the model hasn't yet adapted to the new task. See [this explanation](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules#transformers.get_linear_schedule_with_warmup) for more details.
|
| 314 |
+
|
| 315 |
+
</Sidenote>
|
| 316 |
+
|
| 317 |
+
A linear schedule with warmup gradually increases the learning rate at the start of training (warmup), then linearly decreases it to zero. This helps stabilize early training and improves final performance:
|
| 318 |
+
|
| 319 |
+
```python
|
| 320 |
+
num_update_steps_per_epoch = max(len(TrainLoader) // gradient_accumulation_steps, 1)
|
| 321 |
+
max_train_steps = num_epochs * num_update_steps_per_epoch
|
| 322 |
+
warmup_steps = max(1, int(max_train_steps * warmup_ratio))
|
| 323 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, max_train_steps)
|
| 324 |
+
```
|
| 325 |
+
|
| 326 |
+
## The Training Loop
|
| 327 |
+
|
| 328 |
+
<Sidenote>
|
| 329 |
+
|
| 330 |
+
For distributed training across multiple GPUs, consider [Accelerate](https://huggingface.co/docs/accelerate/index) which wraps this loop with minimal code changes.
|
| 331 |
+
|
| 332 |
+
</Sidenote>
|
| 333 |
+
|
| 334 |
+
Now we bring everything together. The training loop follows the standard PyTorch pattern with gradient accumulation:
|
| 335 |
+
|
| 336 |
+
1. **Forward pass**: Compute loss on a mini-batch
|
| 337 |
+
2. **Backward pass**: Accumulate gradients
|
| 338 |
+
3. **Optimizer step**: Update weights (every `gradient_accumulation_steps` batches)
|
| 339 |
+
4. **Logging**: Track loss and learning rate
|
| 340 |
+
5. **Evaluation**: Measure validation loss after each epoch
|
| 341 |
+
|
| 342 |
+
```python
|
| 343 |
+
model.train()
|
| 344 |
+
global_step = 0
|
| 345 |
+
running_loss = 0.0
|
| 346 |
+
running_steps = 0
|
| 347 |
+
|
| 348 |
+
for epoch in range(num_epochs):
|
| 349 |
+
print(f"Epoch {epoch + 1}/{num_epochs}")
|
| 350 |
+
optimizer.zero_grad(set_to_none=True)
|
| 351 |
+
for step, batch in enumerate(TrainLoader, start=1):
|
| 352 |
+
batch = {key: value.to(device) for key, value in batch.items()}
|
| 353 |
+
outputs = model(**batch)
|
| 354 |
+
loss = outputs.loss / gradient_accumulation_steps
|
| 355 |
+
loss.backward()
|
| 356 |
+
|
| 357 |
+
running_loss += outputs.loss.float().item()
|
| 358 |
+
running_steps += 1
|
| 359 |
+
|
| 360 |
+
if step % gradient_accumulation_steps == 0 or step == len(TrainLoader):
|
| 361 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 362 |
+
optimizer.step()
|
| 363 |
+
scheduler.step()
|
| 364 |
+
optimizer.zero_grad(set_to_none=True)
|
| 365 |
+
global_step += 1
|
| 366 |
+
|
| 367 |
+
if global_step % logging_frequency == 0:
|
| 368 |
+
current_lr = scheduler.get_last_lr()[0]
|
| 369 |
+
mean_loss = running_loss / running_steps
|
| 370 |
+
print(f"step={global_step:05d} | loss={mean_loss:.4f} | lr={current_lr:.2e}")
|
| 371 |
+
running_loss = 0.0
|
| 372 |
+
running_steps = 0
|
| 373 |
+
|
| 374 |
+
train_loss = running_loss / running_steps if running_steps > 0 else float("nan")
|
| 375 |
+
print(f"Training loss after epoch {epoch + 1}: {train_loss:.4f}")
|
| 376 |
+
|
| 377 |
+
model.eval()
|
| 378 |
+
losses = []
|
| 379 |
+
with torch.no_grad():
|
| 380 |
+
for _, batch in enumerate(EvalLoader, start=1):
|
| 381 |
+
batch = {key: value.to(device) for key, value in batch.items()}
|
| 382 |
+
loss = model(**batch).loss
|
| 383 |
+
losses.append(loss.float().item())
|
| 384 |
+
model.train()
|
| 385 |
+
val_loss = sum(losses) / len(losses) if losses else float("nan")
|
| 386 |
+
|
| 387 |
+
print(f"Validation loss after epoch {epoch + 1}: {val_loss:.4f}")
|
| 388 |
+
|
| 389 |
+
print("Training complete.")
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
```
|
| 393 |
+
Epoch 1/1
|
| 394 |
+
step=00010 | loss=1.7586 | lr=1.33e-06
|
| 395 |
+
step=00020 | loss=1.8188 | lr=2.67e-06
|
| 396 |
+
step=00030 | loss=1.8235 | lr=4.00e-06
|
| 397 |
+
step=00040 | loss=1.7935 | lr=5.33e-06
|
| 398 |
+
step=00050 | loss=1.8029 | lr=6.67e-06
|
| 399 |
+
...
|
| 400 |
+
```
|
grpo.ipynb
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "5a611684",
|
| 6 |
+
"metadata": {
|
| 7 |
+
"id": "5a611684"
|
| 8 |
+
},
|
| 9 |
+
"source": [
|
| 10 |
+
"# NanoChat Easy - GRPO Training\n",
|
| 11 |
+
"\n"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "markdown",
|
| 16 |
+
"id": "80df0403",
|
| 17 |
+
"metadata": {
|
| 18 |
+
"id": "80df0403"
|
| 19 |
+
},
|
| 20 |
+
"source": [
|
| 21 |
+
"## Import model and tokenizer\n"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": null,
|
| 27 |
+
"id": "1dd76bde",
|
| 28 |
+
"metadata": {
|
| 29 |
+
"id": "1dd76bde",
|
| 30 |
+
"outputId": "b786d7ad-5aa8-4a13-eb1f-54a65aaf44ba"
|
| 31 |
+
},
|
| 32 |
+
"outputs": [
|
| 33 |
+
{
|
| 34 |
+
"name": "stderr",
|
| 35 |
+
"output_type": "stream",
|
| 36 |
+
"text": [
|
| 37 |
+
"/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 38 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 39 |
+
"`torch_dtype` is deprecated! Use `dtype` instead!\n"
|
| 40 |
+
]
|
| 41 |
+
}
|
| 42 |
+
],
|
| 43 |
+
"source": [
|
| 44 |
+
"import torch\n",
|
| 45 |
+
"from torch.utils.data import DataLoader\n",
|
| 46 |
+
"from datasets import load_dataset\n",
|
| 47 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"model_id = \"karpathy/nanochat-d32\"\n",
|
| 51 |
+
"revision = \"refs/pr/1\"\n",
|
| 52 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)\n",
|
| 56 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 57 |
+
" model_id,\n",
|
| 58 |
+
" revision=revision,\n",
|
| 59 |
+
" torch_dtype=torch.bfloat16 if device.type == \"cuda\" else torch.float32,\n",
|
| 60 |
+
").to(device)\n",
|
| 61 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
| 62 |
+
"model.config.pad_token_id = tokenizer.pad_token_id"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "markdown",
|
| 67 |
+
"id": "6eb979a9",
|
| 68 |
+
"metadata": {
|
| 69 |
+
"id": "6eb979a9"
|
| 70 |
+
},
|
| 71 |
+
"source": [
|
| 72 |
+
"## Setup LoRA\n"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "code",
|
| 77 |
+
"execution_count": null,
|
| 78 |
+
"id": "1973b450",
|
| 79 |
+
"metadata": {
|
| 80 |
+
"id": "1973b450",
|
| 81 |
+
"outputId": "354ceafb-b4cb-4423-f076-7800024171b7"
|
| 82 |
+
},
|
| 83 |
+
"outputs": [
|
| 84 |
+
{
|
| 85 |
+
"name": "stdout",
|
| 86 |
+
"output_type": "stream",
|
| 87 |
+
"text": [
|
| 88 |
+
"trainable params: 1,179,648 || all params: 1,880,227,840 || trainable%: 0.0627\n"
|
| 89 |
+
]
|
| 90 |
+
}
|
| 91 |
+
],
|
| 92 |
+
"source": [
|
| 93 |
+
"from peft import LoraConfig, get_peft_model\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"lora_config = LoraConfig(\n",
|
| 96 |
+
" r=1,\n",
|
| 97 |
+
" lora_alpha=2,\n",
|
| 98 |
+
" lora_dropout=0.00,\n",
|
| 99 |
+
" task_type=\"CAUSAL_LM\",\n",
|
| 100 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"fc1\", \"fc2\"]\n",
|
| 101 |
+
")\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"model = get_peft_model(model, lora_config)\n",
|
| 104 |
+
"model.print_trainable_parameters()\n"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "markdown",
|
| 109 |
+
"id": "3f3533dd",
|
| 110 |
+
"metadata": {
|
| 111 |
+
"id": "3f3533dd"
|
| 112 |
+
},
|
| 113 |
+
"source": [
|
| 114 |
+
"## Demo the model\n"
|
| 115 |
+
]
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"cell_type": "code",
|
| 119 |
+
"execution_count": null,
|
| 120 |
+
"id": "0f930711",
|
| 121 |
+
"metadata": {
|
| 122 |
+
"id": "0f930711",
|
| 123 |
+
"outputId": "f263ab12-9b2c-4ea3-da1c-4465032538d2"
|
| 124 |
+
},
|
| 125 |
+
"outputs": [
|
| 126 |
+
{
|
| 127 |
+
"name": "stdout",
|
| 128 |
+
"output_type": "stream",
|
| 129 |
+
"text": [
|
| 130 |
+
"================================================================================\n",
|
| 131 |
+
"TEST 1: Plain Autoregressive Prompt\n",
|
| 132 |
+
"================================================================================\n",
|
| 133 |
+
"Prompt: The Eiffel Tower stands in Paris and\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"Generated: is one of the most famous landmarks in the world. It is located on the Champ de Mars in the heart of the city. The tower was built for the 1889 World's Fair. It was designed by the French engineer Gustave Eiffel and took 2 years to build. The Eiffel Tower stands 324 meters\n",
|
| 136 |
+
"================================================================================\n"
|
| 137 |
+
]
|
| 138 |
+
}
|
| 139 |
+
],
|
| 140 |
+
"source": [
|
| 141 |
+
"print(\"=\" * 80)\n",
|
| 142 |
+
"print(\"TEST 1: Plain Autoregressive Prompt\")\n",
|
| 143 |
+
"print(\"=\" * 80)\n",
|
| 144 |
+
"prompt = \"The Eiffel Tower stands in Paris and\"\n",
|
| 145 |
+
"test_inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"with torch.no_grad():\n",
|
| 149 |
+
" test_outputs = model.generate(\n",
|
| 150 |
+
" **test_inputs,\n",
|
| 151 |
+
" max_new_tokens=64,\n",
|
| 152 |
+
" do_sample=False,\n",
|
| 153 |
+
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 154 |
+
" )\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"generated_tokens = test_outputs[0, test_inputs[\"input_ids\"].shape[1] :]\n",
|
| 157 |
+
"print(f\"Prompt: {prompt}\")\n",
|
| 158 |
+
"print(f\"\\nGenerated: {tokenizer.decode(generated_tokens, skip_special_tokens=True)}\")\n",
|
| 159 |
+
"print(\"=\" * 80)\n"
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"cell_type": "code",
|
| 164 |
+
"execution_count": null,
|
| 165 |
+
"id": "fbf80e5f",
|
| 166 |
+
"metadata": {
|
| 167 |
+
"id": "fbf80e5f",
|
| 168 |
+
"outputId": "86af20b4-3b9f-4dad-ba09-5dbb0de0f18c"
|
| 169 |
+
},
|
| 170 |
+
"outputs": [
|
| 171 |
+
{
|
| 172 |
+
"name": "stdout",
|
| 173 |
+
"output_type": "stream",
|
| 174 |
+
"text": [
|
| 175 |
+
"================================================================================\n",
|
| 176 |
+
"TEST 2: Chat Template\n",
|
| 177 |
+
"================================================================================\n",
|
| 178 |
+
"Formatted prompt: <|bos|><|user_start|>What is the capital of France?<|user_end|><|assistant_start|>\n",
|
| 179 |
+
"Input IDs: [65527, 65528, 1442, 309, 261, 3429, 281, 4215, 63, 65529, 65530]\n",
|
| 180 |
+
"\n",
|
| 181 |
+
"Generated: The capital of France is Paris.<|assistant_end|>\n",
|
| 182 |
+
"================================================================================\n"
|
| 183 |
+
]
|
| 184 |
+
}
|
| 185 |
+
],
|
| 186 |
+
"source": [
|
| 187 |
+
"print(\"=\" * 80)\n",
|
| 188 |
+
"print(\"TEST 2: Chat Template\")\n",
|
| 189 |
+
"print(\"=\"*80)\n",
|
| 190 |
+
"conversation = [\n",
|
| 191 |
+
" {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n",
|
| 192 |
+
"]\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"inputs = tokenizer.apply_chat_template(\n",
|
| 195 |
+
" conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors=\"pt\"\n",
|
| 196 |
+
").to(device)\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"print(f\"Formatted prompt: {tokenizer.decode(inputs['input_ids'][0])}\")\n",
|
| 199 |
+
"print(f\"Input IDs: {inputs['input_ids'][0].tolist()}\")\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"with torch.no_grad():\n",
|
| 202 |
+
" outputs = model.generate(\n",
|
| 203 |
+
" **inputs,\n",
|
| 204 |
+
" max_new_tokens=64,\n",
|
| 205 |
+
" do_sample=False\n",
|
| 206 |
+
" )\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"generated_tokens = outputs[0, inputs[\"input_ids\"].shape[1] :]\n",
|
| 209 |
+
"print(f\"\\nGenerated: {tokenizer.decode(generated_tokens)}\")\n",
|
| 210 |
+
"print(\"=\" * 80)\n"
|
| 211 |
+
]
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"cell_type": "markdown",
|
| 215 |
+
"id": "a102e248",
|
| 216 |
+
"metadata": {
|
| 217 |
+
"id": "a102e248"
|
| 218 |
+
},
|
| 219 |
+
"source": [
|
| 220 |
+
"## Dataset\n"
|
| 221 |
+
]
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"cell_type": "code",
|
| 225 |
+
"execution_count": null,
|
| 226 |
+
"id": "b07e3b95",
|
| 227 |
+
"metadata": {
|
| 228 |
+
"id": "b07e3b95",
|
| 229 |
+
"outputId": "3c42b4d4-6e4f-4622-94cd-adbe53efa238"
|
| 230 |
+
},
|
| 231 |
+
"outputs": [
|
| 232 |
+
{
|
| 233 |
+
"name": "stderr",
|
| 234 |
+
"output_type": "stream",
|
| 235 |
+
"text": [
|
| 236 |
+
"Generating train split: 100%|██████████| 52736/52736 [00:00<00:00, 1058243.18 examples/s]\n"
|
| 237 |
+
]
|
| 238 |
+
}
|
| 239 |
+
],
|
| 240 |
+
"source": [
|
| 241 |
+
"raw_dataset = load_dataset(\"HuggingFaceH4/OpenR1-Math-220k-default-verified\", split=\"train\")\n",
|
| 242 |
+
"splits = raw_dataset.train_test_split(test_size=0.1, seed=42)\n",
|
| 243 |
+
"train_dataset = splits[\"train\"]\n",
|
| 244 |
+
"eval_dataset = splits[\"test\"]\n"
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"cell_type": "markdown",
|
| 249 |
+
"id": "21ec9078",
|
| 250 |
+
"metadata": {
|
| 251 |
+
"id": "21ec9078"
|
| 252 |
+
},
|
| 253 |
+
"source": [
|
| 254 |
+
"## Training Configuration\n"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"cell_type": "code",
|
| 259 |
+
"execution_count": null,
|
| 260 |
+
"id": "17a49557",
|
| 261 |
+
"metadata": {
|
| 262 |
+
"id": "17a49557"
|
| 263 |
+
},
|
| 264 |
+
"outputs": [],
|
| 265 |
+
"source": [
|
| 266 |
+
"max_train_steps = 50\n",
|
| 267 |
+
"prompt_batch_size = 1\n",
|
| 268 |
+
"num_generations = 4\n",
|
| 269 |
+
"max_new_tokens = 128\n",
|
| 270 |
+
"temperature = 1.0\n",
|
| 271 |
+
"top_k = 50\n",
|
| 272 |
+
"learning_rate = 5e-6\n",
|
| 273 |
+
"weight_decay = 0.0\n",
|
| 274 |
+
"epsilon = 0.2\n",
|
| 275 |
+
"gradient_accumulation_steps = 1\n",
|
| 276 |
+
"warmup_ratio = 0.1\n",
|
| 277 |
+
"logging_frequency = 5\n",
|
| 278 |
+
"max_train_samples = 1000\n",
|
| 279 |
+
"max_eval_samples = 100\n"
|
| 280 |
+
]
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"cell_type": "markdown",
|
| 284 |
+
"id": "a8a12581",
|
| 285 |
+
"metadata": {
|
| 286 |
+
"id": "a8a12581"
|
| 287 |
+
},
|
| 288 |
+
"source": [
|
| 289 |
+
"## Reward Functions\n"
|
| 290 |
+
]
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"cell_type": "code",
|
| 294 |
+
"execution_count": null,
|
| 295 |
+
"id": "3f07953f",
|
| 296 |
+
"metadata": {
|
| 297 |
+
"id": "3f07953f"
|
| 298 |
+
},
|
| 299 |
+
"outputs": [],
|
| 300 |
+
"source": [
|
| 301 |
+
"import re\n",
|
| 302 |
+
"import numpy as np\n",
|
| 303 |
+
"import torch.nn.functional as F\n",
|
| 304 |
+
"from contextlib import nullcontext\n",
|
| 305 |
+
"\n",
|
| 306 |
+
"\n",
|
| 307 |
+
"def think_format_reward(completions):\n",
|
| 308 |
+
" \"\"\"\n",
|
| 309 |
+
" Reward function that checks if the reasoning process is enclosed within <think> and </think> tags.\n",
|
| 310 |
+
" Returns 1.0 if the format is correct, otherwise 0.0.\n",
|
| 311 |
+
" \"\"\"\n",
|
| 312 |
+
" pattern = r\"^(?!.*<think>)(.*?)</think>.*$\"\n",
|
| 313 |
+
" matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]\n",
|
| 314 |
+
" return [1.0 if match else 0.0 for match in matches]\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"def accuracy_reward(completions, solutions):\n",
|
| 318 |
+
" \"\"\"\n",
|
| 319 |
+
" Reward function that checks if the completion matches the solution.\n",
|
| 320 |
+
" For simplicity, we'll do basic string matching here.\n",
|
| 321 |
+
" \"\"\"\n",
|
| 322 |
+
" rewards = []\n",
|
| 323 |
+
" for completion, solution in zip(completions, solutions):\n",
|
| 324 |
+
" # Simple string matching (normalized)\n",
|
| 325 |
+
" reward = 1.0 if solution.strip().lower() in completion.strip().lower() else 0.0\n",
|
| 326 |
+
" rewards.append(reward)\n",
|
| 327 |
+
" return rewards\n",
|
| 328 |
+
"\n",
|
| 329 |
+
"\n",
|
| 330 |
+
"def min_length_reward(completions, min_length=10):\n",
|
| 331 |
+
" \"\"\"\n",
|
| 332 |
+
" Reward function that checks if the completion is at least a certain length.\n",
|
| 333 |
+
" Returns 1.0 if the length is greater than or equal to the minimum length, otherwise 0.0.\n",
|
| 334 |
+
" \"\"\"\n",
|
| 335 |
+
" return [1.0 if len(completion) >= min_length else 0.0 for completion in completions]\n",
|
| 336 |
+
"\n",
|
| 337 |
+
"def combined_reward(completions, solutions):\n",
|
| 338 |
+
" \"\"\"\n",
|
| 339 |
+
" Combines format and accuracy rewards with equal weight.\n",
|
| 340 |
+
" \"\"\"\n",
|
| 341 |
+
" format_rewards = think_format_reward(completions)\n",
|
| 342 |
+
" accuracy_rewards = accuracy_reward(completions, solutions)\n",
|
| 343 |
+
" min_length_rewards = min_length_reward(completions)\n",
|
| 344 |
+
" return [np.mean([f, a, m]) for f, a, m in zip(format_rewards, accuracy_rewards, min_length_rewards)]"
|
| 345 |
+
]
|
| 346 |
+
},
|
| 347 |
+
{
|
| 348 |
+
"cell_type": "markdown",
|
| 349 |
+
"id": "b2299e86",
|
| 350 |
+
"metadata": {
|
| 351 |
+
"id": "b2299e86"
|
| 352 |
+
},
|
| 353 |
+
"source": [
|
| 354 |
+
"## Helper Functions\n"
|
| 355 |
+
]
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"cell_type": "code",
|
| 359 |
+
"execution_count": null,
|
| 360 |
+
"id": "b0f0e9e4",
|
| 361 |
+
"metadata": {
|
| 362 |
+
"id": "b0f0e9e4"
|
| 363 |
+
},
|
| 364 |
+
"outputs": [],
|
| 365 |
+
"source": [
|
| 366 |
+
"def per_token_log_probs(logits, labels):\n",
|
| 367 |
+
" logits = logits.float()\n",
|
| 368 |
+
" log_probs = F.log_softmax(logits, dim=-1)\n",
|
| 369 |
+
" return log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)\n",
|
| 370 |
+
"\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"def prepare_prompt(example, problem_key=\"problem\", solution_key=\"solution\"):\n",
|
| 373 |
+
" # Extract the messages (should be a list of dicts with 'role' and 'content')\n",
|
| 374 |
+
" prompt = example.get(problem_key, \"\")\n",
|
| 375 |
+
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
|
| 376 |
+
"\n",
|
| 377 |
+
" formatted = tokenizer.apply_chat_template(\n",
|
| 378 |
+
" messages,\n",
|
| 379 |
+
" add_generation_prompt=True,\n",
|
| 380 |
+
" truncation=True,\n",
|
| 381 |
+
" max_length=2048,\n",
|
| 382 |
+
" padding=False,\n",
|
| 383 |
+
" return_dict=True,\n",
|
| 384 |
+
" return_tensors=\"pt\",\n",
|
| 385 |
+
" )\n",
|
| 386 |
+
" return formatted[\"input_ids\"], formatted[\"attention_mask\"]\n",
|
| 387 |
+
"\n",
|
| 388 |
+
"\n",
|
| 389 |
+
"if device.type == \"cuda\":\n",
|
| 390 |
+
" autocast_ctx = torch.amp.autocast(device_type=\"cuda\", dtype=torch.bfloat16)\n",
|
| 391 |
+
"else:\n",
|
| 392 |
+
" autocast_ctx = nullcontext()\n"
|
| 393 |
+
]
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"cell_type": "markdown",
|
| 397 |
+
"id": "2756b691",
|
| 398 |
+
"metadata": {
|
| 399 |
+
"id": "2756b691"
|
| 400 |
+
},
|
| 401 |
+
"source": [
|
| 402 |
+
"## Optimizer and Scheduler\n"
|
| 403 |
+
]
|
| 404 |
+
},
|
| 405 |
+
{
|
| 406 |
+
"cell_type": "code",
|
| 407 |
+
"execution_count": null,
|
| 408 |
+
"id": "e0e05495",
|
| 409 |
+
"metadata": {
|
| 410 |
+
"id": "e0e05495"
|
| 411 |
+
},
|
| 412 |
+
"outputs": [],
|
| 413 |
+
"source": [
|
| 414 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
|
| 415 |
+
"total_update_steps = max_train_steps // gradient_accumulation_steps\n",
|
| 416 |
+
"warmup_steps = max(1, int(total_update_steps * warmup_ratio))\n",
|
| 417 |
+
"scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_update_steps)\n"
|
| 418 |
+
]
|
| 419 |
+
},
|
| 420 |
+
{
|
| 421 |
+
"cell_type": "markdown",
|
| 422 |
+
"id": "5e2c7a2c",
|
| 423 |
+
"metadata": {
|
| 424 |
+
"id": "5e2c7a2c"
|
| 425 |
+
},
|
| 426 |
+
"source": [
|
| 427 |
+
"# The Training Loop\n"
|
| 428 |
+
]
|
| 429 |
+
},
|
| 430 |
+
{
|
| 431 |
+
"cell_type": "code",
|
| 432 |
+
"execution_count": null,
|
| 433 |
+
"id": "260f574c",
|
| 434 |
+
"metadata": {
|
| 435 |
+
"id": "260f574c",
|
| 436 |
+
"outputId": "b762165f-ed4a-4b22-cbb7-2fa203696ac3"
|
| 437 |
+
},
|
| 438 |
+
"outputs": [
|
| 439 |
+
{
|
| 440 |
+
"name": "stdout",
|
| 441 |
+
"output_type": "stream",
|
| 442 |
+
"text": [
|
| 443 |
+
"step=0005 | loss=0.0000 | avg_reward=0.4000 | lr=0.00e+00\n",
|
| 444 |
+
"Sample eval completion: 3^4 - 11 and 3^6 - 17\n",
|
| 445 |
+
"step=0010 | loss=0.0000 | avg_reward=0.3333 | lr=0.00e+00\n",
|
| 446 |
+
"Sample eval completion: 11. \n",
|
| 447 |
+
"\n",
|
| 448 |
+
"This statement refers to an optimization problem where we seek to find the smallest prime \\( p\n",
|
| 449 |
+
"step=0015 | loss=0.0000 | avg_reward=0.4667 | lr=0.00e+00\n",
|
| 450 |
+
"Sample eval completion: What number has two prime factors, 1 and itself, without additional restrictions? One possible combi\n",
|
| 451 |
+
"step=0020 | loss=-0.0983 | avg_reward=0.4500 | lr=0.00e+00\n",
|
| 452 |
+
"Sample eval completion: \\[\\begin{bmatrix} 2 & 3\\\\ 6 & 11\\end{bmatrix} \\]\\[3^{a}-2^{b}\\left(\\frac{1^{a}}{a}\\right) \\left(\\fra\n",
|
| 453 |
+
"step=0025 | loss=-0.0979 | avg_reward=0.3333 | lr=0.00e+00\n",
|
| 454 |
+
"Sample eval completion: Let's examine the smallest prime \\( p \\) for which there do not exist non-negative integers \\( a, b \n",
|
| 455 |
+
"step=0030 | loss=-0.0000 | avg_reward=0.3667 | lr=0.00e+00\n",
|
| 456 |
+
"Sample eval completion: \n",
|
| 457 |
+
"Since \\( p = 23^2 + 7 \\) or \\( p \\ge 23^3 + 63 \\), and \\( p > 23 \\), we find that \\( p \\ge 9223 \\).\n",
|
| 458 |
+
"step=0035 | loss=0.0431 | avg_reward=0.4167 | lr=0.00e+00\n",
|
| 459 |
+
"Sample eval completion: \\[11 \\] = \\((3^5)\\), for all \\( a, b \\).\n",
|
| 460 |
+
"[asy]\n",
|
| 461 |
+
"import random;\n",
|
| 462 |
+
"import numpy as np;\n",
|
| 463 |
+
"\n",
|
| 464 |
+
"unitsize(1cm);\n",
|
| 465 |
+
"\n",
|
| 466 |
+
"d\n",
|
| 467 |
+
"step=0040 | loss=-0.0702 | avg_reward=0.5000 | lr=0.00e+00\n",
|
| 468 |
+
"Sample eval completion: 3^4 - 7\n",
|
| 469 |
+
"step=0045 | loss=0.0000 | avg_reward=0.3333 | lr=0.00e+00\n",
|
| 470 |
+
"Sample eval completion: 7.\n",
|
| 471 |
+
"step=0050 | loss=0.0000 | avg_reward=0.4000 | lr=0.00e+00\n",
|
| 472 |
+
"Sample eval completion: Here is the answer:\n",
|
| 473 |
+
"\n",
|
| 474 |
+
"The smallest prime \\( p \\) (where \\( p > 3 \\)) for which there do not exist non\n",
|
| 475 |
+
"Training complete.\n"
|
| 476 |
+
]
|
| 477 |
+
}
|
| 478 |
+
],
|
| 479 |
+
"source": [
|
| 480 |
+
"\n",
|
| 481 |
+
"# Sample dataset if needed\n",
|
| 482 |
+
"if max_train_samples is not None and len(train_dataset) > max_train_samples:\n",
|
| 483 |
+
" train_dataset = train_dataset.select(range(max_train_samples))\n",
|
| 484 |
+
"if max_eval_samples is not None and len(eval_dataset) > max_eval_samples:\n",
|
| 485 |
+
" eval_dataset = eval_dataset.select(range(max_eval_samples))\n",
|
| 486 |
+
"\n",
|
| 487 |
+
"model.train()\n",
|
| 488 |
+
"train_index = 0\n",
|
| 489 |
+
"global_step = 0\n",
|
| 490 |
+
"running_reward = 0.0\n",
|
| 491 |
+
"running_loss = 0.0\n",
|
| 492 |
+
"\n",
|
| 493 |
+
"for step in range(1, max_train_steps + 1):\n",
|
| 494 |
+
" example = train_dataset[train_index % len(train_dataset)]\n",
|
| 495 |
+
" train_index += 1\n",
|
| 496 |
+
"\n",
|
| 497 |
+
" prompt_ids, prompt_mask = prepare_prompt(example)\n",
|
| 498 |
+
" prompt_ids = prompt_ids.to(device)\n",
|
| 499 |
+
" prompt_mask = prompt_mask.to(device)\n",
|
| 500 |
+
" prompt_length = prompt_ids.shape[1]\n",
|
| 501 |
+
"\n",
|
| 502 |
+
" prompt_repeat = prompt_ids.repeat(num_generations, 1)\n",
|
| 503 |
+
" mask_repeat = prompt_mask.repeat(num_generations, 1)\n",
|
| 504 |
+
"\n",
|
| 505 |
+
" # Generate completions\n",
|
| 506 |
+
" model.eval()\n",
|
| 507 |
+
" with torch.no_grad():\n",
|
| 508 |
+
" generated = model.generate(\n",
|
| 509 |
+
" input_ids=prompt_repeat,\n",
|
| 510 |
+
" attention_mask=mask_repeat,\n",
|
| 511 |
+
" max_new_tokens=max_new_tokens,\n",
|
| 512 |
+
" do_sample=True,\n",
|
| 513 |
+
" temperature=temperature,\n",
|
| 514 |
+
" top_k=top_k,\n",
|
| 515 |
+
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 516 |
+
" )\n",
|
| 517 |
+
" model.train()\n",
|
| 518 |
+
"\n",
|
| 519 |
+
" sequences = generated\n",
|
| 520 |
+
" attention_mask = (sequences != tokenizer.pad_token_id).long()\n",
|
| 521 |
+
" completion_mask = attention_mask.clone()\n",
|
| 522 |
+
" completion_mask[:, :prompt_length] = 0\n",
|
| 523 |
+
"\n",
|
| 524 |
+
" completion_tokens = sequences[:, prompt_length:]\n",
|
| 525 |
+
" completion_texts = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)\n",
|
| 526 |
+
"\n",
|
| 527 |
+
" # Get solution\n",
|
| 528 |
+
" solution = example.get(\"solution\", example.get(\"answer\", \"\"))\n",
|
| 529 |
+
" solutions = [solution] * num_generations\n",
|
| 530 |
+
"\n",
|
| 531 |
+
" # Compute rewards\n",
|
| 532 |
+
" rewards = combined_reward(completion_texts, solutions)\n",
|
| 533 |
+
" rewards = torch.tensor(rewards, dtype=torch.float32, device=device)\n",
|
| 534 |
+
" running_reward += rewards.mean().item()\n",
|
| 535 |
+
"\n",
|
| 536 |
+
" rewards_view = rewards.view(prompt_batch_size, num_generations)\n",
|
| 537 |
+
" mean_rewards = rewards_view.mean(dim=1, keepdim=True)\n",
|
| 538 |
+
" std_rewards = rewards_view.std(dim=1, keepdim=True)\n",
|
| 539 |
+
" std_rewards = torch.where(std_rewards > 0, std_rewards, torch.ones_like(std_rewards))\n",
|
| 540 |
+
" advantages = ((rewards_view - mean_rewards) / std_rewards).view(-1)\n",
|
| 541 |
+
"\n",
|
| 542 |
+
" labels = sequences[:, 1:].clone()\n",
|
| 543 |
+
" labels[attention_mask[:, 1:] == 0] = tokenizer.pad_token_id\n",
|
| 544 |
+
"\n",
|
| 545 |
+
" # Compute old log probs\n",
|
| 546 |
+
" with torch.no_grad():\n",
|
| 547 |
+
" with (autocast_ctx if device.type == \"cuda\" else nullcontext()):\n",
|
| 548 |
+
" old_outputs = model(\n",
|
| 549 |
+
" input_ids=sequences,\n",
|
| 550 |
+
" attention_mask=attention_mask,\n",
|
| 551 |
+
" use_cache=False,\n",
|
| 552 |
+
" )\n",
|
| 553 |
+
" old_log_probs = per_token_log_probs(old_outputs.logits[:, :-1], labels)\n",
|
| 554 |
+
"\n",
|
| 555 |
+
" valid_mask = (completion_mask[:, 1:] == 1) & (labels != tokenizer.pad_token_id)\n",
|
| 556 |
+
"\n",
|
| 557 |
+
" # Compute loss\n",
|
| 558 |
+
" optimizer.zero_grad(set_to_none=True)\n",
|
| 559 |
+
" with (autocast_ctx if device.type == \"cuda\" else nullcontext()):\n",
|
| 560 |
+
" outputs = model(\n",
|
| 561 |
+
" input_ids=sequences,\n",
|
| 562 |
+
" attention_mask=attention_mask,\n",
|
| 563 |
+
" use_cache=False,\n",
|
| 564 |
+
" )\n",
|
| 565 |
+
" log_probs = per_token_log_probs(outputs.logits[:, :-1], labels)\n",
|
| 566 |
+
"\n",
|
| 567 |
+
" ratio = (log_probs - old_log_probs).exp()\n",
|
| 568 |
+
" ratio = torch.where(valid_mask, ratio, torch.ones_like(ratio))\n",
|
| 569 |
+
" clipped_ratio = ratio.clamp(1.0 - epsilon, 1.0 + epsilon)\n",
|
| 570 |
+
"\n",
|
| 571 |
+
" adv = advantages.unsqueeze(1)\n",
|
| 572 |
+
" loss_unclipped = ratio * adv\n",
|
| 573 |
+
" loss_clipped = clipped_ratio * adv\n",
|
| 574 |
+
" per_token_loss = -torch.min(loss_unclipped, loss_clipped)\n",
|
| 575 |
+
" per_token_loss = torch.where(valid_mask, per_token_loss, torch.zeros_like(per_token_loss))\n",
|
| 576 |
+
"\n",
|
| 577 |
+
" denom = valid_mask.sum().clamp(min=1)\n",
|
| 578 |
+
" loss = per_token_loss.sum() / denom\n",
|
| 579 |
+
"\n",
|
| 580 |
+
" loss.backward()\n",
|
| 581 |
+
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
|
| 582 |
+
" optimizer.step()\n",
|
| 583 |
+
" scheduler.step()\n",
|
| 584 |
+
"\n",
|
| 585 |
+
" global_step += 1\n",
|
| 586 |
+
" running_loss += loss.item()\n",
|
| 587 |
+
"\n",
|
| 588 |
+
" if step % logging_frequency == 0:\n",
|
| 589 |
+
" avg_reward = running_reward / logging_frequency\n",
|
| 590 |
+
" avg_loss = running_loss / logging_frequency\n",
|
| 591 |
+
" current_lr = scheduler.get_last_lr()[0]\n",
|
| 592 |
+
" print(\n",
|
| 593 |
+
" f\"step={step:04d} | loss={avg_loss:.4f} | avg_reward={avg_reward:.4f} | lr={current_lr:.2e}\"\n",
|
| 594 |
+
" )\n",
|
| 595 |
+
" running_reward = 0.0\n",
|
| 596 |
+
" running_loss = 0.0\n",
|
| 597 |
+
"\n",
|
| 598 |
+
" # Sample evaluation\n",
|
| 599 |
+
" model.eval()\n",
|
| 600 |
+
" eval_example = eval_dataset[0]\n",
|
| 601 |
+
" prompt_ids, prompt_mask = prepare_prompt(eval_example)\n",
|
| 602 |
+
" with torch.no_grad():\n",
|
| 603 |
+
" eval_sequences = model.generate(\n",
|
| 604 |
+
" input_ids=prompt_ids.to(device),\n",
|
| 605 |
+
" attention_mask=prompt_mask.to(device),\n",
|
| 606 |
+
" max_new_tokens=max_new_tokens,\n",
|
| 607 |
+
" do_sample=True,\n",
|
| 608 |
+
" top_k=top_k,\n",
|
| 609 |
+
" temperature=temperature,\n",
|
| 610 |
+
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 611 |
+
" )\n",
|
| 612 |
+
" model.train()\n",
|
| 613 |
+
" completion = eval_sequences[0, prompt_ids.shape[1] :]\n",
|
| 614 |
+
" print(\"Sample eval completion:\", tokenizer.decode(completion, skip_special_tokens=True)[:100])\n",
|
| 615 |
+
"\n",
|
| 616 |
+
"print(\"Training complete.\")\n"
|
| 617 |
+
]
|
| 618 |
+
},
|
| 619 |
+
{
|
| 620 |
+
"cell_type": "code",
|
| 621 |
+
"execution_count": null,
|
| 622 |
+
"id": "2104662d",
|
| 623 |
+
"metadata": {
|
| 624 |
+
"id": "2104662d"
|
| 625 |
+
},
|
| 626 |
+
"outputs": [],
|
| 627 |
+
"source": []
|
| 628 |
+
}
|
| 629 |
+
],
|
| 630 |
+
"metadata": {
|
| 631 |
+
"kernelspec": {
|
| 632 |
+
"display_name": ".venv",
|
| 633 |
+
"language": "python",
|
| 634 |
+
"name": "python3"
|
| 635 |
+
},
|
| 636 |
+
"language_info": {
|
| 637 |
+
"codemirror_mode": {
|
| 638 |
+
"name": "ipython",
|
| 639 |
+
"version": 3
|
| 640 |
+
},
|
| 641 |
+
"file_extension": ".py",
|
| 642 |
+
"mimetype": "text/x-python",
|
| 643 |
+
"name": "python",
|
| 644 |
+
"nbconvert_exporter": "python",
|
| 645 |
+
"pygments_lexer": "ipython3",
|
| 646 |
+
"version": "3.10.18"
|
| 647 |
+
},
|
| 648 |
+
"colab": {
|
| 649 |
+
"provenance": []
|
| 650 |
+
}
|
| 651 |
+
},
|
| 652 |
+
"nbformat": 4,
|
| 653 |
+
"nbformat_minor": 5
|
| 654 |
+
}
|
sft.ipynb
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "b7eb261b",
|
| 6 |
+
"metadata": {
|
| 7 |
+
"id": "b7eb261b"
|
| 8 |
+
},
|
| 9 |
+
"source": [
|
| 10 |
+
"# NanoChat Easy - SFT Training\n"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "markdown",
|
| 15 |
+
"id": "8b8a04a8",
|
| 16 |
+
"metadata": {
|
| 17 |
+
"id": "8b8a04a8"
|
| 18 |
+
},
|
| 19 |
+
"source": [
|
| 20 |
+
"## Import model and tokenizer\n"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": null,
|
| 26 |
+
"id": "3e48247c",
|
| 27 |
+
"metadata": {
|
| 28 |
+
"id": "3e48247c",
|
| 29 |
+
"outputId": "882fcf01-34fb-4123-e84c-deefdf477814"
|
| 30 |
+
},
|
| 31 |
+
"outputs": [
|
| 32 |
+
{
|
| 33 |
+
"name": "stderr",
|
| 34 |
+
"output_type": "stream",
|
| 35 |
+
"text": [
|
| 36 |
+
"/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 37 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 38 |
+
"`torch_dtype` is deprecated! Use `dtype` instead!\n"
|
| 39 |
+
]
|
| 40 |
+
}
|
| 41 |
+
],
|
| 42 |
+
"source": [
|
| 43 |
+
"import torch\n",
|
| 44 |
+
"from torch.utils.data import DataLoader\n",
|
| 45 |
+
"from datasets import load_dataset\n",
|
| 46 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"model_id = \"karpathy/nanochat-d32\"\n",
|
| 50 |
+
"revision = \"refs/pr/1\"\n",
|
| 51 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)\n",
|
| 55 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 56 |
+
" model_id,\n",
|
| 57 |
+
" revision=revision,\n",
|
| 58 |
+
" torch_dtype=torch.bfloat16 if device.type == \"cuda\" else torch.float32,\n",
|
| 59 |
+
").to(device)\n"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "markdown",
|
| 64 |
+
"id": "c9a9c0a4",
|
| 65 |
+
"metadata": {
|
| 66 |
+
"id": "c9a9c0a4"
|
| 67 |
+
},
|
| 68 |
+
"source": [
|
| 69 |
+
"## Setup LoRA\n"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"id": "dd9a698a",
|
| 76 |
+
"metadata": {
|
| 77 |
+
"id": "dd9a698a",
|
| 78 |
+
"outputId": "0aae9ecc-7af9-436e-a95b-a4cd023997fd"
|
| 79 |
+
},
|
| 80 |
+
"outputs": [
|
| 81 |
+
{
|
| 82 |
+
"name": "stdout",
|
| 83 |
+
"output_type": "stream",
|
| 84 |
+
"text": [
|
| 85 |
+
"trainable params: 1,179,648 || all params: 1,880,227,840 || trainable%: 0.0627\n"
|
| 86 |
+
]
|
| 87 |
+
}
|
| 88 |
+
],
|
| 89 |
+
"source": [
|
| 90 |
+
"from peft import LoraConfig, get_peft_model\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"lora_config = LoraConfig(\n",
|
| 93 |
+
" r=1,\n",
|
| 94 |
+
" lora_alpha=2,\n",
|
| 95 |
+
" lora_dropout=0.00,\n",
|
| 96 |
+
" task_type=\"CAUSAL_LM\",\n",
|
| 97 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"fc1\", \"fc2\"]\n",
|
| 98 |
+
")\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"model = get_peft_model(model, lora_config)\n",
|
| 101 |
+
"model.print_trainable_parameters()\n"
|
| 102 |
+
]
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"cell_type": "markdown",
|
| 106 |
+
"id": "4810af1a",
|
| 107 |
+
"metadata": {
|
| 108 |
+
"id": "4810af1a"
|
| 109 |
+
},
|
| 110 |
+
"source": [
|
| 111 |
+
"## Demo the model\n"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"execution_count": null,
|
| 117 |
+
"id": "b3e81aa9",
|
| 118 |
+
"metadata": {
|
| 119 |
+
"id": "b3e81aa9",
|
| 120 |
+
"outputId": "1cde7e69-7ff1-4bfe-aa9f-9ded20249d82"
|
| 121 |
+
},
|
| 122 |
+
"outputs": [
|
| 123 |
+
{
|
| 124 |
+
"name": "stdout",
|
| 125 |
+
"output_type": "stream",
|
| 126 |
+
"text": [
|
| 127 |
+
"================================================================================\n",
|
| 128 |
+
"TEST 1: Plain Autoregressive Prompt\n",
|
| 129 |
+
"================================================================================\n",
|
| 130 |
+
"Prompt: The Eiffel Tower stands in Paris and\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"Generated: is one of the most famous landmarks in the world. It is located on the Champ de Mars in the heart of the city. The tower was built for the 1889 World's Fair. It was designed by the French engineer Gustave Eiffel and took 2 years to build. The Eiffel Tower stands 324 meters\n",
|
| 133 |
+
"================================================================================\n"
|
| 134 |
+
]
|
| 135 |
+
}
|
| 136 |
+
],
|
| 137 |
+
"source": [
|
| 138 |
+
"print(\"=\" * 80)\n",
|
| 139 |
+
"print(\"TEST 1: Plain Autoregressive Prompt\")\n",
|
| 140 |
+
"print(\"=\" * 80)\n",
|
| 141 |
+
"prompt = \"The Eiffel Tower stands in Paris and\"\n",
|
| 142 |
+
"test_inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"with torch.no_grad():\n",
|
| 146 |
+
" test_outputs = model.generate(\n",
|
| 147 |
+
" **test_inputs,\n",
|
| 148 |
+
" max_new_tokens=64,\n",
|
| 149 |
+
" do_sample=False,\n",
|
| 150 |
+
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 151 |
+
" )\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"generated_tokens = test_outputs[0, test_inputs[\"input_ids\"].shape[1] :]\n",
|
| 154 |
+
"print(f\"Prompt: {prompt}\")\n",
|
| 155 |
+
"print(f\"\\nGenerated: {tokenizer.decode(generated_tokens, skip_special_tokens=True)}\")\n",
|
| 156 |
+
"print(\"=\" * 80)\n"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"cell_type": "code",
|
| 161 |
+
"execution_count": null,
|
| 162 |
+
"id": "8e7b275c",
|
| 163 |
+
"metadata": {
|
| 164 |
+
"id": "8e7b275c",
|
| 165 |
+
"outputId": "719e986e-61b4-4fd5-db15-4a9ef8f97396"
|
| 166 |
+
},
|
| 167 |
+
"outputs": [
|
| 168 |
+
{
|
| 169 |
+
"name": "stdout",
|
| 170 |
+
"output_type": "stream",
|
| 171 |
+
"text": [
|
| 172 |
+
"================================================================================\n",
|
| 173 |
+
"TEST 2: Chat Template\n",
|
| 174 |
+
"================================================================================\n",
|
| 175 |
+
"Formatted prompt: <|bos|><|user_start|>What is the capital of France?<|user_end|><|assistant_start|>\n",
|
| 176 |
+
"Input IDs: [65527, 65528, 1442, 309, 261, 3429, 281, 4215, 63, 65529, 65530]\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"Generated: The capital of France is Paris.<|assistant_end|>\n",
|
| 179 |
+
"================================================================================\n"
|
| 180 |
+
]
|
| 181 |
+
}
|
| 182 |
+
],
|
| 183 |
+
"source": [
|
| 184 |
+
"print(\"=\" * 80)\n",
|
| 185 |
+
"print(\"TEST 2: Chat Template\")\n",
|
| 186 |
+
"print(\"=\"*80)\n",
|
| 187 |
+
"conversation = [\n",
|
| 188 |
+
" {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n",
|
| 189 |
+
"]\n",
|
| 190 |
+
"\n",
|
| 191 |
+
"inputs = tokenizer.apply_chat_template(\n",
|
| 192 |
+
" conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors=\"pt\"\n",
|
| 193 |
+
").to(device)\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"print(f\"Formatted prompt: {tokenizer.decode(inputs['input_ids'][0])}\")\n",
|
| 196 |
+
"print(f\"Input IDs: {inputs['input_ids'][0].tolist()}\")\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"with torch.no_grad():\n",
|
| 199 |
+
" outputs = model.generate(\n",
|
| 200 |
+
" **inputs,\n",
|
| 201 |
+
" max_new_tokens=64,\n",
|
| 202 |
+
" do_sample=False\n",
|
| 203 |
+
" )\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"generated_tokens = outputs[0, inputs[\"input_ids\"].shape[1] :]\n",
|
| 206 |
+
"print(f\"\\nGenerated: {tokenizer.decode(generated_tokens)}\")\n",
|
| 207 |
+
"print(\"=\" * 80)\n"
|
| 208 |
+
]
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"cell_type": "markdown",
|
| 212 |
+
"id": "44cb321a",
|
| 213 |
+
"metadata": {
|
| 214 |
+
"id": "44cb321a"
|
| 215 |
+
},
|
| 216 |
+
"source": [
|
| 217 |
+
"## Dataset\n"
|
| 218 |
+
]
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"cell_type": "code",
|
| 222 |
+
"execution_count": null,
|
| 223 |
+
"id": "e1a75c14",
|
| 224 |
+
"metadata": {
|
| 225 |
+
"id": "e1a75c14"
|
| 226 |
+
},
|
| 227 |
+
"outputs": [],
|
| 228 |
+
"source": [
|
| 229 |
+
"raw_dataset = load_dataset(\"HuggingFaceTB/smoltalk2\", \"SFT\", split=\"OpenThoughts3_1.2M_think\")\n",
|
| 230 |
+
"splits = raw_dataset.train_test_split(test_size=0.1, seed=42)\n",
|
| 231 |
+
"train_dataset = splits[\"train\"]\n",
|
| 232 |
+
"eval_dataset = splits[\"test\"]\n"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"cell_type": "markdown",
|
| 237 |
+
"id": "8b29399d",
|
| 238 |
+
"metadata": {
|
| 239 |
+
"id": "8b29399d"
|
| 240 |
+
},
|
| 241 |
+
"source": [
|
| 242 |
+
"### Process the Dataset\n"
|
| 243 |
+
]
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"cell_type": "code",
|
| 247 |
+
"execution_count": null,
|
| 248 |
+
"id": "451542b4",
|
| 249 |
+
"metadata": {
|
| 250 |
+
"id": "451542b4",
|
| 251 |
+
"outputId": "caa727dd-f9d8-4c67-d193-79bcc0836b49"
|
| 252 |
+
},
|
| 253 |
+
"outputs": [
|
| 254 |
+
{
|
| 255 |
+
"name": "stderr",
|
| 256 |
+
"output_type": "stream",
|
| 257 |
+
"text": [
|
| 258 |
+
"Map: 0%| | 0/20000 [00:00<?, ? examples/s]"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"name": "stderr",
|
| 263 |
+
"output_type": "stream",
|
| 264 |
+
"text": [
|
| 265 |
+
"Map: 100%|██████████| 20000/20000 [06:27<00:00, 51.68 examples/s]\n",
|
| 266 |
+
"Map: 100%|██████████| 1000/1000 [00:19<00:00, 52.12 examples/s]\n"
|
| 267 |
+
]
|
| 268 |
+
}
|
| 269 |
+
],
|
| 270 |
+
"source": [
|
| 271 |
+
"max_length = 2048\n",
|
| 272 |
+
"max_train_examples = 20000\n",
|
| 273 |
+
"max_eval_examples = 1000\n",
|
| 274 |
+
"\n",
|
| 275 |
+
"def format_example(example):\n",
|
| 276 |
+
" formatted = tokenizer.apply_chat_template(\n",
|
| 277 |
+
" example[\"messages\"],\n",
|
| 278 |
+
" add_generation_prompt=False,\n",
|
| 279 |
+
" truncation=True,\n",
|
| 280 |
+
" max_length=max_length,\n",
|
| 281 |
+
" padding=False,\n",
|
| 282 |
+
" return_dict=True,\n",
|
| 283 |
+
" return_tensors=\"pt\",\n",
|
| 284 |
+
" )\n",
|
| 285 |
+
" return {\n",
|
| 286 |
+
" \"input_ids\": formatted[\"input_ids\"][0].tolist(),\n",
|
| 287 |
+
" \"attention_mask\": formatted[\"attention_mask\"][0].tolist(),\n",
|
| 288 |
+
" }\n",
|
| 289 |
+
"\n",
|
| 290 |
+
"\n",
|
| 291 |
+
"if max_train_examples is not None:\n",
|
| 292 |
+
" train_dataset = train_dataset.select(range(min(len(train_dataset), max_train_examples)))\n",
|
| 293 |
+
" train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)\n",
|
| 294 |
+
"else:\n",
|
| 295 |
+
" train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)\n",
|
| 296 |
+
"\n",
|
| 297 |
+
"if max_eval_examples is not None:\n",
|
| 298 |
+
" eval_dataset = eval_dataset.select(range(min(len(eval_dataset), max_eval_examples)))\n",
|
| 299 |
+
" eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)\n",
|
| 300 |
+
"else:\n",
|
| 301 |
+
" eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)\n"
|
| 302 |
+
]
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"cell_type": "markdown",
|
| 306 |
+
"id": "ecd33dd7",
|
| 307 |
+
"metadata": {
|
| 308 |
+
"id": "ecd33dd7"
|
| 309 |
+
},
|
| 310 |
+
"source": [
|
| 311 |
+
"## Training Configuration"
|
| 312 |
+
]
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"cell_type": "code",
|
| 316 |
+
"execution_count": null,
|
| 317 |
+
"id": "f9d837ee",
|
| 318 |
+
"metadata": {
|
| 319 |
+
"id": "f9d837ee"
|
| 320 |
+
},
|
| 321 |
+
"outputs": [],
|
| 322 |
+
"source": [
|
| 323 |
+
"train_batch_size = 2\n",
|
| 324 |
+
"eval_batch_size = 2\n",
|
| 325 |
+
"num_epochs = 1\n",
|
| 326 |
+
"gradient_accumulation_steps = 4\n",
|
| 327 |
+
"learning_rate = 1e-5\n",
|
| 328 |
+
"weight_decay = 0.0\n",
|
| 329 |
+
"warmup_ratio = 0.03\n",
|
| 330 |
+
"logging_frequency = 10"
|
| 331 |
+
]
|
| 332 |
+
},
|
| 333 |
+
{
|
| 334 |
+
"cell_type": "markdown",
|
| 335 |
+
"id": "1cf11e96",
|
| 336 |
+
"metadata": {
|
| 337 |
+
"id": "1cf11e96"
|
| 338 |
+
},
|
| 339 |
+
"source": [
|
| 340 |
+
"## Create a `DataLoader` 👴"
|
| 341 |
+
]
|
| 342 |
+
},
|
| 343 |
+
{
|
| 344 |
+
"cell_type": "code",
|
| 345 |
+
"execution_count": null,
|
| 346 |
+
"id": "1bc4fa24",
|
| 347 |
+
"metadata": {
|
| 348 |
+
"id": "1bc4fa24"
|
| 349 |
+
},
|
| 350 |
+
"outputs": [],
|
| 351 |
+
"source": [
|
| 352 |
+
"def collate_fn(batch):\n",
|
| 353 |
+
" batch_dict = {\n",
|
| 354 |
+
" \"input_ids\": [record[\"input_ids\"] for record in batch],\n",
|
| 355 |
+
" \"attention_mask\": [record[\"attention_mask\"] for record in batch],\n",
|
| 356 |
+
" }\n",
|
| 357 |
+
" padded = tokenizer.pad(batch_dict, padding=True, return_tensors=\"pt\")\n",
|
| 358 |
+
" labels = padded[\"input_ids\"].clone()\n",
|
| 359 |
+
" labels[padded[\"attention_mask\"] == 0] = -100\n",
|
| 360 |
+
" padded[\"labels\"] = labels\n",
|
| 361 |
+
" return padded\n",
|
| 362 |
+
"\n",
|
| 363 |
+
"\n",
|
| 364 |
+
"TrainLoader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)\n",
|
| 365 |
+
"EvalLoader = DataLoader(eval_dataset, batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn)\n"
|
| 366 |
+
]
|
| 367 |
+
},
|
| 368 |
+
{
|
| 369 |
+
"cell_type": "markdown",
|
| 370 |
+
"id": "f5965d1b",
|
| 371 |
+
"metadata": {
|
| 372 |
+
"id": "f5965d1b"
|
| 373 |
+
},
|
| 374 |
+
"source": [
|
| 375 |
+
"## Optimizer"
|
| 376 |
+
]
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"cell_type": "code",
|
| 380 |
+
"execution_count": null,
|
| 381 |
+
"id": "f57c7be2",
|
| 382 |
+
"metadata": {
|
| 383 |
+
"id": "f57c7be2"
|
| 384 |
+
},
|
| 385 |
+
"outputs": [],
|
| 386 |
+
"source": [
|
| 387 |
+
"optimizer = torch.optim.AdamW(\n",
|
| 388 |
+
" model.parameters(),\n",
|
| 389 |
+
" lr=learning_rate,\n",
|
| 390 |
+
" weight_decay=weight_decay,\n",
|
| 391 |
+
")\n"
|
| 392 |
+
]
|
| 393 |
+
},
|
| 394 |
+
{
|
| 395 |
+
"cell_type": "markdown",
|
| 396 |
+
"id": "215f8782",
|
| 397 |
+
"metadata": {
|
| 398 |
+
"id": "215f8782"
|
| 399 |
+
},
|
| 400 |
+
"source": [
|
| 401 |
+
"# Learning Rate Scheduler"
|
| 402 |
+
]
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"cell_type": "code",
|
| 406 |
+
"execution_count": null,
|
| 407 |
+
"id": "034e2903",
|
| 408 |
+
"metadata": {
|
| 409 |
+
"id": "034e2903"
|
| 410 |
+
},
|
| 411 |
+
"outputs": [],
|
| 412 |
+
"source": [
|
| 413 |
+
"num_update_steps_per_epoch = max(len(TrainLoader) // gradient_accumulation_steps, 1)\n",
|
| 414 |
+
"max_train_steps = num_epochs * num_update_steps_per_epoch\n",
|
| 415 |
+
"warmup_steps = max(1, int(max_train_steps * warmup_ratio))\n",
|
| 416 |
+
"scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, max_train_steps)\n"
|
| 417 |
+
]
|
| 418 |
+
},
|
| 419 |
+
{
|
| 420 |
+
"cell_type": "markdown",
|
| 421 |
+
"id": "0f0090b6",
|
| 422 |
+
"metadata": {
|
| 423 |
+
"id": "0f0090b6"
|
| 424 |
+
},
|
| 425 |
+
"source": [
|
| 426 |
+
"# The Training Loop"
|
| 427 |
+
]
|
| 428 |
+
},
|
| 429 |
+
{
|
| 430 |
+
"cell_type": "code",
|
| 431 |
+
"execution_count": null,
|
| 432 |
+
"id": "1540e30a",
|
| 433 |
+
"metadata": {
|
| 434 |
+
"id": "1540e30a",
|
| 435 |
+
"outputId": "747badd7-18df-441f-8026-7aa4f30c2fd7"
|
| 436 |
+
},
|
| 437 |
+
"outputs": [
|
| 438 |
+
{
|
| 439 |
+
"name": "stderr",
|
| 440 |
+
"output_type": "stream",
|
| 441 |
+
"text": [
|
| 442 |
+
"You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
|
| 443 |
+
]
|
| 444 |
+
},
|
| 445 |
+
{
|
| 446 |
+
"name": "stdout",
|
| 447 |
+
"output_type": "stream",
|
| 448 |
+
"text": [
|
| 449 |
+
"Epoch 1/1\n",
|
| 450 |
+
"step=00010 | loss=1.7586 | lr=1.33e-06\n",
|
| 451 |
+
"step=00020 | loss=1.8188 | lr=2.67e-06\n",
|
| 452 |
+
"step=00030 | loss=1.8235 | lr=4.00e-06\n",
|
| 453 |
+
"step=00040 | loss=1.7935 | lr=5.33e-06\n",
|
| 454 |
+
"step=00050 | loss=1.8029 | lr=6.67e-06\n",
|
| 455 |
+
"step=00060 | loss=1.8433 | lr=8.00e-06\n",
|
| 456 |
+
"step=00070 | loss=1.8616 | lr=9.33e-06\n",
|
| 457 |
+
"step=00080 | loss=1.8238 | lr=9.98e-06\n",
|
| 458 |
+
"step=00090 | loss=1.7774 | lr=9.94e-06\n",
|
| 459 |
+
"step=00100 | loss=1.8081 | lr=9.90e-06\n",
|
| 460 |
+
"step=00110 | loss=1.7437 | lr=9.86e-06\n",
|
| 461 |
+
"step=00120 | loss=1.7830 | lr=9.81e-06\n",
|
| 462 |
+
"step=00130 | loss=1.8064 | lr=9.77e-06\n",
|
| 463 |
+
"step=00140 | loss=1.8541 | lr=9.73e-06\n",
|
| 464 |
+
"step=00150 | loss=1.8301 | lr=9.69e-06\n",
|
| 465 |
+
"step=00160 | loss=1.7725 | lr=9.65e-06\n",
|
| 466 |
+
"step=00170 | loss=1.7635 | lr=9.61e-06\n",
|
| 467 |
+
"step=00180 | loss=1.7963 | lr=9.57e-06\n",
|
| 468 |
+
"step=00190 | loss=1.7563 | lr=9.53e-06\n",
|
| 469 |
+
"step=00200 | loss=1.6950 | lr=9.48e-06\n",
|
| 470 |
+
"step=00210 | loss=1.7680 | lr=9.44e-06\n",
|
| 471 |
+
"step=00220 | loss=1.8906 | lr=9.40e-06\n",
|
| 472 |
+
"step=00230 | loss=1.7120 | lr=9.36e-06\n",
|
| 473 |
+
"step=00240 | loss=1.8390 | lr=9.32e-06\n",
|
| 474 |
+
"step=00250 | loss=1.7180 | lr=9.28e-06\n",
|
| 475 |
+
"step=00260 | loss=1.7709 | lr=9.24e-06\n",
|
| 476 |
+
"step=00270 | loss=1.7598 | lr=9.20e-06\n",
|
| 477 |
+
"step=00280 | loss=1.7981 | lr=9.15e-06\n",
|
| 478 |
+
"step=00290 | loss=1.7540 | lr=9.11e-06\n",
|
| 479 |
+
"step=00300 | loss=1.7695 | lr=9.07e-06\n",
|
| 480 |
+
"step=00310 | loss=1.7468 | lr=9.03e-06\n"
|
| 481 |
+
]
|
| 482 |
+
},
|
| 483 |
+
{
|
| 484 |
+
"ename": "KeyboardInterrupt",
|
| 485 |
+
"evalue": "",
|
| 486 |
+
"output_type": "error",
|
| 487 |
+
"traceback": [
|
| 488 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 489 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 490 |
+
"Cell \u001b[0;32mIn[14], line 11\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step, batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(TrainLoader, start\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m):\n\u001b[1;32m 10\u001b[0m batch \u001b[38;5;241m=\u001b[39m {key: value\u001b[38;5;241m.\u001b[39mto(device) \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m batch\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m---> 11\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m loss \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;241m/\u001b[39m gradient_accumulation_steps\n\u001b[1;32m 13\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n",
|
| 491 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 492 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
|
| 493 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/peft/peft_model.py:1850\u001b[0m, in \u001b[0;36mPeftModelForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)\u001b[0m\n\u001b[1;32m 1848\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_enable_peft_forward_hooks(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 1849\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mspecial_peft_forward_args}\n\u001b[0;32m-> 1850\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1851\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1852\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1853\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1854\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1855\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1856\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1857\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1858\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1859\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1861\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m _get_batch_size(input_ids, inputs_embeds)\n\u001b[1;32m 1862\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1863\u001b[0m \u001b[38;5;66;03m# concat prompt attention mask\u001b[39;00m\n",
|
| 494 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 495 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
|
| 496 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:222\u001b[0m, in \u001b[0;36mBaseTuner.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any):\n\u001b[0;32m--> 222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 497 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/utils/generic.py:757\u001b[0m, in \u001b[0;36mcan_return_tuple.<locals>.wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 755\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_dict_passed \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 756\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict_passed\n\u001b[0;32m--> 757\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 758\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 759\u001b[0m output \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39mto_tuple()\n",
|
| 498 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:474\u001b[0m, in \u001b[0;36mNanoChatForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, cache_position, logits_to_keep, **kwargs)\u001b[0m\n\u001b[1;32m 435\u001b[0m \u001b[38;5;129m@can_return_tuple\u001b[39m\n\u001b[1;32m 436\u001b[0m \u001b[38;5;129m@auto_docstring\u001b[39m\n\u001b[1;32m 437\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 448\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Unpack[TransformersKwargs],\n\u001b[1;32m 449\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m CausalLMOutputWithPast:\n\u001b[1;32m 450\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 451\u001b[0m \u001b[38;5;124;03m Example:\u001b[39;00m\n\u001b[1;32m 452\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 472\u001b[0m \u001b[38;5;124;03m >>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)\u001b[39;00m\n\u001b[1;32m 473\u001b[0m \u001b[38;5;124;03m ```\"\"\"\u001b[39;00m\n\u001b[0;32m--> 474\u001b[0m outputs: BaseModelOutputWithPast \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 475\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 476\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 477\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 478\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 479\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 480\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 481\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 482\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 483\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlast_hidden_state\n\u001b[1;32m 486\u001b[0m slice_indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mslice\u001b[39m(\u001b[38;5;241m-\u001b[39mlogits_to_keep, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(logits_to_keep, \u001b[38;5;28mint\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m logits_to_keep\n",
|
| 499 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 500 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
|
| 501 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/utils/generic.py:927\u001b[0m, in \u001b[0;36mcheck_model_inputs.<locals>.wrapped_fn.<locals>.wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 924\u001b[0m monkey_patched_layers\u001b[38;5;241m.\u001b[39mappend((module, original_forward))\n\u001b[1;32m 926\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 927\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 928\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m original_exception:\n\u001b[1;32m 929\u001b[0m \u001b[38;5;66;03m# If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.\u001b[39;00m\n\u001b[1;32m 930\u001b[0m \u001b[38;5;66;03m# Get a TypeError even after removing the recordable kwargs -> re-raise the original exception\u001b[39;00m\n\u001b[1;32m 931\u001b[0m \u001b[38;5;66;03m# Otherwise -> we're probably missing `**kwargs` in the decorated function\u001b[39;00m\n\u001b[1;32m 932\u001b[0m kwargs_without_recordable \u001b[38;5;241m=\u001b[39m {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m recordable_keys}\n",
|
| 502 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:401\u001b[0m, in \u001b[0;36mNanoChatModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m 398\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minitial_norm(hidden_states)\n\u001b[1;32m 400\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m decoder_layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m--> 401\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 402\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 403\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 404\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 405\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 406\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 407\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 408\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 409\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 410\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 412\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm(hidden_states)\n\u001b[1;32m 414\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m BaseModelOutputWithPast(\n\u001b[1;32m 415\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mhidden_states,\n\u001b[1;32m 416\u001b[0m past_key_values\u001b[38;5;241m=\u001b[39mpast_key_values \u001b[38;5;28;01mif\u001b[39;00m use_cache \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 417\u001b[0m )\n",
|
| 503 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/modeling_layers.py:94\u001b[0m, in \u001b[0;36mGradientCheckpointingLayer.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 91\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning_once(message)\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(partial(\u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs), \u001b[38;5;241m*\u001b[39margs)\n\u001b[0;32m---> 94\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 504 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 505 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
|
| 506 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:279\u001b[0m, in \u001b[0;36mNanoChatDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_values, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 268\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 269\u001b[0m hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 276\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Unpack[TransformersKwargs],\n\u001b[1;32m 277\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m 278\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[0;32m--> 279\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_layernorm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[1;32m 281\u001b[0m hidden_states, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mself_attn(\n\u001b[1;32m 282\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mhidden_states,\n\u001b[1;32m 283\u001b[0m attention_mask\u001b[38;5;241m=\u001b[39mattention_mask,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 290\u001b[0m )\n",
|
| 507 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 508 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
|
| 509 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:53\u001b[0m, in \u001b[0;36mNanoChatRMSNorm.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 53\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_norm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mtype_as(x)\n",
|
| 510 |
+
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:50\u001b[0m, in \u001b[0;36mNanoChatRMSNorm._norm\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_norm\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 50\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mrsqrt(\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeepdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39meps)\n",
|
| 511 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
| 512 |
+
]
|
| 513 |
+
}
|
| 514 |
+
],
|
| 515 |
+
"source": [
|
| 516 |
+
"\n",
|
| 517 |
+
"model.train()\n",
|
| 518 |
+
"global_step = 0\n",
|
| 519 |
+
"running_loss = 0.0\n",
|
| 520 |
+
"running_steps = 0\n",
|
| 521 |
+
"\n",
|
| 522 |
+
"for epoch in range(num_epochs):\n",
|
| 523 |
+
" print(f\"Epoch {epoch + 1}/{num_epochs}\")\n",
|
| 524 |
+
" optimizer.zero_grad(set_to_none=True)\n",
|
| 525 |
+
" for step, batch in enumerate(TrainLoader, start=1):\n",
|
| 526 |
+
" batch = {key: value.to(device) for key, value in batch.items()}\n",
|
| 527 |
+
" outputs = model(**batch)\n",
|
| 528 |
+
" loss = outputs.loss / gradient_accumulation_steps\n",
|
| 529 |
+
" loss.backward()\n",
|
| 530 |
+
"\n",
|
| 531 |
+
" running_loss += outputs.loss.float().item()\n",
|
| 532 |
+
" running_steps += 1\n",
|
| 533 |
+
"\n",
|
| 534 |
+
" if step % gradient_accumulation_steps == 0 or step == len(TrainLoader):\n",
|
| 535 |
+
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
|
| 536 |
+
" optimizer.step()\n",
|
| 537 |
+
" scheduler.step()\n",
|
| 538 |
+
" optimizer.zero_grad(set_to_none=True)\n",
|
| 539 |
+
" global_step += 1\n",
|
| 540 |
+
"\n",
|
| 541 |
+
" if global_step % logging_frequency == 0:\n",
|
| 542 |
+
" current_lr = scheduler.get_last_lr()[0]\n",
|
| 543 |
+
" mean_loss = running_loss / running_steps\n",
|
| 544 |
+
" print(f\"step={global_step:05d} | loss={mean_loss:.4f} | lr={current_lr:.2e}\")\n",
|
| 545 |
+
" running_loss = 0.0\n",
|
| 546 |
+
" running_steps = 0\n",
|
| 547 |
+
"\n",
|
| 548 |
+
" train_loss = running_loss / running_steps if running_steps > 0 else float(\"nan\")\n",
|
| 549 |
+
" print(f\"Training loss after epoch {epoch + 1}: {train_loss:.4f}\")\n",
|
| 550 |
+
"\n",
|
| 551 |
+
" model.eval()\n",
|
| 552 |
+
" losses = []\n",
|
| 553 |
+
" with torch.no_grad():\n",
|
| 554 |
+
" for _, batch in enumerate(EvalLoader, start=1):\n",
|
| 555 |
+
" batch = {key: value.to(device) for key, value in batch.items()}\n",
|
| 556 |
+
" loss = model(**batch).loss\n",
|
| 557 |
+
" losses.append(loss.float().item())\n",
|
| 558 |
+
" model.train()\n",
|
| 559 |
+
" val_loss = sum(losses) / len(losses) if losses else float(\"nan\")\n",
|
| 560 |
+
"\n",
|
| 561 |
+
" print(f\"Validation loss after epoch {epoch + 1}: {val_loss:.4f}\")\n",
|
| 562 |
+
"\n",
|
| 563 |
+
"print(\"Training complete.\")\n"
|
| 564 |
+
]
|
| 565 |
+
}
|
| 566 |
+
],
|
| 567 |
+
"metadata": {
|
| 568 |
+
"kernelspec": {
|
| 569 |
+
"display_name": ".venv",
|
| 570 |
+
"language": "python",
|
| 571 |
+
"name": "python3"
|
| 572 |
+
},
|
| 573 |
+
"language_info": {
|
| 574 |
+
"codemirror_mode": {
|
| 575 |
+
"name": "ipython",
|
| 576 |
+
"version": 3
|
| 577 |
+
},
|
| 578 |
+
"file_extension": ".py",
|
| 579 |
+
"mimetype": "text/x-python",
|
| 580 |
+
"name": "python",
|
| 581 |
+
"nbconvert_exporter": "python",
|
| 582 |
+
"pygments_lexer": "ipython3",
|
| 583 |
+
"version": "3.10.18"
|
| 584 |
+
},
|
| 585 |
+
"colab": {
|
| 586 |
+
"provenance": []
|
| 587 |
+
}
|
| 588 |
+
},
|
| 589 |
+
"nbformat": 4,
|
| 590 |
+
"nbformat_minor": 5
|
| 591 |
+
}
|