Learn transformer with makemore and torch

deep learning
python
Author

Xiaochuan Yang

Published

February 16, 2024

[last modified on 2024-03-22]

makemore

Karpathy’s makemore is an end-to-end python application that takes in a text file, then generate new text similar to what’s given. The project makemore is part of his neural networks: zero to hero lecture series which I recommend to all.

The example dataset in the repo is a large collection of baby names (about 30k names) and the applicaiton trains a transformer model to learn the mechanism of naming things, then sample from the learned model.

The purpose of this post is three fold:
- gain familiarity with transformer
- summarise essential torch functionalities and workflow for building neural nets application
- use minimal tools e.g argparse to produce a command line interface for the app

disclaimer

all the code chunks in this post (including docstring) are taken from the makemore repository

torch basics

import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

Simplistically, a neural net application is built on two major components: a dataset and a model. The former shapes the behaviour of the latter by optimising some objective function. torch allows us to do both fairly easily and straightforward.

Dataset

Dataset class is a container like a sequence. To define a custom Dataset class, one must implement __init__, __len__, __getitem__, together with transformations relevant to the particular use case of the application. Exampe:

class CharDataset(Dataset):

    def __init__(self, words:list[str], chars:str, max_word_length:int):
        self.words = words
        self.chars = chars
        self.max_word_length = max_word_length
        self.stoi = {ch:i+1 for i,ch in enumerate(chars)}
        self.itos = {i:s for s,i in self.stoi.items()} # inverse mapping

    def __len__(self):
        return len(self.words)

    def contains(self, word):
        return word in self.words

    def get_vocab_size(self):
        return len(self.chars) + 1 # all the possible characters and special 0 token

    def get_output_length(self):
        return self.max_word_length + 1 # <START> token followed by words

    def encode(self, word):
        ix = torch.tensor([self.stoi[w] for w in word], dtype=torch.long)
        return ix

    def decode(self, ix):
        word = ''.join(self.itos[i] for i in ix)
        return word

    def __getitem__(self, idx):
        word = self.words[idx]
        ix = self.encode(word)
        x = torch.zeros(self.max_word_length + 1, dtype=torch.long)
        y = torch.zeros(self.max_word_length + 1, dtype=torch.long)
        x[1:1+len(ix)] = ix
        y[:len(ix)] = ix
        y[len(ix)+1:] = -1 # index -1 will mask the loss at the inactive locations
        return x, y

The custom transformations in this example consists of mapping character to integer, aka tokenisation. There is one special token 0 representing both the start of name and end of name. We can loop through or get at index a CharDataset because we have implemented sufficient dunder methods for this.

import string

examples = CharDataset(['emma', 'richard'], string.ascii_letters, 10)
for e in examples:
    print(e)
(tensor([ 0,  5, 13, 13,  1,  0,  0,  0,  0,  0,  0]), tensor([ 5, 13, 13,  1,  0, -1, -1, -1, -1, -1, -1]))
(tensor([ 0, 18,  9,  3,  8,  1, 18,  4,  0,  0,  0]), tensor([18,  9,  3,  8,  1, 18,  4,  0, -1, -1, -1]))
print(examples[0])
(tensor([ 0,  5, 13, 13,  1,  0,  0,  0,  0,  0,  0]), tensor([ 5, 13, 13,  1,  0, -1, -1, -1, -1, -1, -1]))

Let’s break down the output tuple x,y.

x tensor

It starts with 0 token, then each input character gets mapped to an integer from 1 to 26, with trailing zeros if the length of name is less than max_word_length (set to be max length of names in the dataset)

y tensor

By definition, it is the same as x shifted by 1 token to the left, modulo extra subtleties with the trailing -1. What is going on here? Well, in language modelling, the learning task is next token prediction, so at index idx such that x[idx].item()!=0, given x[:idx+1], the goal is to predict x[idx+1] which by definition is nothing but y[idx]. If x[idx].item()==0, then there is nothing to predict (name finished), we set by convention y[idx]=-1.

Our ultimate goal is to build and train a neural net which can learn from the 30k x,y tuples a good way of sampling next token given some context. Concetely, throw a 0 token at the model and let the model sample the next token t1, concatenate it with 0, then sample next given [0,t1], until a 0 token is sampled which means we have arrived at the end of a name. Repeat this to produce as many names as we want. We’ll get back to inference later.

DataLoader

For efficiency’s sake, it is beneficial to stack multiple examples together, aka mini-batch, and process them all at once. The DataLoader class is meant to help us with this. Example:

class InfiniteDataLoader:
    """
    this is really hacky and I'm not proud of it, but there doesn't seem to be
    a better way in PyTorch to just create an infinite dataloader?
    """

    def __init__(self, dataset, **kwargs):
        train_sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=int(1e10))
        self.train_loader = DataLoader(dataset, sampler=train_sampler, **kwargs)
        self.data_iter = iter(self.train_loader)

    def next(self):
        try:
            batch = next(self.data_iter)
        except StopIteration: # this will technically only happen after 1e10 samples... (i.e. basically never)
            self.data_iter = iter(self.train_loader)
            batch = next(self.data_iter)
        return batch
dataset = CharDataset(['emma', 'richard', 'ben', 'steve'],string.ascii_letters, 10)
batch_loader = InfiniteDataLoader(dataset, batch_size=2)
for _ in range(2):
    X,Y = batch_loader.next()
    print(f'{X=}') # B,T = batch_size, max_word_length+1
    print('-'*60)
    print(f'{Y=}') # (B,T)
    print('*'*60)
X=tensor([[ 0, 18,  9,  3,  8,  1, 18,  4,  0,  0,  0],
        [ 0, 19, 20,  5, 22,  5,  0,  0,  0,  0,  0]])
------------------------------------------------------------
Y=tensor([[18,  9,  3,  8,  1, 18,  4,  0, -1, -1, -1],
        [19, 20,  5, 22,  5,  0, -1, -1, -1, -1, -1]])
************************************************************
X=tensor([[ 0,  2,  5, 14,  0,  0,  0,  0,  0,  0,  0],
        [ 0, 19, 20,  5, 22,  5,  0,  0,  0,  0,  0]])
------------------------------------------------------------
Y=tensor([[ 2,  5, 14,  0, -1, -1, -1, -1, -1, -1, -1],
        [19, 20,  5, 22,  5,  0, -1, -1, -1, -1, -1]])
************************************************************

Tensor ops and View

Tensor is the most fundamental data structure for neural nets.

A tensor is a collection of numbers index by tuple of non-negative integers. In the above, we’ve seen that a batch X is 2d tensor wit shape (B,T), we can index X[b,t] for b in range(B) and t in range(T).

torch provides optimised tensor operations and auto differentiation engine. Rather than understanding low level optimisations (parallel programming as in e.g. cuda kernels), we just take these optimisations for granted in this post and we only concerned with what we can build with Tensor.

torch ops usually create new tensor as output. e.g. torch.cat, torch.stack.

In contrast, Tensor View avoids wasteful data copy.

Here are a few common view ops.

  • basic slicing and indexing (following numpy)
  • view
  • split
  • unsqueeze
  • expand

For a full list of view ops, refer to docs.

torch.arange(24).view(3,4,2)
tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11],
         [12, 13],
         [14, 15]],

        [[16, 17],
         [18, 19],
         [20, 21],
         [22, 23]]])
assert torch.equal(torch.arange(60).view(3,4,5),torch.arange(60).view(3,4,-1))
torch.tril(torch.ones(4,4))
tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])
torch.randn(3,4).split(2,dim=1)  # a tuple of 4/2 tensors of shape (3,2) 
(tensor([[-0.1093,  0.0426],
         [ 2.5843,  0.2994],
         [-0.2158, -1.7210]]),
 tensor([[-1.2973,  0.2321],
         [ 1.1551,  0.2394],
         [ 0.4124,  0.1518]]))
att = torch.arange(16).view(4,4)
att.masked_fill(torch.tril(torch.ones(4,4))==0, -99)
tensor([[  0, -99, -99, -99],
        [  4,   5, -99, -99],
        [  8,   9,  10, -99],
        [ 12,  13,  14,  15]])
assert torch.randn(24).view(2,3,4).transpose(1,2).shape == (2,4,3)
assert torch.randn(5).unsqueeze(1).shape == (5,1)
assert torch.cat([torch.ones(3,3), torch.arange(9).view(3,3)], dim=1).shape == (3,3+3)
assert torch.stack([torch.ones(3,3), torch.arange(9).view(3,3)], dim=1).shape == (3,2,3)

Module

nn.Module is the base class for all neural nets, which, mathemtically, are just functions taking Tensor as input and compute the output as another Tensor. Just as we can compose functions, we can compose Module’s to build complicated achitechture.

torch offers built-in modules as LEGO pieces. Example:

  • nn.Linear: random linear function that takes input dim and output dim as arguments
  • nn.LayerNorm: standardise a tensor over shape (pass in as argument), then multiplied by weight, then add bias (default elementwise 1 and 0).
  • nn.RELU: a parameter-less elementwise non-linear function
  • nn.Embedding: just a random lookup table of shape B,T, n_embd if the input tensor is of shape B,T

some container classes:

  • nn.ModuleDict: pass in a dictionary of name:instance pairs
  • nn.ModuleList: pass in a list of instances

custom Module must implement forward method. Example:

class NewGELU(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
    """
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

build Transformer

Transformer is a custom module built on attention blocks, which themselves are built on multi-head attention layers. What we have here is a decoder/GPT.

more variants

let’s build from the layer level all the way to transformer. here is our model config. Notice that we should have a dataset first, infer from there the necessary config parameters for transformer. For example, the vocab_size is the size of all the possible tokens in the dataset, a value that may vary from dataset to dataset.

from dataclasses import dataclass

@dataclass
class ModelConfig:
    block_size: int = None # length of the input sequences of integers
    vocab_size: int = None # the input integers are in range [0 .. vocab_size -1]
    # parameters below control the sizes of each model slightly differently
    n_layer: int = 4
    n_embd: int = 64
    n_embd2: int = 64
    n_head: int = 4

self attention layer

The input of attn is 3d tensor of shape (B,T,C). The forward step of attention undergoes a few steps

  • linear C-> 3C, split into qkv
  • view C -> (nh, C/nh) as heads
  • compute attention matrix with qk, which is used to average v
  • re-assemble all heads to (B,T,C) then project
class CausalSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.c_proj(y)
        return y

This is pretty self-explanatory. don’t forget super().__init__()

A bit of caution here in line 35: transpose is a non-contiguous view op of the base tensor, which has performance penalty if not making it contingous. Also:

x = torch.randn(2,3,2)
torch.ne(x.transpose(0,1).contiguous().view(3,4), x.view(3,4))
tensor([[False, False,  True,  True],
        [ True,  True,  True,  True],
        [ True,  True, False, False]])

block

Stack attention layer and mlp, with layer normalization as well as residual connections to facilitate training.

class Block(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.ModuleDict(dict(
            c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj  = nn.Linear(4 * config.n_embd, config.n_embd),
            act     = NewGELU(),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x))) # MLP forward

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x

the entire thing

token embedding + potitional embedding, then pass through layers of blocks, layer normalization, finally projection.

class Transformer(nn.Module):
    """ Transformer Language Model, exactly as seen in GPT-2 """

    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # report number of parameters (note we don't count the decoder parameters in lm_head)
        n_params = sum(p.numel() for p in self.transformer.parameters())
        print("number of parameters: %.2fM" % (n_params/1e6,))

    def get_block_size(self):
        return self.block_size

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return logits, loss

Note: F.cross_entropy caveats

  • expect ground truth of shape (D,)
  • expect predictions of shape (D,C) where C is the number of classes
  • use logtis for prediction (which is one F.softmax away from probability)
  • ignore_index=-1 matches -1 in the definition of yin CharDataset

evaluation

the evaluation code is standard. Note:

  • DataLoader can be looped through.
  • batch is a tuple of tensors each of shape (B, T)

I am however not sure what are the benefits of entering the inference mode, given that code has already switched on model.eval(). TODO!

@torch.inference_mode()
def evaluate(model, dataset, batch_size=50, max_batches=None):
    model.eval()
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=0)
    losses = []
    for i, batch in enumerate(loader):
        batch = [t.to(args.device) for t in batch]
        X, Y = batch
        logits, loss = model(X, Y)
        losses.append(loss.item())
        if max_batches is not None and i >= max_batches:
            break
    mean_loss = torch.tensor(losses).mean().item()
    model.train() # reset model back to training mode
    return mean_loss

inference

Given some context, generate things with trained model (no gradient update) one token at a time. This involves concat a generated token with the context, then pass the combined new context back into the model to get the next token, and so on …

@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
    """
    Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
    the sequence max_new_tokens times, feeding the predictions back into the model each time.
    Most likely you'll want to make sure to be in model.eval() mode of operation for this.
    """
    block_size = model.get_block_size()
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
        # forward the model to get the logits for the index in the sequence
        logits, _ = model(idx_cond)
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)
        # either sample from the distribution or take the most likely element
        if do_sample:
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            _, idx_next = torch.topk(probs, k=1, dim=-1)
        # append sampled index to the running sequence and continue
        idx = torch.cat((idx, idx_next), dim=1)

    return idx

tie it all up

import os
import sys
import time
import math
import argparse

argparse

argparse module is in the python standard library centring around the ArgumentParser class, which has the add_argument() method and parse_args() method.

When no specific value is given for a flag, the parser will use the default value. Alternatively, one can use action to store some specific value for instance a boolean. If neither exists, then the value of the flag is None.

Calling parse_args() returns a Namespace object which stores arguments and values.

Recall that the python built-in vars calls the __dict__ of a class.

# parse command line args
parser = argparse.ArgumentParser(description="Make More")
# system/input/output
parser.add_argument('--input-file', '-i', type=str, default='names.txt', help="input file with things one per line")
parser.add_argument('--work-dir', '-o', type=str, default='out', help="output working directory")
parser.add_argument('--resume', action='store_true', help="when this flag is used, we will resume optimization from existing model in the workdir")
parser.add_argument('--sample-only', action='store_true', help="just sample from the model and quit, don't train")
parser.add_argument('--num-workers', '-n', type=int, default=4, help="number of data workers for both train/test")
parser.add_argument('--max-steps', type=int, default=-1, help="max number of optimization steps to run for, or -1 for infinite.")
parser.add_argument('--device', type=str, default='cpu', help="device to use for compute, examples: cpu|cuda|cuda:2|mps")
parser.add_argument('--seed', type=int, default=3407, help="seed")
# sampling
parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k")
# model
parser.add_argument('--type', type=str, default='transformer', help="model class type to use, bigram|mlp|rnn|gru|bow|transformer")
parser.add_argument('--n-layer', type=int, default=4, help="number of layers")
parser.add_argument('--n-head', type=int, default=4, help="number of heads (in a transformer)")
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the model")
parser.add_argument('--n-embd2', type=int, default=64, help="number of feature channels elsewhere in the model")
# optimization
parser.add_argument('--batch-size', '-b', type=int, default=32, help="batch size during optimization")
parser.add_argument('--learning-rate', '-l', type=float, default=5e-4, help="learning rate")
parser.add_argument('--weight-decay', '-w', type=float, default=0.01, help="weight decay")

args = parser.parse_args()
print(vars(args))

inits

we run things in order

  • init a dataset and let it determine part of model config
  • init model with config
  • init optimiser and dataloader
# system inits
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
os.makedirs(args.work_dir, exist_ok=True)
writer = SummaryWriter(log_dir=args.work_dir)

# init datasets
train_dataset, test_dataset = create_datasets(args.input_file)
vocab_size = train_dataset.get_vocab_size()
block_size = train_dataset.get_output_length()
print(f"dataset determined that: {vocab_size=}, {block_size=}")

# init model
config = ModelConfig(vocab_size=vocab_size, block_size=block_size,
                   n_layer=args.n_layer, n_head=args.n_head,
                   n_embd=args.n_embd, n_embd2=args.n_embd2)
if args.type == 'transformer':
    model = Transformer(config)
elif args.type == 'bigram':
    model = Bigram(config)
elif args.type == 'mlp':
    model = MLP(config)
elif args.type == 'rnn':
    model = RNN(config, cell_type='rnn')
elif args.type == 'gru':
    model = RNN(config, cell_type='gru')
elif args.type == 'bow':
    model = BoW(config)
else:
    raise ValueError(f'model type {args.type} is not recognized')
model.to(args.device)
print(f"model #params: {sum(p.numel() for p in model.parameters())}")
if args.resume or args.sample_only: # note: if we sample-only then we also assume we are resuming
    print("resuming from existing model in the workdir")
    model.load_state_dict(torch.load(os.path.join(args.work_dir, 'model.pt')))
if args.sample_only:
    print_samples(num=50)
    sys.exit()

# init optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(0.9, 0.99), eps=1e-8)

# init dataloader
batch_loader = InfiniteDataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers)

training loop

runs forever, save the model.state_dict() when a lower test loss is achieved

# training loop
best_loss = None
step = 0
while True:

    t0 = time.time()

    # get the next batch, ship to device, and unpack it to input and target
    batch = batch_loader.next()
    batch = [t.to(args.device) for t in batch]
    X, Y = batch

    # feed into the model
    logits, loss = model(X, Y)

    # calculate the gradient, update the weights
    model.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # wait for all CUDA work on the GPU to finish then calculate iteration time taken
    if args.device.startswith('cuda'):
        torch.cuda.synchronize()
    t1 = time.time()

    # logging
    if step % 10 == 0:
        print(f"step {step} | loss {loss.item():.4f} | step time {(t1-t0)*1000:.2f}ms")

    # evaluate the model
    if step > 0 and step % 500 == 0:
        train_loss = evaluate(model, train_dataset, batch_size=100, max_batches=10)
        test_loss  = evaluate(model, test_dataset,  batch_size=100, max_batches=10)
        writer.add_scalar("Loss/train", train_loss, step)
        writer.add_scalar("Loss/test", test_loss, step)
        writer.flush()
        print(f"step {step} train loss: {train_loss} test loss: {test_loss}")
        # save the model to disk if it has improved
        if best_loss is None or test_loss < best_loss:
            out_path = os.path.join(args.work_dir, "model.pt")
            print(f"test loss {test_loss} is the best so far, saving model to {out_path}")
            torch.save(model.state_dict(), out_path)
            best_loss = test_loss

    # sample from the model
    if step > 0 and step % 200 == 0:
        print_samples(num=10)

    step += 1
    # termination conditions
    if args.max_steps >= 0 and step >= args.max_steps:
        break