We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
import math from dataclasses import dataclass from typing import Dict import torch as tc import torch.nn as nn from torch.nn import functional as F TS_DICT = Dict[str, tc.Tensor] @dataclass class GPTConfig: block_size: int = 1024 vocab_size: int = 50304 n_layer: int = 12 n_head: int = 12 n_embed: int = 768 dropout: float = 0.0 bias: bool = True class Attention(nn.Module): def __init__(s, c: GPTConfig): super().__init__() s.conf = c assert c.n_embed % c.n_head == 0 E = c.n_embed s.c_attn = nn.Linear(E, 3 * E, c.bias) s.c_proj = nn.Linear(E, E, c.bias) s.resid_dropout = nn.Dropout(c.dropout) def forward(s, x: tc.Tensor): B, T, E = x.shape H = s.conf.n_head y1: tc.Tensor = s.c_attn(x) q, k, v = [z.view(B, T, H, E // H).transpose(1, 2) for z in y1.split(E, dim=2)] drop = s.conf.dropout if s.training else 0 y2 = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=True) y3 = y2.transpose(1, 2).contiguous().view(B, T, E) return s.resid_dropout(s.c_proj(y3)) class MLP(nn.Module): def __init__(s, c: GPTConfig): super().__init__() E = c.n_embed s.c_fc = nn.Linear(E, 4 * E, c.bias) s.gelu = nn.GELU() s.c_proj = nn.Linear(4 * E, E, c.bias) s.dropout = nn.Dropout(c.dropout) def forward(s, x): return s.dropout(s.c_proj(s.gelu(s.c_fc(x)))) class TransLayer(nn.Module): def __init__(s, c: GPTConfig): super().__init__() s.ln_1 = nn.LayerNorm(c.n_embed, bias=c.bias) s.attn = Attention(c) s.ln_2 = nn.LayerNorm(c.n_embed, bias=c.bias) s.mlp = MLP(c) def forward(s, x): x = x + s.attn(s.ln_1(x)) return x + s.mlp(s.ln_2(x)) class GPT(nn.Module): def __init__(s, c: GPTConfig): super().__init__() s.conf = c E = c.n_embed s.transformer = nn.ModuleDict( dict( wte=nn.Embedding(c.vocab_size, E), wpe=nn.Embedding(c.block_size, E), drop=nn.Dropout(c.dropout), h=nn.ModuleList([TransLayer(c) for _ in range(c.n_layer)]), ln_f=nn.LayerNorm(E, bias=c.bias), ) ) s.lm_head = nn.Linear(E, c.vocab_size, bias=False) s.transformer.wte.weight = s.lm_head.weight # weight-tying s.apply(s._init_weights) for k, p in s.named_parameters(): if k.endswith("c_proj.weight"): nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * c.n_layer)) print(f"n_params: {s.n_params() / 1e6:.2f}M") def n_params(s): return sum(p.numel() for p in s.parameters()) def _init_weights(s, m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def forward(s, x: tc.Tensor, y0: tc.Tensor = None): B, T = x.shape c = s.conf assert T <= c.block_size tok = s.transformer.wte(x) pos = s.transformer.wpe(tc.arange(0, T, dtype=tc.long, device=x.device)) x = s.transformer.drop(tok + pos) for layer in s.transformer.h: x = layer(x) x = s.transformer.ln_f(x) if y0 is not None: logits: tc.Tensor = s.lm_head(x) loss = F.cross_entropy( logits.view(-1, logits.size(-1)), y0.view(-1), ignore_index=-1 ) else: logits = s.lm_head(x[:, [-1], :]) loss = None return logits, loss def crop_block_size(s, n): assert n <= s.conf.block_size s.conf.block_size = n wpe = s.transformer.wpe wpe.weight = nn.Parameter(wpe.weight[:n]) for layer in s.transformer.h: if hasattr(layer.attn, "bias"): layer.attn.bias = layer.attn.bias[:, :, :n, :n] @classmethod @tc.no_grad() def from_pretrained(s, type, drop=0.0): assert type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} from transformers import GPT2LMHeadModel args = { "gpt2": dict(n_layer=12, n_head=12, n_embed=768), # 124M params "gpt2-medium": dict(n_layer=24, n_head=16, n_embed=1024), # 350M params "gpt2-large": dict(n_layer=36, n_head=20, n_embed=1280), # 774M params "gpt2-xl": dict(n_layer=48, n_head=25, n_embed=1600), # 1558M params }[type] args.update(dict(vocab_size=50257, block_size=1024, bias=True, dropout=drop)) m = GPT(GPTConfig(**args)) sd: TS_DICT = m.state_dict() ignore = (".attn.bias", ".attn.masked_bias") keys = [k for k in sd if not k.endswith(ignore)] m2 = GPT2LMHeadModel.from_pretrained(type) sd2: TS_DICT = m2.state_dict() keys2 = [k for k in sd2 if not k.endswith(ignore)] assert len(keys) == len(keys2) trans = [ "attn.c_attn.weight", "attn.c_proj.weight", "mlp.c_fc.weight", "mlp.c_proj.weight", ] for k in keys2: x = sd2[k].t() if any(k.endswith(w) for w in trans) else sd2[k] sd[k].copy_(x) return m @tc.no_grad() def generate(s, x, num, temp=1.0): B = s.conf.block_size for _ in range(num): x = x if x.size(1) <= B else x[:, -B:] logits: tc.Tensor = s(x)[0] probs = F.softmax(logits[:, -1, :] / temp, dim=-1) x2 = tc.multinomial(probs, num_samples=1) x = tc.cat((x, x2), dim=1) return x
import os from contextlib import nullcontext import numpy as np import torch as tc import torch.nn as nn import wandb from torch.distributed import destroy_process_group, init_process_group from torch.nn.parallel import DistributedDataParallel as DDP class LLMTrainer: # ======== setup ==================== backend = "nccl" grad_acc_steps = 40 w_decay = 0.1 lr = 6e-4 betas = [0.9, 0.95] fused = True compile = True save_path = "llm_cp" log_id = None # ======== get_batch ================= data_path = "llm_data" batch_size = 12 block_size = 1024 # ======= get_cos_lr =============== i_warmup = 2000 # ========== train ================= i_eval = 2000 n_eval = 200 grad_clip = 1.0 i_max = 600000 def setup(s, m: nn.Module): cuda_ok = tc.cuda.is_available() bf16_ok = tc.cuda.is_bf16_supported() device = "cuda" if cuda_ok else "cpu" dtype = "bfloat16" if cuda_ok and bf16_ok else "float16" rank = int(os.environ.get("RANK", -1)) s.is_master = rank in [-1, 0] if rank != -1: init_process_group(s.backend) local_rank = int(os.environ["LOCAL_RANK"]) device = f"cuda:{local_rank}" tc.cuda.set_device(device) n_proc = int(os.environ["WORLD_SIZE"]) assert s.grad_acc_steps % n_proc == 0 s.grad_acc_steps //= n_proc tc.manual_seed(1338 + rank) tc.backends.cuda.matmul.allow_tf32 = True tc.backends.cudnn.allow_tf32 = True s.ctx = ( nullcontext() if device == "cpu" else tc.amp.autocast("cuda", getattr(tc, dtype)) ) s.scaler = tc.amp.grad_scaler.GradScaler(device, enabled=dtype == "float16") # ============================================ g1 = [p for p in m.parameters() if p.dim() >= 2] g2 = [p for p in m.parameters() if p.dim() < 2] params = [ {"params": g1, "weight_decay": s.w_decay}, {"params": g2, "weight_decay": 0.0}, ] s.opt = tc.optim.AdamW(params, lr=s.lr, betas=s.betas, fused=s.fused) if os.path.exists(s.save_path): r = tc.load(s.save_path) m.load_state_dict(r["model"]) s.opt.load_state_dict(r["opt"]) s.iter, s.best_loss = r["iter"], r["best_loss"] else: s.iter, s.best_loss = 0, np.inf # ============================================ m.to(device) if rank != -1: m = DDP(m, device_ids=[local_rank]) if s.compile: print("compiling") m.compile() if s.log_id: wandb.init(project="LLMTrainer", name=s.log_id) s.device = device s.model = m def get_batch(s, split="train"): B, T = s.batch_size, s.block_size data = np.memmap(s.data_path, np.uint16, mode="r").astype(np.int64) a = int(len(data) * 0.9) data = data[:a] if split == "train" else data[a:] idx = tc.randint(len(data) - T, (B,)) x = tc.stack([tc.from_numpy(data[i : i + T]) for i in idx]) y = tc.stack([tc.from_numpy(data[i + 1 : i + 1 + T]) for i in idx]) if s.device != "cpu": x = x.pin_memory().to(s.device, non_blocking=True) y = y.pin_memory().to(s.device, non_blocking=True) return x, y @tc.no_grad() def get_losses(s): m = s.model m.eval() res = {} for split in ["train", "val"]: losses = tc.zeros(s.n_eval) for k in range(s.n_eval): x, y = s.get_batch(split) with s.ctx: loss: tc.Tensor = m(x, y)[1] losses[k] = loss.item() res[split] = losses.mean() m.train() return res def get_cos_lr(s, i): min_lr = s.lr * 0.1 if i < s.i_warmup: return s.lr * (i + 1) / (s.i_warmup + 1) if i > s.i_max: return min_lr r = (i - s.i_warmup) / (s.i_max - s.i_warmup) assert 0 <= r <= 1 c = 0.5 * (1.0 + np.cos(np.pi * r)) return min_lr + c * (s.lr - min_lr) def train(s): x, y = s.get_batch() m = s.model is_ddp = isinstance(m, DDP) # raw_m = m.module if is_ddp else m while s.iter < s.i_max: lr = s.get_cos_lr(s.iter) for p in s.opt.param_groups: p["lr"] = lr if s.iter % s.i_eval == 0 and s.is_master: losses = s.get_losses() print(f"{s.iter} {losses}") if s.log_id: wandb.log( { "iter": s.iter, "train/loss": losses["train"], "val/loss": losses["val"], "lr": lr, } ) if losses["val"] < s.best_loss: s.best_loss = losses["val"] if s.iter: obj = { "model": m.state_dict(), "opt": s.opt.state_dict(), "iter": s.iter, "best_loss": s.best_loss, } tc.save(obj, s.save_path) for j in range(s.grad_acc_steps): if is_ddp: m.require_backward_grad_sync = j == s.grad_acc_steps - 1 with s.ctx: logits, loss = m(x, y) loss /= s.grad_acc_steps x, y = s.get_batch() s.scaler.scale(loss).backward() if s.grad_clip: s.scaler.unscale_(s.opt) nn.utils.clip_grad_norm_(m.parameters(), s.grad_clip) s.scaler.step(s.opt) s.scaler.update() s.opt.zero_grad() s.iter += 1 if is_ddp: destroy_process_group()
The text was updated successfully, but these errors were encountered:
No branches or pull requests
model
trainer (not fully tested -- I don't have GPU)
The text was updated successfully, but these errors were encountered: