A Brief Tour of my NanoGPT

words byJoe Holmescode byClaude Codetutorial byAndrej Karpathy

Introduction

I recently finished Andrej Karpathy's 1-2 punch of GPT from scratch, developing both his pico-sized GPT with character-level tokens and Tiny Shakespeare, as well as the heavier duty reproduction of GPT-2 124M, on 10B tokens of educational data and with a variety of production grade optimizations.

Both were a treat to study and a gift to world knowledge. The first one was mostly fun and gentle, while the second one (the subject of this post) felt like a true rite of passage.

I've spent the past couple days banging my head on the desk while my training runs crash and overfit all while my expensive GPU servers burn the dollars away. Was it worth it? Learning is always worth it.

How else could we taste such sweet, sweet victory? Here you see my training run complete and my GPT-2 validation loss at the same levels as Karpathy's.

Rites of Passage

This crucible reminds me of my experiences with Nand2Tetris, back when I was studying the teachyourselfcs.com learning track (real ambrosia for my fellow autodidacts.) If you're unfamiliar, Nand 2 Tetris starts with the student writing code for elementary logic gates (Not And, or NAND, among them) and proceeds all the way to coding Tetris in a high level language.

For months after work I'd try and get my CPU passing tests, my tiny computer working, my assembly codes compiling into binary. It was often extremely frustrating trying to get these things to work and not knowing what to do.

I don't know how useful building a computer from scratch from first principles ended up being–while a few large and important intuitions remain, I only remember small parts of the overall Nand to Tetris pipeline. But what's important in rites of passage is not quantifiable, the knowledge gained is non-technical; to have endured and stood victorious over its challenges added some important confidence to the overall way I've since approached the computer.

In the same way, I don't think I'm moving into a career of pretraining LLMs, but getting up close and personal with optimized training runs (at least, the 2019 version) feels like I've proven something important to myself.

Try it out

Anyway, here's the model itself! Amazing how easy it is to embed this Hugging Face Space. You'll notice the outputs are very bad—this was cutting edge only 6, 7 years ago. We are alive in a time of miracles.

One of the main failure modes is that it repeats itself very easily, which it will probably do with your prompt.

Probably the most useful way to think about these outputs is to look at this repetitive nonsense and imagine it like a puppyslugged TV-static-esque image of pure noise. Out of this noise, as these parameters scale up 100s and 1000s and 10,000s of times, suddenly a picture emerges. But the fundamentals stay the same. There is something miraculous about this.

Architecture

Throughout this blog post, we'll be plucking out sections of my reproduction of Karpathy's training script. I don't intend to give a full tutorial, but instead want to point out especially weird, challenging, or interesting parts of the code.

And the first, here on lines 122-128: the chewy center of a large language model.

The story goes like this: a string of text is tokenized and those tokens become a collection of learned embeddings (wte.) Fascinatingly, in GPT-2 information about where these embeddings are in the input sequence is passed on via totally learned weights added to the embeddings (wpe).

In earlier transformers, these position encodings were hardcoded in (for me) very hard to understand sinusoidal patterns. In later models, these position encodings "rotate" the embeddings, which is also confusing. But wholly learned weights that figure out how to encode positional information? Weird and cool.

The other cool thing is that Karpathy reuses the same embeddings in the final classifier head of the model as are used to embed the tokens (wte). This is represented by a red dashed line connecting the wte and lm_head embeddings. This means the model only learns a single understanding of how its vocabulary of tokens maps to the high dimensional space in which the next token is predicted. A single translator, from the world of humankind's language to the private, incomprehensible language of the machine.

Blocks!

In line 125, we initialize a list of blocks. What's in them? Why, the chewy center of the chewy center of the LLM.

The simplicity of this system—self-attention, feed forward, rinse and repeat—is why the original paper was called "Attention Is All You Need." That causal self-attention mechanism in self.attn, scaled up to billions of parameters, is by far and away the most complex component of the model, and yet is simpler than many of the competitors of its time. Despite this simplicity, it ended up being the one true architecture for the breakthroughs we've witnessed in the past few years.

One thing I found particularly elegant is the way the residual connections are implemented in the forward pass. Notice that the forward pass simply calls the functions initialized in the constructor. The only novelty they introduce is in adding the output of each mechanism to the input itself. That simple x + represents such an valuable idea, as it saves the model from the effort of reinventing the representation inside the model all by itself in each new layer.

Though it concerns a totally different architecture, my introduction to residual connections was in Serena Yeung's course on CNNs and ResNet. I found the explanation super clear and nice.

Attention

Following the most exciting section of the Block to its definition, we see the scariest, most abstract portion of the architecture, for causal self-attention.

The best way to understand what's happening is to study Karpathy's preceding video, in which he implemented self-attention in less optimized code. And if by chance that feels too far in the deep end for you, I made an annotated bibliography of self-taught ML from total scratch, with a lot of time spent on transformers.

The thing that's unique—and that took me a long time to figure out—is how the weights for each attention head are stacked, then rearranged, to take maximum advantage of the GPU's parallelism.

The hardest part to grok is the tensor reshaping. In the forward pass, the code uses .view() and .transpose() to move the number of attention heads (n_head) into the second dimension. Why? This sets up the Query and Key tensors for a highly optimized batched matrix multiplication. By shaping them as (Batch, Heads, Sequence, Features), PyTorch can treat each attention head as a separate problem to solve in parallel, massively speeding up the core computation. This is thanks to the miracle of broadcasting.

The other big optimization is a little easier to understand: instead of multiplying Q, K, and V one line at a time, all those tensors are stacked into one mega tensor on line 38. This is simply faster than three separate matrix multiplications.

Optimizations

After dialing in the basic architecture, Karpathy spends much of the 4 hours of lecture slowly speeding up the model's ability to train. Full of tips, technical details, and weird quirks, I found this section exciting to witness. On the right sidebar is a graph of my original logs of each training, as we walked through reducing unnecessary precision in the weights and gradually decreasing the time it took to process our batches.

We get a baseline with the CPU, then head to the GPU. First, we use full 32-bit precision, then a fancy tensor-float 32 bit precision that lops off some of the bits in the mantissa (the small values that are less significant). Then, we use a half-precision 16-bit format that lops off even more bits.

And when it's time for what Karpathy calls "the heavy artillery," we compile the PyTorch model with torch.compile, which intelligently fuses a bunch of related operations so that they all take place in one operation without any round trips to the GPU's memory. This was pretty spellbinding to see.

Another cool quirk that isn't shown: Karpathy also (in a mischevious tone) mentions that models just tend to like numbers that are exponents of 2, i.e. binary numbers. So we pad out our vocabulary to be a power of 2, make sure all our batch sizes and sequence lengths are the same, and so on. It's kind of silly, and yet profound at the same time. The models want to learn, but to help them, you should speak in the language of Computer.

Gradient Accumulation

For a long section after the first speed-ups, Karpathy investigates a variety of ways we can be totally sure we're reproducing GPT-2 accurately. Because there's not as much information about 2's hyperparameters, Karpathy sometimes gets GPT-3's hyperparameters instead.

This leads to some tricky new concepts such as gradient accumulation. Why must the gradients be accumulated? GPT-3's batch size was a positively gargantuan number. And since there's no way we could fit all that in memory for one pass, we loop through a bunch of passes until we collect as many gradients as we would've if we had the tens of thousands of GPUs OpenAI did when training GPT-3. That is, we "accumulate" gradients until a single optimizer step is done on a massive amount of data.

I found out how useful this is when I screwed it up and was somehow updating weights after only an OOM less training examples. The result: ghastly, horrifying overfitting (2.5 train loss, 5ish val loss—very very bad.)

These reproduction challenges were pretty painstaking and didn't quite light my fire, so we're going to move on to figuring out DDP, distributed data parallelism: a crucial part of training on multiple GPUs at once.

DDP Is Tricky

DDP is a distributed training technique that allows us to train models on multiple GPUs or machines. However, it can be tricky to get right.

While until now we'd been running the code as a Python script, we now need to run a special terminal command, torchrun, to initialize the multi-process environment. For ex:

torchrun --nproc_per_node=x main.py

One of the biggest changes we make in our code to hook our training up to multiple GPUs at once is in the DataLoaderLite. We grab batches of data and add each to a different GPU's 'rank'. Note the use of master_process to ensure only one of the processes prints logs and other one-person jobs.

The other one is in the training loop,where we use the dist module in PyTorch to reduce all the loss values across all GPUs. This is done using the all_reduce function, which sums up the values across all GPUs and then divides by the number of GPUs to get the average loss.

All of this was pretty amazing, but also kind of scary—to see this code in action requires you to be renting multiple GPUs while SSH'd into a server. So when it breaks (which happened a lot for me), you'll be on the clock and bleeding money.

How to Accidentally Spend $150 on 10 hours of 8 A100s

As you'll know if you spend any time reading news about AI, training models on lots of GPUs is very expensive. No one but the extraordinarily wealthy owns the high end GPUs used for training. Mortals rent from the cloud, and I chose to use Prime Intellect, an aggregator of GPU cloud servers that picks out the cheapest ones available.

But by no means was the process cheap! This is mostly because I kept screwing up. SSHing into remote servers for long processes like the training (which took about 3 hours) is something I didn't have much experience in. I learned some valuable stuff—use tmux to keep processes running even if you disconnect! I also wrote some checkpointing code that saved the model checkpoints and reloaded them whenever it crashed.

All told though, this was one of those confidence boosting activities that made a simpler, saner project—fine-tuning a small OS LLM with RL someday—feel a lot more approachable. If you know what you're doing you will not spend as much as I did. I wrote this post to give the pain some meaning.

Hope you enjoyed exploring Karpathy's 4 hour tutorial masterpiece with me! Next time I'm in the mood for deep ML study, I think I'll follow along with Raschka's guide to instruction-tuning a from-scratch model like this one. If you enjoyed, follow me on LinkedIn.

1from huggingface_hub import HfApi, create_repo, upload_file
2from torch.distributed import init_process_group, destroy_process_group
3from torch.nn.parallel import DistributedDataParallel as DDP
4import time
5import torch.distributed as dist
6from datetime import timedelta
7from dataclasses import dataclass
8import inspect
9import math
10import torch
11import torch.nn as nn
12from torch.nn import functional as F
13from torch.optim import optimizer
14import os
15
16from hellaswag import render_example, iterate_examples
17import tiktoken
18enc = tiktoken.get_encoding("gpt2")
19# more forgiving timeouts
20os.environ["NCCL_TIMEOUT"] = "7200"  # 2hr
21os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"  # disable async error checking
22os.environ["NCCL_BLOCKING_WAIT"] = "1"  # more stable, slightly slower
23os.environ["NCCL_DEBUG"] = "WARN"  # reduce debug spam
24os.environ["NCCL_IB_DISABLE"] = "1"  # disable infiniband if causing issues
25os.environ["NCCL_P2P_DISABLE"] = "1"  # disable p2p if causing issues
26# TCP store timeout for distributed init
27os.environ["TORCH_DISTRIBUTED_INIT_TIMEOUT"] = "1800"  # 30 minutes
28os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
29os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
30
31# ---------------------------------------------------------------
32class CausalSelfAttention(nn.Module):
33
34    def __init__(self, config):
35        super().__init__()
36        assert config.n_embd % config.n_head == 0
37        # key, query, value projections for all heads, but in a batch
38        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
39        # output projection
40        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
41        self.c_proj.NANOGPT_SCALE_INIT = 1 # normalize residual stream
42        # regularization
43        self.n_head = config.n_head
44        self.n_embd = config.n_embd
45        # this is a mask, not a bias, but following openai naming conventions
46        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
47                                     .view(1, 1, config.block_size, config.block_size))
48
49    def forward(self, x):
50        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - size returns tuple of shapes
51        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
52        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
53        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
54        qkv = self.c_attn(x)
55        q, k, v = qkv.split(self.n_embd, dim=2)
56        # these 3 view operations move the heads to the second dimension so they
57        # are computed in parallel. this is the internal of pt: first 2 dimensions
58        # are auto computed in parallel
59        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
61        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
62
63        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
64
65        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
66        # output projection
67        y = self.c_proj(y)
68        return y
69
70
71
72class MLP(nn.Module):
73
74    def __init__(self,config):
75        super().__init__()
76        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
77        # gaussian error linear units
78        # a slightly smoother relu
79        # transformers seem to prefer smooth activations of sharper ones (like relu)
80        self.gelu = nn.GELU(approximate='tanh') # historical quirk to use approximate tanh
81        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
82        self.c_proj.NANOGPT_SCALE_INIT = 1 # normalize residual stream
83
84
85    def forward(self, x):
86        x = self.c_fc(x)
87        x = self.gelu(x)
88        x = self.c_proj(x)
89        return x
90
91class Block(nn.Module):
92
93    def __init__(self,config):
94        super().__init__()
95        self.ln_1 = nn.LayerNorm(config.n_embd)
96        self.attn = CausalSelfAttention(config)
97        self.ln_2 = nn.LayerNorm(config.n_embd)
98        self.mlp = MLP(config)
99
100    def forward(self, x):
101        x = x + self.attn(self.ln_1(x))
102        x = x + self.mlp(self.ln_2(x))
103        return x
104
105
106@dataclass
107class GPTConfig:
108    block_size: int = 1024 # max sequence length
109    vocab_size: int = 50257 # number of toks: 50k bpe merges, 256 bytes tokens, 1 <eos> token
110    # 50257 is ugly. odd. we want powers of 2.
111    # 50304 is divisible by 8, 16. better.
112    n_layer: int = 12
113    n_head: int = 12
114    n_embd: int = 768
115    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
116
117class GPT(nn.Module):
118    def __init__(self,config):
119        super().__init__()
120        self.config = config
121
122        self.transformer = nn.ModuleDict(dict(
123            wte = nn.Embedding(config.vocab_size, config.n_embd),
124            wpe = nn.Embedding(config.block_size, config.n_embd),
125            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # h stands for hidden
126            ln_f = nn.LayerNorm(config.n_embd), # final layer norm
127        ))
128        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # final classifier
129
130        # weight sharing
131        self.transformer.wte.weight = self.lm_head.weight
132
133        # init params
134        self.apply(self._init_weights)
135
136    def _init_weights(self, module):
137        if isinstance(module, nn.Linear):
138            std = 0.02
139            if hasattr(module, 'NANOGPT_SCALE_INIT'):
140                std *= (2 * self.config.n_layer) ** -0.5 # 1 over square root of num_layers,
141               # keeps residual stream from ballooning
142               # it's 2x bc attn and ffn both add to the residual pathway
143            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
144            if module.bias is not None:
145                torch.nn.init.zeros_(module.bias)
146        elif isinstance(module, nn.Embedding):
147            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
148
149    @classmethod
150    def from_pretrained(cls, model_type, override_args=None):
151        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
152        override_args = override_args or {} # default to empty dict
153        # only dropout can be overridden see more notes below
154        assert all(k == 'dropout' for k in override_args)
155        from transformers import GPT2LMHeadModel
156        print("loading weights from pretrained gpt: %s" % model_type)
157
158        # n_layer, n_head and n_embd are determined from model_type
159        config_args = {
160            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
161            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
162            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
163            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
164        }[model_type]
165        print("forcing vocab_size=50257, block_size=1024, bias=True")
166        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
167        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
168        config_args['bias'] = True # always True for GPT model checkpoints
169        # we can override the dropout rate, if desired
170        if 'dropout' in override_args:
171            print(f"overriding dropout rate to {override_args['dropout']}")
172            config_args['dropout'] = override_args['dropout']
173        # create a from-scratch initialized minGPT model
174        config = GPTConfig(**config_args)
175        model = GPT(config)
176        sd = model.state_dict()
177        sd_keys = sd.keys()
178        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
179
180        # init a huggingface/transformers model
181        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
182        sd_hf = model_hf.state_dict()
183
184        # copy while ensuring all of the parameters are aligned and match in names and shapes
185        sd_keys_hf = sd_hf.keys()
186        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
187        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
188        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
189        # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
190        # this means that we have to transpose these weights when we import them
191        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
192        for k in sd_keys_hf:
193            if any(k.endswith(w) for w in transposed):
194                # special treatment for the Conv1D weights we need to transpose
195                assert sd_hf[k].shape[::-1] == sd[k].shape
196                with torch.no_grad():
197                    sd[k].copy_(sd_hf[k].t())
198            else:
199                # vanilla copy over the other parameters
200                assert sd_hf[k].shape == sd[k].shape
201                with torch.no_grad():
202                    sd[k].copy_(sd_hf[k])
203
204        return model
205
206    def forward(self, idx, targets=None):
207        B, T = idx.size()
208        assert T <= self.config.block_size, f"Cannot fwd seq of length {T}, block size {self.config.block_size}"
209
210        pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
211        pos_emb = self.transformer.wpe(pos) # pos embs of shape (T, n_embd)
212        tok_emb = self.transformer.wte(idx) # token embs of shape (B, T, n_embd)
213        x = tok_emb + pos_emb
214
215        for block in self.transformer.h:
216            x = block(x)
217
218        x = self.transformer.ln_f(x)
219        logits = self.lm_head(x) # (B,T, vocab_size)
220        loss = None
221        if targets is not None:
222
223            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
224        return logits, loss
225
226    def configure_optimizers(self, weight_decay, learning_rate, device_type): # 2:31:50
227        # this weight decay forces info to go across many smaller channels instead of one big one
228        # start with all of the candidate parameters (that require grad)
229        param_dict = {pn: p for pn, p in self.named_parameters()}
230        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
231        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
232        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
233        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] # decay weights and sometimes embs
234        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] # no decay biases and layernorms
235        optim_groups = [
236            {'params': decay_params, 'weight_decay': weight_decay},
237            {'params': nodecay_params, 'weight_decay': 0.0}
238        ]
239        num_decay_params = sum(p.numel() for p in decay_params)
240        num_nodecay_params = sum(p.numel() for p in nodecay_params)
241        if master_process: # this is so if u have gpu clusters it doesn't print 8 times
242            print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
243            print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
244        # Create AdamW optimizer and use the fused version if it is available
245        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
246        use_fused = fused_available and device_type == "cuda"
247        if master_process:
248            print(f"using fused AdamW: {use_fused}") # fused is a newer performance optimization
249        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
250        return optimizer
251
252# --- setting up DDP (distributed data parallels)
253# torchrun command sets the env variables RANK, LOCAL_RANK and WORLD_SIZE
254
255ddp = int(os.environ.get('RANK', -1))
256if ddp != -1:
257    assert torch.cuda.is_available()
258    # Retry DDP initialization with exponential backoff
259    max_retries = 5
260    for attempt in range(max_retries):
261        try:
262            if attempt > 0:
263                if int(os.environ['RANK']) == 0:
264                    print(f"DDP init attempt {attempt + 1}/{max_retries}")
265                time.sleep(2 ** attempt)  # exponential backoff: 2, 4, 8, 16 seconds
266            init_process_group(backend='nccl', timeout=timedelta(seconds=1800))
267            break
268        except Exception as e:
269            if attempt == max_retries - 1:
270                print(f"Failed to initialize DDP after {max_retries} attempts: {e}")
271                raise e
272            else:
273                print(f"DDP init attempt {attempt + 1} failed: {e}, retrying...")
274                if dist.is_initialized():
275                    dist.destroy_process_group()
276
277    ddp_rank = int(os.environ['RANK'])
278    ddp_local_rank = int(os.environ['LOCAL_RANK'])
279    ddp_world_size = int(os.environ['WORLD_SIZE'])
280    device = f"cuda:{ddp_local_rank}"
281    torch.cuda.set_device(device)
282    master_process = ddp_rank == 0
283else:
284    ddp_rank = 0
285    ddp_local_rank = 0
286    ddp_world_size = 1
287    master_process = True
288    device = "cpu"
289    if torch.cuda.is_available():
290        device = "cuda"
291    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
292        device = "mps"
293    print(f"using device: {device}")
294torch.set_float32_matmul_precision('high')
295torch.manual_seed(1337)
296if torch.cuda.is_available():
297    torch.manual_seed(1337)
298
299# -----------------------------------------------------------------------------
300# helper function for HellaSwag eval
301# takes tokens, mask, and logits, returns the index of the completion with the lowest loss
302
303def get_most_likely_row(tokens, mask, logits):
304    # evaluate the autoregressive loss at all positions
305    shift_logits = (logits[..., :-1, :]).contiguous()
306    shift_tokens = (tokens[..., 1:]).contiguous()
307    flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
308    flat_shift_tokens = shift_tokens.view(-1)
309    shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
310    shift_losses = shift_losses.view(tokens.size(0), -1)
311    # now get the average loss just for the completion region (where mask == 1), in each row
312    shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
313    masked_shift_losses = shift_losses * shift_mask
314    # sum and divide by the number of 1s in the mask
315    sum_loss = masked_shift_losses.sum(dim=1)
316    avg_loss = sum_loss / shift_mask.sum(dim=1)
317    # now we have a loss for each of the 4 completions
318    # the one with the lowest loss should be the most likely
319    pred_norm = avg_loss.argmin().item()
320    return pred_norm
321
322
323torch.manual_seed(1337)
324if torch.cuda.is_available():
325    torch.cuda.manual_seed(1337)
326
327# --- gradient accumulation 2:36:00 ---
328total_batch_size = 524288 # ~0.5M tokens per gpt-3 small in its paper
329
330B = 64
331T = 1024
332assert total_batch_size % (B * T * ddp_world_size) == 0 # make sure total batch size is divisible by B*T * world_size (number of total gpus)
333grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
334if master_process:
335    print(f"total desired batch size: {total_batch_size}")
336    print(f"=> calculated grad accum steps: {grad_accum_steps}")
337
338
339
340# -------------------------------
341# dataloader
342
343# -----------------------------------------------------------------------------
344import numpy as np
345import tiktoken
346
347def load_tokens(filename):
348    try:
349        npt = np.load(filename)
350        # npt = npt.astype(np.int32)  # Convert uint16 to int32 for torch compatibility
351        ptt = torch.tensor(npt, dtype=torch.long)
352        return ptt
353    except Exception as e:
354        print(f"Error loading {filename}: {e}")
355        # Try to peek at the file content
356        with open(filename, 'rb') as f:
357            header = f.read(16)
358            print(f"File header (first 16 bytes): {header}")
359        raise e
360
361class DataLoaderLite:
362    def __init__(self, B, T, process_rank, num_processes, split):
363        self.B = B
364        self.T = T
365        self.process_rank = process_rank
366        self.num_processes = num_processes
367        assert split in {'train', 'val'}
368
369        # get the shard filenames
370        data_root = "../data/edu_fineweb10B"
371        shards = os.listdir(data_root)
372        shards = [s for s in shards if split in s]
373        shards = sorted(shards)
374        shards = [os.path.join(data_root, s) for s in shards]
375        self.shards = shards
376        assert len(shards) > 0, f"no shards found for split {split}"
377        if master_process:
378            print(f"found {len(shards)} shards for split {split}")
379
380        # state, init at shard zero
381        self.current_shard = 0
382        self.tokens = load_tokens(self.shards[self.current_shard])
383        self.current_position = self.B * self.T * self.process_rank
384        self.reset()
385
386    def reset(self):
387        self.current_shard = 0
388        self.tokens = load_tokens(self.shards[self.current_shard])
389        self.current_position = self.B * self.T * self.process_rank
390
391    def next_batch(self):
392            B, T = self.B, self.T
393            buf = self.tokens[self.current_position : self.current_position+B*T+1]
394            x = (buf[:-1]).view(B, T) # inputs
395            y = (buf[1:]).view(B, T) # targets
396            # advance the position in the tensor
397            self.current_position += B * T * self.num_processes
398            # if loading the next batch would be out of bounds, advance to next shard
399            if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
400                self.current_shard = (self.current_shard + 1) % len(self.shards)
401                self.tokens = load_tokens(self.shards[self.current_shard])
402                self.current_position = B * T * self.process_rank
403            return x, y
404
405
406train_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train")
407val_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val")
408
409
410# -------------------------------
411# run the training loop
412
413
414
415num_return_sequences = 5
416max_length = 30
417
418# model = GPT.from_pretrained('gpt2')
419# overriding the ugly vocab size number with a power of 2 number here
420# when doing distributed training, WORLD_SIZE models get created now
421# they all have the same seed so they're all identical (2:59)
422model = GPT(GPTConfig(vocab_size=50304))
423model.to(device)
424use_compile = False # torch.compile interferes with HellaSwag eval and Generation. TODO fix
425if use_compile:
426    model = torch.compile(model)
427if ddp:
428    model = DDP(model, device_ids=[ddp_local_rank])
429raw_model = model.module if ddp else model # this contains the configure optimizers func we wanna call
430# -- logging --
431# create the log directory we will write checkpoints to and log to
432log_dir = "log"
433os.makedirs(log_dir, exist_ok=True)
434log_file = os.path.join(log_dir, f"log.txt")
435with open(log_file, "w") as f: # open for writing to clear the file
436    pass
437
438# --- cosine decay lr ----
439max_lr = 6e-4 # gpt-3 small LR per their paper (gpt-2 doesn't specify)
440min_lr = max_lr * 0.1
441warmup_steps = 715 # these two hps map to the gpt-3 schedule, adapted for fineweb10b
442max_steps = 19073
443def get_lr(it):
444    # 1) linear warmup for warmup_iters steps
445    if it < warmup_steps:
446        return max_lr * (it+1) / warmup_steps
447    # 2) if it > lr_decay_iters, return min lr
448    if it > max_steps:
449        return min_lr
450    # 3) in between, use cosine decay down to min learning rate
451    decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
452    assert 0 <= decay_ratio <= 1
453    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
454    return min_lr + coeff * (max_lr - min_lr)
455# --- cosine decay lr end ----
456
457
458device_type = 'cuda' if 'cuda' in device else 'cpu'
459
460# checkpoint loading logic
461initial_iter = 0
462if os.path.exists(log_dir):
463    checkpoints = [f for f in os.listdir(log_dir) if f.startswith('model_')]
464    if checkpoints:
465        latest = max(checkpoints, key=lambda x: int(x.split('_')[1].split('.')[0]))
466        checkpoint_path = os.path.join(log_dir, latest)
467        checkpoint = torch.load(checkpoint_path, map_location=device)
468        raw_model.load_state_dict(checkpoint['model'])
469        initial_iter = checkpoint['step'] + 1
470        if master_process:
471            print(f"resuming from step {initial_iter}")
472
473optimizer = raw_model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device_type)
474
475# load optimizer state if checkpoint exists
476if initial_iter > 0:
477    try:
478        if 'optimizer' in checkpoint:
479            optimizer.load_state_dict(checkpoint['optimizer'])
480        if 'rng_state' in checkpoint:
481            rng_state = checkpoint['rng_state']
482            if not isinstance(rng_state, torch.ByteTensor):
483                if master_process:
484                    print(f"Warning: Converting RNG state from {type(rng_state)} to ByteTensor")
485                # Convert to uint8 tensor then to ByteTensor
486                if isinstance(rng_state, torch.Tensor):
487                    rng_state = rng_state.to(torch.uint8)
488            torch.set_rng_state(rng_state)
489        if checkpoint.get('cuda_rng_state') is not None:
490            cuda_rng_state = checkpoint['cuda_rng_state']
491            if not isinstance(cuda_rng_state, torch.ByteTensor):
492                if master_process:
493                    print(f"Warning: Converting CUDA RNG state from {type(cuda_rng_state)} to ByteTensor")
494                # Convert to uint8 tensor
495                if isinstance(cuda_rng_state, torch.Tensor):
496                    cuda_rng_state = cuda_rng_state.to(torch.uint8)
497            torch.cuda.set_rng_state(cuda_rng_state)
498        if 'loader_position' in checkpoint:
499            train_loader.current_position = checkpoint['loader_position']
500        if 'loader_shard' in checkpoint:
501            train_loader.current_shard = checkpoint['loader_shard']
502            train_loader.tokens = load_tokens(train_loader.shards[train_loader.current_shard])
503        # reset lr scheduler
504        for param_group in optimizer.param_groups:
505            param_group['lr'] = get_lr(initial_iter)
506        if master_process:
507            print(f"Successfully loaded checkpoint from step {initial_iter}")
508    except Exception as e:
509        if master_process:
510            print(f"Warning: Failed to load some checkpoint data: {e}")
511            print("Continuing with partial checkpoint restore...")
512
513for step in range(initial_iter, max_steps):
514    t0 = time.time()
515    last_step = (step == max_steps - 1)
516    # occasionally find out the val loss
517    if step % 500 == 0 or last_step:
518        model.eval()
519        val_loader.reset()
520        if torch.cuda.is_available():
521            torch.cuda.reset_peak_memory_stats()
522            torch.cuda.empty_cache()
523        with torch.no_grad():
524            val_loss_accum = 0.0
525            val_loss_steps = 20
526            for _ in range(val_loss_steps):
527                x, y = val_loader.next_batch()
528                x, y = x.to(device), y.to(device)
529                with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
530                    logits, loss = model(x,y)
531                loss = loss / val_loss_steps
532                val_loss_accum += loss.detach()
533        if ddp:
534            dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
535        if master_process:
536            val_loss_val = val_loss_accum.item()
537            print(f"step {step}, val loss: {val_loss_val:.4f}")
538            # Check for potential overfitting every eval
539            if step > 500 and hasattr('prev_train_loss', '__self__'):
540                train_val_gap = val_loss_val - prev_train_loss
541                if train_val_gap > 1.5:
542                    print(f"Warning: Large train/val gap ({train_val_gap:.3f}) - possible overfitting")
543            with open(log_file, "a") as f:
544                f.write(f"{step} val {val_loss_val:.4f}\n")
545        if torch.cuda.is_available():
546            torch.cuda.empty_cache()
547
548    if step > 0 and (step % 1000 == 0 or last_step):
549        checkpoint_path = os.path.join(log_dir, f"model_{step:05d}.pt")
550        # ensure RNG states are ByteTensors for compatibility
551        rng_state = torch.get_rng_state()
552        cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
553
554        checkpoint = {
555            'model': raw_model.state_dict(),
556            'config': raw_model.config,
557            'step': step,
558            'val_loss': val_loss_accum.item(),
559            'optimizer': optimizer.state_dict(),
560            'rng_state': rng_state.byte() if not isinstance(rng_state, torch.ByteTensor) else rng_state,
561            'cuda_rng_state': cuda_rng_state.byte() if cuda_rng_state is not None and not isinstance(cuda_rng_state, torch.ByteTensor) else cuda_rng_state,
562            'loader_position': train_loader.current_position,
563            'loader_shard': train_loader.current_shard,
564        }
565        torch.save(checkpoint, checkpoint_path)
566
567    # once in a while evaluate hellaswag
568    if (step % 500 == 0 or last_step) and (not use_compile):
569        num_correct_norm = 0
570        num_total = 0
571        for i, example in enumerate(iterate_examples("val")):
572            # only process examples where i % ddp_world_size == ddp_rank
573            if i % ddp_world_size != ddp_rank:
574                continue
575            # render the example into tokens and labels
576            _, tokens, mask, label = render_example(example)
577            tokens = tokens.to(device)
578            mask = mask.to(device)
579            # get the logits
580            with torch.no_grad():
581                with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
582                    logits, loss = model(tokens)
583                pred_norm = get_most_likely_row(tokens, mask, logits)
584            num_total += 1
585            num_correct_norm += int(pred_norm == label)
586        # reduce the stats across all processes
587        if ddp:
588            num_total = torch.tensor(num_total, dtype=torch.long, device=device)
589            num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
590            dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
591            dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
592            num_total = num_total.item()
593            num_correct_norm = num_correct_norm.item()
594        acc_norm = num_correct_norm / num_total
595        if master_process:
596            print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
597            with open(log_file, "a") as f:
598                f.write(f"{step} hella {acc_norm:.4f}\n")
599        if torch.cuda.is_available():
600            torch.cuda.empty_cache()
601    if ((step > 0 and step % 500 == 0) or last_step) and (not use_compile):
602        model.eval()
603        num_return_sequences = 4
604        max_length = 32
605        tokens = enc.encode("Hello, I'm a language model,")
606        tokens = torch.tensor(tokens, dtype=torch.long)
607        tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
608        xgen = tokens.to(device)
609        sample_rng = torch.Generator(device=device)
610        sample_rng.manual_seed(42 + ddp_rank)
611        # in training loop
612        while xgen.size(1) < max_length:
613            # forward the model to get the logits
614            with torch.no_grad():
615                with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
616                    logits, loss = model(xgen) # (B, T, vocab_size)
617                # take the logits at the last position
618                logits = logits[:, -1, :] # (B, vocab_size)
619                # get the probabilities
620                probs = F.softmax(logits, dim=-1)
621                # do top-k sampling of 50 (huggingface pipeline default)
622                # topk_probs here becomes (5, 50), topk_indices is (5, 50)
623                topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
624                # select a token from the top-k probabilities
625                # note: multinomial does not demand the input to sum to 1
626                ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
627                # gather the corresponding indices
628                xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
629                # append to the sequence
630                xgen = torch.cat((xgen, xcol), dim=1)
631        # print the generated text
632        for i in range(num_return_sequences):
633            tokens = xgen[i, :max_length].tolist()
634            decoded = enc.decode(tokens)
635            print(f"rank {ddp_rank} sample {i}: {decoded}")
636        if torch.cuda.is_available():
637            torch.cuda.empty_cache()
638
639    model.train()
640    optimizer.zero_grad()
641    loss_accum = 0.0
642    # gradient accumulation 2:39:23
643    for micro_step in range(grad_accum_steps):
644        x, y = train_loader.next_batch()
645        x, y = x.to(device), y.to(device)
646        if ddp:
647            model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1) # should only share grads on last step
648        with torch.autocast(device_type=device_type, dtype=torch.bfloat16): # bfloat only possible with ampere gpus
649            logits, loss = model(x,y)
650        loss = loss / grad_accum_steps # 2:44, otherwise losses sum over the accumulated passes
651        loss_accum += loss.detach()
652        loss.backward()
653    if ddp:
654        dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG) # average the loss across all gpus
655    # final gradient clipping after accumulation
656    # 2:18:00
657    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
658    # lr scheduler is cosine decay (2:22)
659    lr = get_lr(step)
660    for param_group in optimizer.param_groups: # pt treats params as groups in optimization;
661    #there's only one item in this group fyi
662        param_group['lr'] = lr
663    optimizer.step()
664    torch.cuda.synchronize() # this awaits for all kernels to finish
665    t1 = time.time()
666    dt = (t1 - t0)*1000
667    tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
668    tokens_per_sec = tokens_processed / (t1 - t0)
669    # keep an eye on gradient norms, signal of problems if anomalies
670    if master_process:
671        train_loss = loss_accum.item()
672        print(f"step {step}, loss: {train_loss:.6f}, lr: {lr:.4e}, norm: {norm:.4f}, dt: {dt:.2f}ms, tok/sec: {tokens_per_sec:.2f}")
673        # Store for overfitting detection
674        prev_train_loss = train_loss
675        with open(log_file, "a") as f:
676            f.write(f"{step} train {train_loss:.6f}\n")
677
678if ddp:
679    destroy_process_group()
680
681# only master saves & uploads
682if master_process:
683    # create repo (once)
684
685    # save the raw model (not ddp wrapper)
686    torch.save(raw_model.state_dict(), 'model.pt')
687    api = HfApi()
688    api.upload_file(
689        path_or_fileobj="model.pt",
690        path_in_repo="model.pt",
691        repo_id="bathrobe/my-gpt2"
692    )
693    print("model uploaded to hf")
694