Skip to content

Commit f371351

Browse files
authored
Add nanogpt examples (#79)
* added nanogpt example set torch version >= 2.1.0 in pyproject.toml changed message when there are multiple cupti subscribers * deleted comments in session.py - habitat section * removed session.py from PR --------- Co-authored-by: John Calderon <[email protected]>
1 parent d775976 commit f371351

File tree

3 files changed

+278
-1
lines changed

3 files changed

+278
-1
lines changed

examples/nanogpt/entry_point.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import numpy as np
2+
import torch
3+
from torch import nn
4+
5+
from model import GPTConfig, GPT
6+
7+
# Batch size.
8+
block_size = 32
9+
device = "cuda" if torch.cuda.is_available() else "cpu"
10+
11+
# model
12+
n_layer = 16
13+
n_head = 16
14+
n_embd = 512
15+
dropout = 0.0
16+
vocab_size = 65
17+
bias = False
18+
19+
# Adamw optimizer
20+
learning_rate = 6e-4
21+
weight_decay = 1e-1
22+
beta1 = 0.9
23+
beta2 = 0.95
24+
25+
26+
# optimizer
27+
def configure_optimizer(model, weight_decay, learning_rate, betas):
28+
param_dict = {pn: p for pn, p in model.named_parameters()}
29+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
30+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
31+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
32+
optim_groups = [
33+
{"params": decay_params, "weight_decay": weight_decay},
34+
{"params": nodecay_params, "weight_decay": 0.0},
35+
]
36+
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
37+
38+
return optimizer
39+
40+
def deepview_model_provider():
41+
# model init
42+
# ---------------------------------------------
43+
# Enable flash attention
44+
enable_flash_attention = False
45+
model_args = dict(
46+
n_layer=n_layer,
47+
n_head=n_head,
48+
n_embd=n_embd,
49+
block_size=block_size,
50+
bias=bias,
51+
vocab_size=vocab_size,
52+
dropout=dropout,
53+
enable_flash_attention=enable_flash_attention,
54+
)
55+
gptconf = GPTConfig(**model_args)
56+
model = GPT(gptconf)
57+
return model.to(device)
58+
59+
60+
def deepview_input_provider(batch_size=48):
61+
data = np.random.randint(vocab_size, size=(batch_size, block_size + 1))
62+
x = torch.stack(
63+
[torch.from_numpy((data[i, :-1]).astype(np.int64)) for i in range(batch_size)]
64+
)
65+
y = torch.stack(
66+
[torch.from_numpy((data[i, 1:]).astype(np.int64)) for i in range(batch_size)]
67+
)
68+
69+
return (x.to(device), y.to(device))
70+
71+
72+
def deepview_iteration_provider(model):
73+
criterion = nn.CrossEntropyLoss()
74+
optimizer = torch.optim.AdamW(
75+
model.parameters(), lr=learning_rate, betas=(beta1, beta2)
76+
)
77+
78+
def iteration(inputs, targets):
79+
optimizer.zero_grad()
80+
outputs = model(inputs)
81+
loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
82+
loss.backward()
83+
optimizer.step()
84+
85+
return iteration

examples/nanogpt/model.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""
2+
Model Reference: https://github.com/karpathy/nanoGPT
3+
Full definition of a GPT Language Model, all of it in this single file.
4+
References:
5+
1) the official GPT-2 TensorFlow implementation released by OpenAI:
6+
https://github.com/openai/gpt-2/blob/master/src/model.py
7+
2) huggingface/transformers PyTorch implementation:
8+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
9+
"""
10+
11+
import math
12+
from dataclasses import dataclass
13+
14+
import torch
15+
import torch.nn as nn
16+
from torch.nn import functional as F
17+
18+
# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
19+
def new_gelu(x):
20+
"""
21+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
22+
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
23+
"""
24+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
25+
26+
class LayerNorm(nn.Module):
27+
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
28+
29+
def __init__(self, ndim, bias):
30+
super().__init__()
31+
self.weight = nn.Parameter(torch.ones(ndim))
32+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
33+
34+
def forward(self, input):
35+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
36+
37+
class CausalSelfAttention(nn.Module):
38+
39+
def __init__(self, config):
40+
super().__init__()
41+
assert config.n_embd % config.n_head == 0
42+
# key, query, value projections for all heads, but in a batch
43+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
44+
# output projection
45+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
46+
# regularization
47+
self.attn_dropout = nn.Dropout(config.dropout)
48+
self.resid_dropout = nn.Dropout(config.dropout)
49+
self.n_head = config.n_head
50+
self.n_embd = config.n_embd
51+
self.dropout = config.dropout
52+
# flash attention only in PyTorch >= 2.0
53+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') if config.enable_flash_attention else False
54+
if not self.flash:
55+
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
56+
# causal mask to ensure that attention is only applied to the left in the input sequence
57+
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
58+
.view(1, 1, config.block_size, config.block_size))
59+
60+
def forward(self, x):
61+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
62+
63+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
64+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
65+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
66+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
67+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68+
69+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
70+
if self.flash:
71+
# efficient attention using Flash Attention CUDA kernels
72+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
73+
else:
74+
# manual implementation of attention
75+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
76+
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
77+
att = F.softmax(att, dim=-1)
78+
att = self.attn_dropout(att)
79+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
80+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
81+
82+
# output projection
83+
y = self.resid_dropout(self.c_proj(y))
84+
return y
85+
86+
class MLP(nn.Module):
87+
88+
def __init__(self, config):
89+
super().__init__()
90+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
91+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
92+
self.dropout = nn.Dropout(config.dropout)
93+
94+
def forward(self, x):
95+
x = self.c_fc(x)
96+
x = new_gelu(x)
97+
x = self.c_proj(x)
98+
x = self.dropout(x)
99+
return x
100+
101+
class Block(nn.Module):
102+
103+
def __init__(self, config):
104+
super().__init__()
105+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
106+
self.attn = CausalSelfAttention(config)
107+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
108+
self.mlp = MLP(config)
109+
110+
def forward(self, x):
111+
x = x + self.attn(self.ln_1(x))
112+
x = x + self.mlp(self.ln_2(x))
113+
return x
114+
115+
@dataclass
116+
class GPTConfig:
117+
block_size: int = 1024
118+
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
119+
n_layer: int = 12
120+
n_head: int = 12
121+
n_embd: int = 768
122+
dropout: float = 0.0
123+
enable_flash_attention: bool = False
124+
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
125+
126+
class GPT(nn.Module):
127+
128+
def __init__(self, config):
129+
super().__init__()
130+
assert config.vocab_size is not None
131+
assert config.block_size is not None
132+
self.config = config
133+
134+
self.transformer = nn.ModuleDict(dict(
135+
wte = nn.Embedding(config.vocab_size, config.n_embd),
136+
wpe = nn.Embedding(config.block_size, config.n_embd),
137+
drop = nn.Dropout(config.dropout),
138+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
139+
ln_f = LayerNorm(config.n_embd, bias=config.bias),
140+
))
141+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
142+
# with weight tying when using torch.compile() some warnings get generated:
143+
# "UserWarning: functional_call was passed multiple values for tied weights.
144+
# This behavior is deprecated and will be an error in future versions"
145+
# not 100% sure what this is, so far seems to be harmless. TODO investigate
146+
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
147+
148+
# init all weights
149+
self.apply(self._init_weights)
150+
# apply special scaled init to the residual projections, per GPT-2 paper
151+
for pn, p in self.named_parameters():
152+
if pn.endswith('c_proj.weight'):
153+
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
154+
155+
# report number of parameters
156+
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
157+
158+
def get_num_params(self, non_embedding=True):
159+
"""
160+
Return the number of parameters in the model.
161+
For non-embedding count (default), the position embeddings get subtracted.
162+
The token embeddings would too, except due to the parameter sharing these
163+
params are actually used as weights in the final layer, so we include them.
164+
"""
165+
n_params = sum(p.numel() for p in self.parameters())
166+
if non_embedding:
167+
n_params -= self.transformer.wpe.weight.numel()
168+
return n_params
169+
170+
def _init_weights(self, module):
171+
if isinstance(module, nn.Linear):
172+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
173+
if module.bias is not None:
174+
torch.nn.init.zeros_(module.bias)
175+
elif isinstance(module, nn.Embedding):
176+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
177+
178+
def forward(self, idx):
179+
device = idx.device
180+
b, t = idx.size()
181+
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
182+
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
183+
184+
# forward the GPT model itself
185+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
186+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
187+
x = self.transformer.drop(tok_emb + pos_emb)
188+
for block in self.transformer.h:
189+
x = block(x)
190+
x = self.transformer.ln_f(x)
191+
logits = self.lm_head(x)
192+
return logits

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ python = "^3.7"
2828
pyyaml = "*"
2929
protobuf = "3.19.6"
3030
numpy = "^1.15.2"
31-
torch = ">=1.13.1"
31+
torch = ">=2.1.0"
3232
nvidia-ml-py3 = "*"
3333
toml = "^0.10.2"
3434
pyRAPL = "^0.2.3"

0 commit comments

Comments
 (0)