PyTorch Coding Round
PyTorch Questions
LLM-Specific
GPT2 Code (Imp to Understand)
import torch import torch.nn as nn class CausalSelfAttention(nn.Module): """ Multi-Head Causal Self-Attention: - Splits embedding into multiple heads - Computes scaled dot-product attention - Applies a causal mask so tokens cannot attend to future positions """ def __init__(self, embed_size, num_heads, dropout=0.1): super().__init__() assert embed_size % num_heads == 0, "embed_size must be divisible by num_heads" self.embed_size = embed_size self.num_heads = num_heads self.head_dim = embed_size // num_heads # Transform from embed -> queries, keys, values self.q_proj = nn.Linear(embed_size, embed_size) self.k_proj = nn.Linear(embed_size, embed_size) self.v_proj = nn.Linear(embed_size, embed_size) # Output projection back to embed dimension self.out_proj = nn.Linear(embed_size, embed_size) self.dropout = nn.Dropout(dropout) def forward(self, x): """ x: (batch_size, seq_len, embed_size) """ bsz, seq_len, _ = x.size() # Compute Q, K, V q = self.q_proj(x) # (bsz, seq_len, embed_size) k = self.k_proj(x) v = self.v_proj(x) # Reshape to (bsz, seq_len, num_heads, head_dim) then transpose to # (bsz, num_heads, seq_len, head_dim) for multi-head attention q = q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Scale queries q = q / (self.head_dim ** 0.5) # Compute attention scores: (bsz, num_heads, seq_len, seq_len) att_scores = torch.matmul(q, k.transpose(-2, -1)) # Causal mask: restrict attention to current + previous positions causal_mask = torch.ones((seq_len, seq_len), device=x.device).tril() # shape (seq_len, seq_len), 1.0 in lower-triangular, 0.0 elsewhere # We expand to [1, 1, seq_len, seq_len] for broadcasting att_scores = att_scores.masked_fill(causal_mask == 0, float('-inf')) # Attention weights att_weights = torch.softmax(att_scores, dim=-1) att_weights = self.dropout(att_weights) # Weighted sum of values out = torch.matmul(att_weights, v) # (bsz, num_heads, seq_len, head_dim) # Recombine heads out = out.transpose(1, 2).contiguous().view(bsz, seq_len, self.embed_size) # Final linear projection out = self.out_proj(out) return out class FeedForward(nn.Module): """ Position-wise Feed Forward layer, typically: - Linear (embed_size) -> (4*embed_size) - Activation (GELU or ReLU) - Linear (4*embed_size) -> (embed_size) """ def __init__(self, embed_size, expansion_factor=4, dropout=0.1): super().__init__() inner_dim = expansion_factor * embed_size self.net = nn.Sequential( nn.Linear(embed_size, inner_dim), nn.GELU(), nn.Linear(inner_dim, embed_size), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class GPTBlock(nn.Module): """ Single block of GPT-style transformer: - LayerNorm - Causal Self-Attention + skip - LayerNorm - FeedForward + skip """ def __init__(self, embed_size, num_heads, expansion_factor=4, dropout=0.1): super().__init__() self.ln1 = nn.LayerNorm(embed_size) self.attn = CausalSelfAttention(embed_size, num_heads, dropout) self.ln2 = nn.LayerNorm(embed_size) self.ff = FeedForward(embed_size, expansion_factor, dropout) def forward(self, x): # Causal Self-Attention sub-layer attn_out = self.attn(self.ln1(x)) # apply LN first x = x + attn_out # residual connection # Feed Forward sub-layer ff_out = self.ff(self.ln2(x)) x = x + ff_out # residual connection return x class GPTModel(nn.Module): """ Decoder-Only Transformer, GPT-style. Typical GPT-3 scale settings (for reference, not shown in code default): - vocab_size ~ 50k - embed_size ~ 12,288 - num_heads ~ 96 - n_layers ~ 96 - block_size (context length) ~ 2048 """ def __init__( self, vocab_size, block_size, # maximum sequence length embed_size=768, # smaller default for demonstration num_heads=12, # smaller default num_layers=12, # smaller default expansion_factor=4, dropout=0.1, pad_id=0 ): super().__init__() self.vocab_size = vocab_size self.block_size = block_size self.embed_size = embed_size self.pad_id = pad_id # Token + Position embeddings self.token_emb = nn.Embedding(vocab_size, embed_size) self.pos_emb = nn.Embedding(block_size, embed_size) # Transformer blocks self.blocks = nn.ModuleList([ GPTBlock(embed_size, num_heads, expansion_factor, dropout) for _ in range(num_layers) ]) # Final layer norm self.ln_f = nn.LayerNorm(embed_size) # Output head (projection to vocab) self.head = nn.Linear(embed_size, vocab_size, bias=False) def forward(self, idx): """ idx: LongTensor of shape (batch_size, sequence_length) with token IDs Returns: logits => (batch_size, sequence_length, vocab_size) """ bsz, seq_len = idx.size() assert seq_len <= self.block_size, "Sequence too long!" # Create token + positional embeddings pos = torch.arange(0, seq_len, dtype=torch.long, device=idx.device) pos = pos.unsqueeze(0) # shape (1, seq_len) token_embeddings = self.token_emb(idx) # (bsz, seq_len, embed_size) position_embeddings = self.pos_emb(pos) # (1, seq_len, embed_size) x = token_embeddings + position_embeddings # (bsz, seq_len, embed_size) # Pass through each GPT block for block in self.blocks: x = block(x) # Final layer norm + linear head x = self.ln_f(x) logits = self.head(x) # (bsz, seq_len, vocab_size) return logits if __name__ == "__main__": # Example usage device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Hyperparams for demonstration (much smaller than GPT-3 scale) vocab_size = 50257 # typical GPT-2/3 Byte-Pair Encoding block_size = 128 # max context length embed_size = 768 num_heads = 12 num_layers = 6 batch_size = 2 seq_len = 64 model = GPTModel( vocab_size=vocab_size, block_size=block_size, embed_size=embed_size, num_heads=num_heads, num_layers=num_layers, expansion_factor=4, dropout=0.1 ).to(device) # Create random input tokens x = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long, device=device) # Forward pass logits = model(x) print("Logits shape:", logits.shape) # (2, 64, 50257)
DDP Code (Distributed Data Parallelism)
import os import torch import torch.nn as nn import torch.optim as optim import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, TensorDataset, DistributedSampler def get_dummy_data(batch_size=8, input_dim=10, num_samples=64): data = torch.randn(num_samples, input_dim) labels = torch.randint(0, 2, (num_samples,)) dataset = TensorDataset(data, labels) return dataset class SimpleNet(nn.Module): def __init__(self, input_dim=10, hidden_dim=16, num_classes=2): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, num_classes) ) def forward(self, x): return self.net(x) def ddp_worker(rank, world_size): """ rank: which process is this? world_size: total number of processes """ print(f"[Process {rank}] Initializing...") # 1) Initialize the process group dist.init_process_group( backend='nccl', # for GPU communication init_method='tcp://localhost:12355', # or some other URL world_size=world_size, rank=rank ) # 2) Set device for this process device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") torch.cuda.set_device(device) # ensure the device is pinned to the rank # 3) Create model, move to device model = SimpleNet().to(device) ddp_model = DDP(model, device_ids=[rank], output_device=rank) # 4) Prepare data with distributed sampler dataset = get_dummy_data(num_samples=64) sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True) dataloader = DataLoader(dataset, batch_size=8, sampler=sampler) # 5) Standard training setup optimizer = optim.SGD(ddp_model.parameters(), lr=0.01) criterion = nn.CrossEntropyLoss() # 6) Train ddp_model.train() for epoch in range(2): sampler.set_epoch(epoch) # ensures each epoch shuffles uniquely for each rank for batch_data, batch_labels in dataloader: batch_data, batch_labels = batch_data.to(device), batch_labels.to(device) 2 optimizer.zero_grad() outputs = ddp_model(batch_data) loss = criterion(outputs, batch_labels) loss.backward() optimizer.step() if rank == 0: # Usually only rank=0 prints print(f"[Rank {rank}] Epoch {epoch}, Loss: {loss.item():.4f}") # 7) Clean up dist.destroy_process_group() def run_ddp_example(world_size): """ Launch N processes, each running ddp_worker """ mp.spawn(ddp_worker, args=(world_size,), nprocs=world_size, join=True) if __name__ == "__main__": # Suppose you have 2 GPUs world_size = 2 run_ddp_example(world_size)
AFTER YOU ARE DONE WITH ABOVE QUESTIONS:
For an interview focusing on creative problem-solving and being clever with PyTorch, you can expect questions that test your ability to:
1. Optimize Model Training & Memory Usage
• How would you efficiently train a large model on limited GPU memory?
• How do mixed precision training and gradient checkpointing work?
• Implement a custom torch.autograd.Function to save memory.
2. Custom Implementations & PyTorch Internals
• Implement a custom activation function with PyTorch.
• How does torch.nn.Module work internally?
• Explain how PyTorch’s autograd computes gradients.
3. Efficient Tensor Operations
• Optimize a given PyTorch operation to minimize GPU memory and maximize speed.
• Implement a function that computes a rolling window mean using efficient tensor operations.
• Why should we prefer torch.einsum over explicit loops?
4. Parallelism & Multi-GPU Training
• Implement a simple data parallel training loop.
• What is the difference between torch.nn.DataParallel and torch.nn.parallel.DistributedDataParallel?
• How do you handle synchronization issues in multi-GPU training?
5. Custom Loss Functions & Gradients
• Implement a custom loss function that requires second-order gradients.
• How do you stop gradients from flowing through part of the computation graph?
6. Debugging & Profiling Performance Issues
• How would you debug a PyTorch model that is training extremely slowly?
• Use torch.profiler to identify bottlenecks in a model’s training loop.
7. Reinforcement Learning / Optimization-Specific Questions
• Implement a basic reinforcement learning policy network in PyTorch.
• How would you use PyTorch for differentiable optimization tasks?
Concepts
MUST WATCH
If You Have Time