Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attempt to refactor nanoGPT #591

Open
tesla-cat opened this issue Feb 2, 2025 · 0 comments
Open

attempt to refactor nanoGPT #591

tesla-cat opened this issue Feb 2, 2025 · 0 comments

Comments

@tesla-cat
Copy link

model

import math
from dataclasses import dataclass
from typing import Dict

import torch as tc
import torch.nn as nn
from torch.nn import functional as F

TS_DICT = Dict[str, tc.Tensor]


@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 12
    n_embed: int = 768
    dropout: float = 0.0
    bias: bool = True


class Attention(nn.Module):
    def __init__(s, c: GPTConfig):
        super().__init__()
        s.conf = c
        assert c.n_embed % c.n_head == 0
        E = c.n_embed
        s.c_attn = nn.Linear(E, 3 * E, c.bias)
        s.c_proj = nn.Linear(E, E, c.bias)
        s.resid_dropout = nn.Dropout(c.dropout)

    def forward(s, x: tc.Tensor):
        B, T, E = x.shape
        H = s.conf.n_head
        y1: tc.Tensor = s.c_attn(x)
        q, k, v = [z.view(B, T, H, E // H).transpose(1, 2) for z in y1.split(E, dim=2)]
        drop = s.conf.dropout if s.training else 0
        y2 = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=True)
        y3 = y2.transpose(1, 2).contiguous().view(B, T, E)
        return s.resid_dropout(s.c_proj(y3))


class MLP(nn.Module):
    def __init__(s, c: GPTConfig):
        super().__init__()
        E = c.n_embed
        s.c_fc = nn.Linear(E, 4 * E, c.bias)
        s.gelu = nn.GELU()
        s.c_proj = nn.Linear(4 * E, E, c.bias)
        s.dropout = nn.Dropout(c.dropout)

    def forward(s, x):
        return s.dropout(s.c_proj(s.gelu(s.c_fc(x))))


class TransLayer(nn.Module):
    def __init__(s, c: GPTConfig):
        super().__init__()
        s.ln_1 = nn.LayerNorm(c.n_embed, bias=c.bias)
        s.attn = Attention(c)
        s.ln_2 = nn.LayerNorm(c.n_embed, bias=c.bias)
        s.mlp = MLP(c)

    def forward(s, x):
        x = x + s.attn(s.ln_1(x))
        return x + s.mlp(s.ln_2(x))


class GPT(nn.Module):
    def __init__(s, c: GPTConfig):
        super().__init__()
        s.conf = c
        E = c.n_embed
        s.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(c.vocab_size, E),
                wpe=nn.Embedding(c.block_size, E),
                drop=nn.Dropout(c.dropout),
                h=nn.ModuleList([TransLayer(c) for _ in range(c.n_layer)]),
                ln_f=nn.LayerNorm(E, bias=c.bias),
            )
        )
        s.lm_head = nn.Linear(E, c.vocab_size, bias=False)
        s.transformer.wte.weight = s.lm_head.weight  # weight-tying

        s.apply(s._init_weights)
        for k, p in s.named_parameters():
            if k.endswith("c_proj.weight"):
                nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * c.n_layer))
        print(f"n_params: {s.n_params() / 1e6:.2f}M")

    def n_params(s):
        return sum(p.numel() for p in s.parameters())

    def _init_weights(s, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def forward(s, x: tc.Tensor, y0: tc.Tensor = None):
        B, T = x.shape
        c = s.conf
        assert T <= c.block_size
        tok = s.transformer.wte(x)
        pos = s.transformer.wpe(tc.arange(0, T, dtype=tc.long, device=x.device))
        x = s.transformer.drop(tok + pos)
        for layer in s.transformer.h:
            x = layer(x)
        x = s.transformer.ln_f(x)

        if y0 is not None:
            logits: tc.Tensor = s.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), y0.view(-1), ignore_index=-1
            )
        else:
            logits = s.lm_head(x[:, [-1], :])
            loss = None
        return logits, loss

    def crop_block_size(s, n):
        assert n <= s.conf.block_size
        s.conf.block_size = n
        wpe = s.transformer.wpe
        wpe.weight = nn.Parameter(wpe.weight[:n])
        for layer in s.transformer.h:
            if hasattr(layer.attn, "bias"):
                layer.attn.bias = layer.attn.bias[:, :, :n, :n]

    @classmethod
    @tc.no_grad()
    def from_pretrained(s, type, drop=0.0):
        assert type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"}
        from transformers import GPT2LMHeadModel

        args = {
            "gpt2": dict(n_layer=12, n_head=12, n_embed=768),  # 124M params
            "gpt2-medium": dict(n_layer=24, n_head=16, n_embed=1024),  # 350M params
            "gpt2-large": dict(n_layer=36, n_head=20, n_embed=1280),  # 774M params
            "gpt2-xl": dict(n_layer=48, n_head=25, n_embed=1600),  # 1558M params
        }[type]
        args.update(dict(vocab_size=50257, block_size=1024, bias=True, dropout=drop))

        m = GPT(GPTConfig(**args))
        sd: TS_DICT = m.state_dict()
        ignore = (".attn.bias", ".attn.masked_bias")
        keys = [k for k in sd if not k.endswith(ignore)]

        m2 = GPT2LMHeadModel.from_pretrained(type)
        sd2: TS_DICT = m2.state_dict()
        keys2 = [k for k in sd2 if not k.endswith(ignore)]
        assert len(keys) == len(keys2)

        trans = [
            "attn.c_attn.weight",
            "attn.c_proj.weight",
            "mlp.c_fc.weight",
            "mlp.c_proj.weight",
        ]
        for k in keys2:
            x = sd2[k].t() if any(k.endswith(w) for w in trans) else sd2[k]
            sd[k].copy_(x)
        return m

    @tc.no_grad()
    def generate(s, x, num, temp=1.0):
        B = s.conf.block_size
        for _ in range(num):
            x = x if x.size(1) <= B else x[:, -B:]
            logits: tc.Tensor = s(x)[0]
            probs = F.softmax(logits[:, -1, :] / temp, dim=-1)
            x2 = tc.multinomial(probs, num_samples=1)
            x = tc.cat((x, x2), dim=1)
        return x

trainer (not fully tested -- I don't have GPU)

import os
from contextlib import nullcontext

import numpy as np
import torch as tc
import torch.nn as nn
import wandb
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP


class LLMTrainer:
    # ======== setup ====================
    backend = "nccl"
    grad_acc_steps = 40

    w_decay = 0.1
    lr = 6e-4
    betas = [0.9, 0.95]
    fused = True
    compile = True

    save_path = "llm_cp"
    log_id = None

    # ======== get_batch =================
    data_path = "llm_data"
    batch_size = 12
    block_size = 1024

    # ======= get_cos_lr ===============
    i_warmup = 2000

    # ========== train =================
    i_eval = 2000
    n_eval = 200
    grad_clip = 1.0
    i_max = 600000

    def setup(s, m: nn.Module):
        cuda_ok = tc.cuda.is_available()
        bf16_ok = tc.cuda.is_bf16_supported()
        device = "cuda" if cuda_ok else "cpu"
        dtype = "bfloat16" if cuda_ok and bf16_ok else "float16"

        rank = int(os.environ.get("RANK", -1))
        s.is_master = rank in [-1, 0]
        if rank != -1:
            init_process_group(s.backend)

            local_rank = int(os.environ["LOCAL_RANK"])
            device = f"cuda:{local_rank}"
            tc.cuda.set_device(device)

            n_proc = int(os.environ["WORLD_SIZE"])
            assert s.grad_acc_steps % n_proc == 0
            s.grad_acc_steps //= n_proc

        tc.manual_seed(1338 + rank)
        tc.backends.cuda.matmul.allow_tf32 = True
        tc.backends.cudnn.allow_tf32 = True

        s.ctx = (
            nullcontext()
            if device == "cpu"
            else tc.amp.autocast("cuda", getattr(tc, dtype))
        )
        s.scaler = tc.amp.grad_scaler.GradScaler(device, enabled=dtype == "float16")

        # ============================================

        g1 = [p for p in m.parameters() if p.dim() >= 2]
        g2 = [p for p in m.parameters() if p.dim() < 2]
        params = [
            {"params": g1, "weight_decay": s.w_decay},
            {"params": g2, "weight_decay": 0.0},
        ]
        s.opt = tc.optim.AdamW(params, lr=s.lr, betas=s.betas, fused=s.fused)

        if os.path.exists(s.save_path):
            r = tc.load(s.save_path)
            m.load_state_dict(r["model"])
            s.opt.load_state_dict(r["opt"])
            s.iter, s.best_loss = r["iter"], r["best_loss"]
        else:
            s.iter, s.best_loss = 0, np.inf

        # ============================================

        m.to(device)
        if rank != -1:
            m = DDP(m, device_ids=[local_rank])
        if s.compile:
            print("compiling")
            m.compile()

        if s.log_id:
            wandb.init(project="LLMTrainer", name=s.log_id)

        s.device = device
        s.model = m

    def get_batch(s, split="train"):
        B, T = s.batch_size, s.block_size
        data = np.memmap(s.data_path, np.uint16, mode="r").astype(np.int64)
        a = int(len(data) * 0.9)
        data = data[:a] if split == "train" else data[a:]

        idx = tc.randint(len(data) - T, (B,))
        x = tc.stack([tc.from_numpy(data[i : i + T]) for i in idx])
        y = tc.stack([tc.from_numpy(data[i + 1 : i + 1 + T]) for i in idx])
        if s.device != "cpu":
            x = x.pin_memory().to(s.device, non_blocking=True)
            y = y.pin_memory().to(s.device, non_blocking=True)
        return x, y

    @tc.no_grad()
    def get_losses(s):
        m = s.model
        m.eval()
        res = {}
        for split in ["train", "val"]:
            losses = tc.zeros(s.n_eval)
            for k in range(s.n_eval):
                x, y = s.get_batch(split)
                with s.ctx:
                    loss: tc.Tensor = m(x, y)[1]
                losses[k] = loss.item()
            res[split] = losses.mean()
        m.train()
        return res

    def get_cos_lr(s, i):
        min_lr = s.lr * 0.1
        if i < s.i_warmup:
            return s.lr * (i + 1) / (s.i_warmup + 1)
        if i > s.i_max:
            return min_lr
        r = (i - s.i_warmup) / (s.i_max - s.i_warmup)
        assert 0 <= r <= 1
        c = 0.5 * (1.0 + np.cos(np.pi * r))
        return min_lr + c * (s.lr - min_lr)

    def train(s):
        x, y = s.get_batch()
        m = s.model
        is_ddp = isinstance(m, DDP)
        # raw_m = m.module if is_ddp else m

        while s.iter < s.i_max:
            lr = s.get_cos_lr(s.iter)
            for p in s.opt.param_groups:
                p["lr"] = lr

            if s.iter % s.i_eval == 0 and s.is_master:
                losses = s.get_losses()
                print(f"{s.iter} {losses}")
                if s.log_id:
                    wandb.log(
                        {
                            "iter": s.iter,
                            "train/loss": losses["train"],
                            "val/loss": losses["val"],
                            "lr": lr,
                        }
                    )
                if losses["val"] < s.best_loss:
                    s.best_loss = losses["val"]
                    if s.iter:
                        obj = {
                            "model": m.state_dict(),
                            "opt": s.opt.state_dict(),
                            "iter": s.iter,
                            "best_loss": s.best_loss,
                        }
                        tc.save(obj, s.save_path)

            for j in range(s.grad_acc_steps):
                if is_ddp:
                    m.require_backward_grad_sync = j == s.grad_acc_steps - 1
                with s.ctx:
                    logits, loss = m(x, y)
                    loss /= s.grad_acc_steps
                x, y = s.get_batch()
                s.scaler.scale(loss).backward()
            if s.grad_clip:
                s.scaler.unscale_(s.opt)
                nn.utils.clip_grad_norm_(m.parameters(), s.grad_clip)
            s.scaler.step(s.opt)
            s.scaler.update()
            s.opt.zero_grad()

            s.iter += 1

        if is_ddp:
            destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant