I recently finished Andrej Karpathy's 1-2 punch of GPT from scratch, developing both his pico-sized GPT with character-level tokens and Tiny Shakespeare, as well as the heavier duty reproduction of GPT-2 124M, on 10B tokens of educational data and with a variety of production grade optimizations.
Both were a treat to study and a gift to world knowledge. The first one was mostly fun and gentle, while the second one (the subject of this post) felt like a true rite of passage.
I've spent the past couple days banging my head on the desk while my training runs crash and overfit all while my expensive GPU servers burn the dollars away. Was it worth it? Learning is always worth it.
How else could we taste such sweet, sweet victory? Here you see my training run complete and my GPT-2 validation loss at the same levels as Karpathy's.
This crucible reminds me of my experiences with Nand2Tetris, back when I was studying the teachyourselfcs.com learning track (real ambrosia for my fellow autodidacts.) If you're unfamiliar, Nand 2 Tetris starts with the student writing code for elementary logic gates (Not And, or NAND, among them) and proceeds all the way to coding Tetris in a high level language.
For months after work I'd try and get my CPU passing tests, my tiny computer working, my assembly codes compiling into binary. It was often extremely frustrating trying to get these things to work and not knowing what to do.
I don't know how useful building a computer from scratch from first principles ended up being–while a few large and important intuitions remain, I only remember small parts of the overall Nand to Tetris pipeline. But what's important in rites of passage is not quantifiable, the knowledge gained is non-technical; to have endured and stood victorious over its challenges added some important confidence to the overall way I've since approached the computer.
In the same way, I don't think I'm moving into a career of pretraining LLMs, but getting up close and personal with optimized training runs (at least, the 2019 version) feels like I've proven something important to myself.
Anyway, here's the model itself! Amazing how easy it is to embed this Hugging Face Space. You'll notice the outputs are very bad—this was cutting edge only 6, 7 years ago. We are alive in a time of miracles.
One of the main failure modes is that it repeats itself very easily, which it will probably do with your prompt.
Probably the most useful way to think about these outputs is to look at this repetitive nonsense and imagine it like a puppyslugged TV-static-esque image of pure noise. Out of this noise, as these parameters scale up 100s and 1000s and 10,000s of times, suddenly a picture emerges. But the fundamentals stay the same. There is something miraculous about this.
Throughout this blog post, we'll be plucking out sections of my reproduction of Karpathy's training script. I don't intend to give a full tutorial, but instead want to point out especially weird, challenging, or interesting parts of the code.
And the first, here on lines 122-128: the chewy center of a large language model.
The story goes like this: a string of text is tokenized and those tokens become a collection of learned embeddings (wte
.) Fascinatingly, in GPT-2 information about where these embeddings are in the input sequence is passed on via totally learned weights added to the embeddings (wpe
).
In earlier transformers, these position encodings were hardcoded in (for me) very hard to understand sinusoidal patterns. In later models, these position encodings "rotate" the embeddings, which is also confusing. But wholly learned weights that figure out how to encode positional information? Weird and cool.
The other cool thing is that Karpathy reuses the same embeddings in the final classifier head of the model as are used to embed the tokens (wte
). This is represented by a red dashed line connecting the wte
and lm_head
embeddings. This means the model only learns a single understanding of how its vocabulary of tokens maps to the high dimensional space in which the next token is predicted. A single translator, from the world of humankind's language to the private, incomprehensible language of the machine.
In line 125, we initialize a list of blocks. What's in them? Why, the chewy center of the chewy center of the LLM.
The simplicity of this system—self-attention, feed forward, rinse and repeat—is why the original paper was called "Attention Is All You Need." That causal self-attention mechanism in self.attn
, scaled up to billions of parameters, is by far and away the most complex component of the model, and yet is simpler than many of the competitors of its time. Despite this simplicity, it ended up being the one true architecture for the breakthroughs we've witnessed in the past few years.
One thing I found particularly elegant is the way the residual connections are implemented in the forward pass. Notice that the forward pass simply calls the functions initialized in the constructor. The only novelty they introduce is in adding the output of each mechanism to the input itself. That simple x +
represents such an valuable idea, as it saves the model from the effort of reinventing the representation inside the model all by itself in each new layer.
Though it concerns a totally different architecture, my introduction to residual connections was in Serena Yeung's course on CNNs and ResNet. I found the explanation super clear and nice.
Following the most exciting section of the Block
to its definition, we see the scariest, most abstract portion of the architecture, for causal self-attention.
The best way to understand what's happening is to study Karpathy's preceding video, in which he implemented self-attention in less optimized code. And if by chance that feels too far in the deep end for you, I made an annotated bibliography of self-taught ML from total scratch, with a lot of time spent on transformers.
The thing that's unique—and that took me a long time to figure out—is how the weights for each attention head are stacked, then rearranged, to take maximum advantage of the GPU's parallelism.
The hardest part to grok is the tensor reshaping. In the forward pass, the code uses .view()
and .transpose()
to move the number of attention heads (n_head
) into the second dimension. Why? This sets up the Query and Key tensors for a highly optimized batched matrix multiplication. By shaping them as (Batch, Heads, Sequence, Features)
, PyTorch can treat each attention head as a separate problem to solve in parallel, massively speeding up the core computation. This is thanks to the miracle of broadcasting.
The other big optimization is a little easier to understand: instead of multiplying Q, K, and V one line at a time, all those tensors are stacked into one mega tensor on line 38. This is simply faster than three separate matrix multiplications.
After dialing in the basic architecture, Karpathy spends much of the 4 hours of lecture slowly speeding up the model's ability to train. Full of tips, technical details, and weird quirks, I found this section exciting to witness. On the right sidebar is a graph of my original logs of each training, as we walked through reducing unnecessary precision in the weights and gradually decreasing the time it took to process our batches.
We get a baseline with the CPU, then head to the GPU. First, we use full 32-bit precision, then a fancy tensor-float 32 bit precision that lops off some of the bits in the mantissa (the small values that are less significant). Then, we use a half-precision 16-bit format that lops off even more bits.
And when it's time for what Karpathy calls "the heavy artillery," we compile the PyTorch model with torch.compile
, which intelligently fuses a bunch of related operations so that they all take place in one operation without any round trips to the GPU's memory. This was pretty spellbinding to see.
Another cool quirk that isn't shown: Karpathy also (in a mischevious tone) mentions that models just tend to like numbers that are exponents of 2, i.e. binary numbers. So we pad out our vocabulary to be a power of 2, make sure all our batch sizes and sequence lengths are the same, and so on. It's kind of silly, and yet profound at the same time. The models want to learn, but to help them, you should speak in the language of Computer.
For a long section after the first speed-ups, Karpathy investigates a variety of ways we can be totally sure we're reproducing GPT-2 accurately. Because there's not as much information about 2's hyperparameters, Karpathy sometimes gets GPT-3's hyperparameters instead.
This leads to some tricky new concepts such as gradient accumulation. Why must the gradients be accumulated? GPT-3's batch size was a positively gargantuan number. And since there's no way we could fit all that in memory for one pass, we loop through a bunch of passes until we collect as many gradients as we would've if we had the tens of thousands of GPUs OpenAI did when training GPT-3. That is, we "accumulate" gradients until a single optimizer step is done on a massive amount of data.
I found out how useful this is when I screwed it up and was somehow updating weights after only an OOM less training examples. The result: ghastly, horrifying overfitting (2.5 train loss, 5ish val loss—very very bad.)
These reproduction challenges were pretty painstaking and didn't quite light my fire, so we're going to move on to figuring out DDP, distributed data parallelism: a crucial part of training on multiple GPUs at once.
DDP is a distributed training technique that allows us to train models on multiple GPUs or machines. However, it can be tricky to get right.
While until now we'd been running the code as a Python script, we now need to run a special terminal command, torchrun
, to initialize the multi-process environment. For ex:
torchrun --nproc_per_node=x main.py
One of the biggest changes we make in our code to hook our training up to multiple GPUs at once is in the DataLoaderLite. We grab batches of data and add each to a different GPU's 'rank'. Note the use of master_process
to ensure only one of the processes prints logs and other one-person jobs.
The other one is in the training loop,where we use the dist
module in PyTorch to reduce all the loss values across all GPUs. This is done using the all_reduce
function, which sums up the values across all GPUs and then divides by the number of GPUs to get the average loss.
All of this was pretty amazing, but also kind of scary—to see this code in action requires you to be renting multiple GPUs while SSH'd into a server. So when it breaks (which happened a lot for me), you'll be on the clock and bleeding money.
As you'll know if you spend any time reading news about AI, training models on lots of GPUs is very expensive. No one but the extraordinarily wealthy owns the high end GPUs used for training. Mortals rent from the cloud, and I chose to use Prime Intellect, an aggregator of GPU cloud servers that picks out the cheapest ones available.
But by no means was the process cheap! This is mostly because I kept screwing up. SSHing into remote servers for long processes like the training (which took about 3 hours) is something I didn't have much experience in. I learned some valuable stuff—use tmux
to keep processes running even if you disconnect! I also wrote some checkpointing code that saved the model checkpoints and reloaded them whenever it crashed.
All told though, this was one of those confidence boosting activities that made a simpler, saner project—fine-tuning a small OS LLM with RL someday—feel a lot more approachable. If you know what you're doing you will not spend as much as I did. I wrote this post to give the pain some meaning.
Hope you enjoyed exploring Karpathy's 4 hour tutorial masterpiece with me! Next time I'm in the mood for deep ML study, I think I'll follow along with Raschka's guide to instruction-tuning a from-scratch model like this one. If you enjoyed, follow me on LinkedIn.
1from huggingface_hub import HfApi, create_repo, upload_file
2from torch.distributed import init_process_group, destroy_process_group
3from torch.nn.parallel import DistributedDataParallel as DDP
4import time
5import torch.distributed as dist
6from datetime import timedelta
7from dataclasses import dataclass
8import inspect
9import math
10import torch
11import torch.nn as nn
12from torch.nn import functional as F
13from torch.optim import optimizer
14import os
15
16from hellaswag import render_example, iterate_examples
17import tiktoken
18enc = tiktoken.get_encoding("gpt2")
19# more forgiving timeouts
20os.environ["NCCL_TIMEOUT"] = "7200" # 2hr
21os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" # disable async error checking
22os.environ["NCCL_BLOCKING_WAIT"] = "1" # more stable, slightly slower
23os.environ["NCCL_DEBUG"] = "WARN" # reduce debug spam
24os.environ["NCCL_IB_DISABLE"] = "1" # disable infiniband if causing issues
25os.environ["NCCL_P2P_DISABLE"] = "1" # disable p2p if causing issues
26# TCP store timeout for distributed init
27os.environ["TORCH_DISTRIBUTED_INIT_TIMEOUT"] = "1800" # 30 minutes
28os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
29os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
30
31# ---------------------------------------------------------------
32class CausalSelfAttention(nn.Module):
33
34 def __init__(self, config):
35 super().__init__()
36 assert config.n_embd % config.n_head == 0
37 # key, query, value projections for all heads, but in a batch
38 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
39 # output projection
40 self.c_proj = nn.Linear(config.n_embd, config.n_embd)
41 self.c_proj.NANOGPT_SCALE_INIT = 1 # normalize residual stream
42 # regularization
43 self.n_head = config.n_head
44 self.n_embd = config.n_embd
45 # this is a mask, not a bias, but following openai naming conventions
46 self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
47 .view(1, 1, config.block_size, config.block_size))
48
49 def forward(self, x):
50 B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - size returns tuple of shapes
51 # calculate query, key, values for all heads in batch and move head forward to be the batch dim
52 # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
53 # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
54 qkv = self.c_attn(x)
55 q, k, v = qkv.split(self.n_embd, dim=2)
56 # these 3 view operations move the heads to the second dimension so they
57 # are computed in parallel. this is the internal of pt: first 2 dimensions
58 # are auto computed in parallel
59 k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60 q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
61 v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
62
63 y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
64
65 y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
66 # output projection
67 y = self.c_proj(y)
68 return y
69
70
71
72class MLP(nn.Module):
73
74 def __init__(self,config):
75 super().__init__()
76 self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
77 # gaussian error linear units
78 # a slightly smoother relu
79 # transformers seem to prefer smooth activations of sharper ones (like relu)
80 self.gelu = nn.GELU(approximate='tanh') # historical quirk to use approximate tanh
81 self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
82 self.c_proj.NANOGPT_SCALE_INIT = 1 # normalize residual stream
83
84
85 def forward(self, x):
86 x = self.c_fc(x)
87 x = self.gelu(x)
88 x = self.c_proj(x)
89 return x
90
91class Block(nn.Module):
92
93 def __init__(self,config):
94 super().__init__()
95 self.ln_1 = nn.LayerNorm(config.n_embd)
96 self.attn = CausalSelfAttention(config)
97 self.ln_2 = nn.LayerNorm(config.n_embd)
98 self.mlp = MLP(config)
99
100 def forward(self, x):
101 x = x + self.attn(self.ln_1(x))
102 x = x + self.mlp(self.ln_2(x))
103 return x
104
105
106@dataclass
107class GPTConfig:
108 block_size: int = 1024 # max sequence length
109 vocab_size: int = 50257 # number of toks: 50k bpe merges, 256 bytes tokens, 1 <eos> token
110 # 50257 is ugly. odd. we want powers of 2.
111 # 50304 is divisible by 8, 16. better.
112 n_layer: int = 12
113 n_head: int = 12
114 n_embd: int = 768
115 bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
116
117class GPT(nn.Module):
118 def __init__(self,config):
119 super().__init__()
120 self.config = config
121
122 self.transformer = nn.ModuleDict(dict(
123 wte = nn.Embedding(config.vocab_size, config.n_embd),
124 wpe = nn.Embedding(config.block_size, config.n_embd),
125 h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # h stands for hidden
126 ln_f = nn.LayerNorm(config.n_embd), # final layer norm
127 ))
128 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # final classifier
129
130 # weight sharing
131 self.transformer.wte.weight = self.lm_head.weight
132
133 # init params
134 self.apply(self._init_weights)
135
136 def _init_weights(self, module):
137 if isinstance(module, nn.Linear):
138 std = 0.02
139 if hasattr(module, 'NANOGPT_SCALE_INIT'):
140 std *= (2 * self.config.n_layer) ** -0.5 # 1 over square root of num_layers,
141 # keeps residual stream from ballooning
142 # it's 2x bc attn and ffn both add to the residual pathway
143 torch.nn.init.normal_(module.weight, mean=0.0, std=std)
144 if module.bias is not None:
145 torch.nn.init.zeros_(module.bias)
146 elif isinstance(module, nn.Embedding):
147 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
148
149 @classmethod
150 def from_pretrained(cls, model_type, override_args=None):
151 assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
152 override_args = override_args or {} # default to empty dict
153 # only dropout can be overridden see more notes below
154 assert all(k == 'dropout' for k in override_args)
155 from transformers import GPT2LMHeadModel
156 print("loading weights from pretrained gpt: %s" % model_type)
157
158 # n_layer, n_head and n_embd are determined from model_type
159 config_args = {
160 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
161 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
162 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
163 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
164 }[model_type]
165 print("forcing vocab_size=50257, block_size=1024, bias=True")
166 config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
167 config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
168 config_args['bias'] = True # always True for GPT model checkpoints
169 # we can override the dropout rate, if desired
170 if 'dropout' in override_args:
171 print(f"overriding dropout rate to {override_args['dropout']}")
172 config_args['dropout'] = override_args['dropout']
173 # create a from-scratch initialized minGPT model
174 config = GPTConfig(**config_args)
175 model = GPT(config)
176 sd = model.state_dict()
177 sd_keys = sd.keys()
178 sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
179
180 # init a huggingface/transformers model
181 model_hf = GPT2LMHeadModel.from_pretrained(model_type)
182 sd_hf = model_hf.state_dict()
183
184 # copy while ensuring all of the parameters are aligned and match in names and shapes
185 sd_keys_hf = sd_hf.keys()
186 sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
187 sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
188 transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
189 # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
190 # this means that we have to transpose these weights when we import them
191 assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
192 for k in sd_keys_hf:
193 if any(k.endswith(w) for w in transposed):
194 # special treatment for the Conv1D weights we need to transpose
195 assert sd_hf[k].shape[::-1] == sd[k].shape
196 with torch.no_grad():
197 sd[k].copy_(sd_hf[k].t())
198 else:
199 # vanilla copy over the other parameters
200 assert sd_hf[k].shape == sd[k].shape
201 with torch.no_grad():
202 sd[k].copy_(sd_hf[k])
203
204 return model
205
206 def forward(self, idx, targets=None):
207 B, T = idx.size()
208 assert T <= self.config.block_size, f"Cannot fwd seq of length {T}, block size {self.config.block_size}"
209
210 pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
211 pos_emb = self.transformer.wpe(pos) # pos embs of shape (T, n_embd)
212 tok_emb = self.transformer.wte(idx) # token embs of shape (B, T, n_embd)
213 x = tok_emb + pos_emb
214
215 for block in self.transformer.h:
216 x = block(x)
217
218 x = self.transformer.ln_f(x)
219 logits = self.lm_head(x) # (B,T, vocab_size)
220 loss = None
221 if targets is not None:
222
223 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
224 return logits, loss
225
226 def configure_optimizers(self, weight_decay, learning_rate, device_type): # 2:31:50
227 # this weight decay forces info to go across many smaller channels instead of one big one
228 # start with all of the candidate parameters (that require grad)
229 param_dict = {pn: p for pn, p in self.named_parameters()}
230 param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
231 # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
232 # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
233 decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] # decay weights and sometimes embs
234 nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] # no decay biases and layernorms
235 optim_groups = [
236 {'params': decay_params, 'weight_decay': weight_decay},
237 {'params': nodecay_params, 'weight_decay': 0.0}
238 ]
239 num_decay_params = sum(p.numel() for p in decay_params)
240 num_nodecay_params = sum(p.numel() for p in nodecay_params)
241 if master_process: # this is so if u have gpu clusters it doesn't print 8 times
242 print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
243 print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
244 # Create AdamW optimizer and use the fused version if it is available
245 fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
246 use_fused = fused_available and device_type == "cuda"
247 if master_process:
248 print(f"using fused AdamW: {use_fused}") # fused is a newer performance optimization
249 optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
250 return optimizer
251
252# --- setting up DDP (distributed data parallels)
253# torchrun command sets the env variables RANK, LOCAL_RANK and WORLD_SIZE
254
255ddp = int(os.environ.get('RANK', -1))
256if ddp != -1:
257 assert torch.cuda.is_available()
258 # Retry DDP initialization with exponential backoff
259 max_retries = 5
260 for attempt in range(max_retries):
261 try:
262 if attempt > 0:
263 if int(os.environ['RANK']) == 0:
264 print(f"DDP init attempt {attempt + 1}/{max_retries}")
265 time.sleep(2 ** attempt) # exponential backoff: 2, 4, 8, 16 seconds
266 init_process_group(backend='nccl', timeout=timedelta(seconds=1800))
267 break
268 except Exception as e:
269 if attempt == max_retries - 1:
270 print(f"Failed to initialize DDP after {max_retries} attempts: {e}")
271 raise e
272 else:
273 print(f"DDP init attempt {attempt + 1} failed: {e}, retrying...")
274 if dist.is_initialized():
275 dist.destroy_process_group()
276
277 ddp_rank = int(os.environ['RANK'])
278 ddp_local_rank = int(os.environ['LOCAL_RANK'])
279 ddp_world_size = int(os.environ['WORLD_SIZE'])
280 device = f"cuda:{ddp_local_rank}"
281 torch.cuda.set_device(device)
282 master_process = ddp_rank == 0
283else:
284 ddp_rank = 0
285 ddp_local_rank = 0
286 ddp_world_size = 1
287 master_process = True
288 device = "cpu"
289 if torch.cuda.is_available():
290 device = "cuda"
291 elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
292 device = "mps"
293 print(f"using device: {device}")
294torch.set_float32_matmul_precision('high')
295torch.manual_seed(1337)
296if torch.cuda.is_available():
297 torch.manual_seed(1337)
298
299# -----------------------------------------------------------------------------
300# helper function for HellaSwag eval
301# takes tokens, mask, and logits, returns the index of the completion with the lowest loss
302
303def get_most_likely_row(tokens, mask, logits):
304 # evaluate the autoregressive loss at all positions
305 shift_logits = (logits[..., :-1, :]).contiguous()
306 shift_tokens = (tokens[..., 1:]).contiguous()
307 flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
308 flat_shift_tokens = shift_tokens.view(-1)
309 shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
310 shift_losses = shift_losses.view(tokens.size(0), -1)
311 # now get the average loss just for the completion region (where mask == 1), in each row
312 shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
313 masked_shift_losses = shift_losses * shift_mask
314 # sum and divide by the number of 1s in the mask
315 sum_loss = masked_shift_losses.sum(dim=1)
316 avg_loss = sum_loss / shift_mask.sum(dim=1)
317 # now we have a loss for each of the 4 completions
318 # the one with the lowest loss should be the most likely
319 pred_norm = avg_loss.argmin().item()
320 return pred_norm
321
322
323torch.manual_seed(1337)
324if torch.cuda.is_available():
325 torch.cuda.manual_seed(1337)
326
327# --- gradient accumulation 2:36:00 ---
328total_batch_size = 524288 # ~0.5M tokens per gpt-3 small in its paper
329
330B = 64
331T = 1024
332assert total_batch_size % (B * T * ddp_world_size) == 0 # make sure total batch size is divisible by B*T * world_size (number of total gpus)
333grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
334if master_process:
335 print(f"total desired batch size: {total_batch_size}")
336 print(f"=> calculated grad accum steps: {grad_accum_steps}")
337
338
339
340# -------------------------------
341# dataloader
342
343# -----------------------------------------------------------------------------
344import numpy as np
345import tiktoken
346
347def load_tokens(filename):
348 try:
349 npt = np.load(filename)
350 # npt = npt.astype(np.int32) # Convert uint16 to int32 for torch compatibility
351 ptt = torch.tensor(npt, dtype=torch.long)
352 return ptt
353 except Exception as e:
354 print(f"Error loading {filename}: {e}")
355 # Try to peek at the file content
356 with open(filename, 'rb') as f:
357 header = f.read(16)
358 print(f"File header (first 16 bytes): {header}")
359 raise e
360
361class DataLoaderLite:
362 def __init__(self, B, T, process_rank, num_processes, split):
363 self.B = B
364 self.T = T
365 self.process_rank = process_rank
366 self.num_processes = num_processes
367 assert split in {'train', 'val'}
368
369 # get the shard filenames
370 data_root = "../data/edu_fineweb10B"
371 shards = os.listdir(data_root)
372 shards = [s for s in shards if split in s]
373 shards = sorted(shards)
374 shards = [os.path.join(data_root, s) for s in shards]
375 self.shards = shards
376 assert len(shards) > 0, f"no shards found for split {split}"
377 if master_process:
378 print(f"found {len(shards)} shards for split {split}")
379
380 # state, init at shard zero
381 self.current_shard = 0
382 self.tokens = load_tokens(self.shards[self.current_shard])
383 self.current_position = self.B * self.T * self.process_rank
384 self.reset()
385
386 def reset(self):
387 self.current_shard = 0
388 self.tokens = load_tokens(self.shards[self.current_shard])
389 self.current_position = self.B * self.T * self.process_rank
390
391 def next_batch(self):
392 B, T = self.B, self.T
393 buf = self.tokens[self.current_position : self.current_position+B*T+1]
394 x = (buf[:-1]).view(B, T) # inputs
395 y = (buf[1:]).view(B, T) # targets
396 # advance the position in the tensor
397 self.current_position += B * T * self.num_processes
398 # if loading the next batch would be out of bounds, advance to next shard
399 if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
400 self.current_shard = (self.current_shard + 1) % len(self.shards)
401 self.tokens = load_tokens(self.shards[self.current_shard])
402 self.current_position = B * T * self.process_rank
403 return x, y
404
405
406train_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train")
407val_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val")
408
409
410# -------------------------------
411# run the training loop
412
413
414
415num_return_sequences = 5
416max_length = 30
417
418# model = GPT.from_pretrained('gpt2')
419# overriding the ugly vocab size number with a power of 2 number here
420# when doing distributed training, WORLD_SIZE models get created now
421# they all have the same seed so they're all identical (2:59)
422model = GPT(GPTConfig(vocab_size=50304))
423model.to(device)
424use_compile = False # torch.compile interferes with HellaSwag eval and Generation. TODO fix
425if use_compile:
426 model = torch.compile(model)
427if ddp:
428 model = DDP(model, device_ids=[ddp_local_rank])
429raw_model = model.module if ddp else model # this contains the configure optimizers func we wanna call
430# -- logging --
431# create the log directory we will write checkpoints to and log to
432log_dir = "log"
433os.makedirs(log_dir, exist_ok=True)
434log_file = os.path.join(log_dir, f"log.txt")
435with open(log_file, "w") as f: # open for writing to clear the file
436 pass
437
438# --- cosine decay lr ----
439max_lr = 6e-4 # gpt-3 small LR per their paper (gpt-2 doesn't specify)
440min_lr = max_lr * 0.1
441warmup_steps = 715 # these two hps map to the gpt-3 schedule, adapted for fineweb10b
442max_steps = 19073
443def get_lr(it):
444 # 1) linear warmup for warmup_iters steps
445 if it < warmup_steps:
446 return max_lr * (it+1) / warmup_steps
447 # 2) if it > lr_decay_iters, return min lr
448 if it > max_steps:
449 return min_lr
450 # 3) in between, use cosine decay down to min learning rate
451 decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
452 assert 0 <= decay_ratio <= 1
453 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
454 return min_lr + coeff * (max_lr - min_lr)
455# --- cosine decay lr end ----
456
457
458device_type = 'cuda' if 'cuda' in device else 'cpu'
459
460# checkpoint loading logic
461initial_iter = 0
462if os.path.exists(log_dir):
463 checkpoints = [f for f in os.listdir(log_dir) if f.startswith('model_')]
464 if checkpoints:
465 latest = max(checkpoints, key=lambda x: int(x.split('_')[1].split('.')[0]))
466 checkpoint_path = os.path.join(log_dir, latest)
467 checkpoint = torch.load(checkpoint_path, map_location=device)
468 raw_model.load_state_dict(checkpoint['model'])
469 initial_iter = checkpoint['step'] + 1
470 if master_process:
471 print(f"resuming from step {initial_iter}")
472
473optimizer = raw_model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device_type)
474
475# load optimizer state if checkpoint exists
476if initial_iter > 0:
477 try:
478 if 'optimizer' in checkpoint:
479 optimizer.load_state_dict(checkpoint['optimizer'])
480 if 'rng_state' in checkpoint:
481 rng_state = checkpoint['rng_state']
482 if not isinstance(rng_state, torch.ByteTensor):
483 if master_process:
484 print(f"Warning: Converting RNG state from {type(rng_state)} to ByteTensor")
485 # Convert to uint8 tensor then to ByteTensor
486 if isinstance(rng_state, torch.Tensor):
487 rng_state = rng_state.to(torch.uint8)
488 torch.set_rng_state(rng_state)
489 if checkpoint.get('cuda_rng_state') is not None:
490 cuda_rng_state = checkpoint['cuda_rng_state']
491 if not isinstance(cuda_rng_state, torch.ByteTensor):
492 if master_process:
493 print(f"Warning: Converting CUDA RNG state from {type(cuda_rng_state)} to ByteTensor")
494 # Convert to uint8 tensor
495 if isinstance(cuda_rng_state, torch.Tensor):
496 cuda_rng_state = cuda_rng_state.to(torch.uint8)
497 torch.cuda.set_rng_state(cuda_rng_state)
498 if 'loader_position' in checkpoint:
499 train_loader.current_position = checkpoint['loader_position']
500 if 'loader_shard' in checkpoint:
501 train_loader.current_shard = checkpoint['loader_shard']
502 train_loader.tokens = load_tokens(train_loader.shards[train_loader.current_shard])
503 # reset lr scheduler
504 for param_group in optimizer.param_groups:
505 param_group['lr'] = get_lr(initial_iter)
506 if master_process:
507 print(f"Successfully loaded checkpoint from step {initial_iter}")
508 except Exception as e:
509 if master_process:
510 print(f"Warning: Failed to load some checkpoint data: {e}")
511 print("Continuing with partial checkpoint restore...")
512
513for step in range(initial_iter, max_steps):
514 t0 = time.time()
515 last_step = (step == max_steps - 1)
516 # occasionally find out the val loss
517 if step % 500 == 0 or last_step:
518 model.eval()
519 val_loader.reset()
520 if torch.cuda.is_available():
521 torch.cuda.reset_peak_memory_stats()
522 torch.cuda.empty_cache()
523 with torch.no_grad():
524 val_loss_accum = 0.0
525 val_loss_steps = 20
526 for _ in range(val_loss_steps):
527 x, y = val_loader.next_batch()
528 x, y = x.to(device), y.to(device)
529 with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
530 logits, loss = model(x,y)
531 loss = loss / val_loss_steps
532 val_loss_accum += loss.detach()
533 if ddp:
534 dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
535 if master_process:
536 val_loss_val = val_loss_accum.item()
537 print(f"step {step}, val loss: {val_loss_val:.4f}")
538 # Check for potential overfitting every eval
539 if step > 500 and hasattr('prev_train_loss', '__self__'):
540 train_val_gap = val_loss_val - prev_train_loss
541 if train_val_gap > 1.5:
542 print(f"Warning: Large train/val gap ({train_val_gap:.3f}) - possible overfitting")
543 with open(log_file, "a") as f:
544 f.write(f"{step} val {val_loss_val:.4f}\n")
545 if torch.cuda.is_available():
546 torch.cuda.empty_cache()
547
548 if step > 0 and (step % 1000 == 0 or last_step):
549 checkpoint_path = os.path.join(log_dir, f"model_{step:05d}.pt")
550 # ensure RNG states are ByteTensors for compatibility
551 rng_state = torch.get_rng_state()
552 cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
553
554 checkpoint = {
555 'model': raw_model.state_dict(),
556 'config': raw_model.config,
557 'step': step,
558 'val_loss': val_loss_accum.item(),
559 'optimizer': optimizer.state_dict(),
560 'rng_state': rng_state.byte() if not isinstance(rng_state, torch.ByteTensor) else rng_state,
561 'cuda_rng_state': cuda_rng_state.byte() if cuda_rng_state is not None and not isinstance(cuda_rng_state, torch.ByteTensor) else cuda_rng_state,
562 'loader_position': train_loader.current_position,
563 'loader_shard': train_loader.current_shard,
564 }
565 torch.save(checkpoint, checkpoint_path)
566
567 # once in a while evaluate hellaswag
568 if (step % 500 == 0 or last_step) and (not use_compile):
569 num_correct_norm = 0
570 num_total = 0
571 for i, example in enumerate(iterate_examples("val")):
572 # only process examples where i % ddp_world_size == ddp_rank
573 if i % ddp_world_size != ddp_rank:
574 continue
575 # render the example into tokens and labels
576 _, tokens, mask, label = render_example(example)
577 tokens = tokens.to(device)
578 mask = mask.to(device)
579 # get the logits
580 with torch.no_grad():
581 with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
582 logits, loss = model(tokens)
583 pred_norm = get_most_likely_row(tokens, mask, logits)
584 num_total += 1
585 num_correct_norm += int(pred_norm == label)
586 # reduce the stats across all processes
587 if ddp:
588 num_total = torch.tensor(num_total, dtype=torch.long, device=device)
589 num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
590 dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
591 dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
592 num_total = num_total.item()
593 num_correct_norm = num_correct_norm.item()
594 acc_norm = num_correct_norm / num_total
595 if master_process:
596 print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
597 with open(log_file, "a") as f:
598 f.write(f"{step} hella {acc_norm:.4f}\n")
599 if torch.cuda.is_available():
600 torch.cuda.empty_cache()
601 if ((step > 0 and step % 500 == 0) or last_step) and (not use_compile):
602 model.eval()
603 num_return_sequences = 4
604 max_length = 32
605 tokens = enc.encode("Hello, I'm a language model,")
606 tokens = torch.tensor(tokens, dtype=torch.long)
607 tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
608 xgen = tokens.to(device)
609 sample_rng = torch.Generator(device=device)
610 sample_rng.manual_seed(42 + ddp_rank)
611 # in training loop
612 while xgen.size(1) < max_length:
613 # forward the model to get the logits
614 with torch.no_grad():
615 with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
616 logits, loss = model(xgen) # (B, T, vocab_size)
617 # take the logits at the last position
618 logits = logits[:, -1, :] # (B, vocab_size)
619 # get the probabilities
620 probs = F.softmax(logits, dim=-1)
621 # do top-k sampling of 50 (huggingface pipeline default)
622 # topk_probs here becomes (5, 50), topk_indices is (5, 50)
623 topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
624 # select a token from the top-k probabilities
625 # note: multinomial does not demand the input to sum to 1
626 ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
627 # gather the corresponding indices
628 xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
629 # append to the sequence
630 xgen = torch.cat((xgen, xcol), dim=1)
631 # print the generated text
632 for i in range(num_return_sequences):
633 tokens = xgen[i, :max_length].tolist()
634 decoded = enc.decode(tokens)
635 print(f"rank {ddp_rank} sample {i}: {decoded}")
636 if torch.cuda.is_available():
637 torch.cuda.empty_cache()
638
639 model.train()
640 optimizer.zero_grad()
641 loss_accum = 0.0
642 # gradient accumulation 2:39:23
643 for micro_step in range(grad_accum_steps):
644 x, y = train_loader.next_batch()
645 x, y = x.to(device), y.to(device)
646 if ddp:
647 model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1) # should only share grads on last step
648 with torch.autocast(device_type=device_type, dtype=torch.bfloat16): # bfloat only possible with ampere gpus
649 logits, loss = model(x,y)
650 loss = loss / grad_accum_steps # 2:44, otherwise losses sum over the accumulated passes
651 loss_accum += loss.detach()
652 loss.backward()
653 if ddp:
654 dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG) # average the loss across all gpus
655 # final gradient clipping after accumulation
656 # 2:18:00
657 norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
658 # lr scheduler is cosine decay (2:22)
659 lr = get_lr(step)
660 for param_group in optimizer.param_groups: # pt treats params as groups in optimization;
661 #there's only one item in this group fyi
662 param_group['lr'] = lr
663 optimizer.step()
664 torch.cuda.synchronize() # this awaits for all kernels to finish
665 t1 = time.time()
666 dt = (t1 - t0)*1000
667 tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
668 tokens_per_sec = tokens_processed / (t1 - t0)
669 # keep an eye on gradient norms, signal of problems if anomalies
670 if master_process:
671 train_loss = loss_accum.item()
672 print(f"step {step}, loss: {train_loss:.6f}, lr: {lr:.4e}, norm: {norm:.4f}, dt: {dt:.2f}ms, tok/sec: {tokens_per_sec:.2f}")
673 # Store for overfitting detection
674 prev_train_loss = train_loss
675 with open(log_file, "a") as f:
676 f.write(f"{step} train {train_loss:.6f}\n")
677
678if ddp:
679 destroy_process_group()
680
681# only master saves & uploads
682if master_process:
683 # create repo (once)
684
685 # save the raw model (not ddp wrapper)
686 torch.save(raw_model.state_dict(), 'model.pt')
687 api = HfApi()
688 api.upload_file(
689 path_or_fileobj="model.pt",
690 path_in_repo="model.pt",
691 repo_id="bathrobe/my-gpt2"
692 )
693 print("model uploaded to hf")
694