import dataclasses
import functools
import os
Ā
import datasets
import tokenizers
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import tqdm
from torch import Tensor
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Ā Ā Ā Ā apply_activation_checkpointing,
Ā Ā Ā Ā checkpoint_wrapper,
)
from torch.distributed.checkpoint import load, save
from torch.distributed.checkpoint.state_dict import (
Ā Ā Ā Ā StateDictOptions,
Ā Ā Ā Ā get_state_dict,
Ā Ā Ā Ā set_state_dict,
)
from torch.distributed.fsdp import (
Ā Ā Ā Ā CPUOffloadPolicy,
Ā Ā Ā Ā FSDPModule,
Ā Ā Ā Ā MixedPrecisionPolicy,
Ā Ā Ā Ā fully_shard,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data.distributed import DistributedSampler
Ā
Ā
# Build the model
@dataclasses.dataclass
class LlamaConfig:
Ā Ā Ā Ā “”“Define Llama model hyperparameters.”“”
Ā Ā Ā Ā vocab_size: int = 50000Ā Ā # Size of the tokenizer vocabulary
Ā Ā Ā Ā max_position_embeddings: int = 2048Ā Ā # Maximum sequence length
Ā Ā Ā Ā hidden_size: int = 768Ā Ā # Dimension of hidden layers
Ā Ā Ā Ā intermediate_size: int = 4*768Ā Ā # Dimension of MLP’s hidden layer
Ā Ā Ā Ā num_hidden_layers: int = 12Ā Ā # Number of transformer layers
Ā Ā Ā Ā num_attention_heads: int = 12Ā Ā # Number of attention heads
Ā Ā Ā Ā num_key_value_heads: int = 3Ā Ā # Number of key-value heads for GQA
Ā
Ā
class RotaryPositionEncoding(nn.Module):
Ā Ā Ā Ā “”“Rotary position encoding.”“”
Ā
Ā Ā Ā Ā def __init__(self, dim: int, max_position_embeddings: int) -> None:
Ā Ā Ā Ā Ā Ā Ā Ā “”“Initialize the RotaryPositionEncoding module.
Ā
Ā Ā Ā Ā Ā Ā Ā Ā Args:
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā dim: The hidden dimension of the input tensor to which RoPE is applied
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā max_position_embeddings: The maximum sequence length of the input tensor
Ā Ā Ā Ā Ā Ā Ā Ā ““”
Ā Ā Ā Ā Ā Ā Ā Ā super().__init__()
Ā Ā Ā Ā Ā Ā Ā Ā self.dim = dim
Ā Ā Ā Ā Ā Ā Ā Ā self.max_position_embeddings = max_position_embeddings
Ā Ā Ā Ā Ā Ā Ā Ā # compute a matrix of n\theta_i
Ā Ā Ā Ā Ā Ā Ā Ā N = 10_000.0
Ā Ā Ā Ā Ā Ā Ā Ā inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim))
Ā Ā Ā Ā Ā Ā Ā Ā inv_freq = torch.cat((inv_freq, inv_freq), dim=–1)
Ā Ā Ā Ā Ā Ā Ā Ā position = torch.arange(max_position_embeddings)
Ā Ā Ā Ā Ā Ā Ā Ā sinusoid_inp = torch.outer(position, inv_freq)
Ā Ā Ā Ā Ā Ā Ā Ā # save cosine and sine matrices as buffers, not parameters
Ā Ā Ā Ā Ā Ā Ā Ā self.register_buffer(“cos”, sinusoid_inp.cos())
Ā Ā Ā Ā Ā Ā Ā Ā self.register_buffer(“sin”, sinusoid_inp.sin())
Ā
Ā Ā Ā Ā def forward(self, x: Tensor) -> Tensor:
Ā Ā Ā Ā Ā Ā Ā Ā “”“Apply RoPE to tensor x.
Ā
Ā Ā Ā Ā Ā Ā Ā Ā Args:
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā x: Input tensor of shape (batch_size, seq_length, num_heads, head_dim)
Ā
Ā Ā Ā Ā Ā Ā Ā Ā Returns:
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Output tensor of shape (batch_size, seq_length, num_heads, head_dim)
Ā Ā Ā Ā Ā Ā Ā Ā ““”
Ā Ā Ā Ā Ā Ā Ā Ā batch_size, seq_len, num_heads, head_dim = x.shape
Ā Ā Ā Ā Ā Ā Ā Ā device = x.device
Ā Ā Ā Ā Ā Ā Ā Ā dtype = x.dtype
Ā Ā Ā Ā Ā Ā Ā Ā # transform the cosine and sine matrices to 4D tensor and the same dtype as x
Ā Ā Ā Ā Ā Ā Ā Ā cos = self.cos.to(device, dtype)[:seq_len].view(1, seq_len, 1, –1)
Ā Ā Ā Ā Ā Ā Ā Ā sin = self.sin.to(device, dtype)[:seq_len].view(1, seq_len, 1, –1)
Ā Ā Ā Ā Ā Ā Ā Ā # apply RoPE to x
Ā Ā Ā Ā Ā Ā Ā Ā x1, x2 = x.chunk(2, dim=–1)
Ā Ā Ā Ā Ā Ā Ā Ā rotated = torch.cat((–x2, x1), dim=–1)
Ā Ā Ā Ā Ā Ā Ā Ā output = (x * cos) + (rotated * sin)
Ā Ā Ā Ā Ā Ā Ā Ā return output
Ā
Ā
class LlamaAttention(nn.Module):
Ā Ā Ā Ā “”“Grouped-query attention with rotary embeddings.”“”
Ā
Ā Ā Ā Ā def __init__(self, config: LlamaConfig) -> None:
Ā Ā Ā Ā Ā Ā Ā Ā super().__init__()
Ā Ā Ā Ā Ā Ā Ā Ā self.hidden_size = config.hidden_size
Ā Ā Ā Ā Ā Ā Ā Ā self.num_heads = config.num_attention_heads
Ā Ā Ā Ā Ā Ā Ā Ā self.head_dim = self.hidden_size // self.num_heads
Ā Ā Ā Ā Ā Ā Ā Ā self.num_kv_heads = config.num_key_value_headsĀ Ā # GQA: H_kv < H_q
Ā
Ā Ā Ā Ā Ā Ā Ā Ā # hidden_size must be divisible by num_heads
Ā Ā Ā Ā Ā Ā Ā Ā assert (self.head_dim * self.num_heads) == self.hidden_size
Ā
Ā Ā Ā Ā Ā Ā Ā Ā # Linear layers for Q, K, V projections
Ā Ā Ā Ā Ā Ā Ā Ā self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
Ā Ā Ā Ā Ā Ā Ā Ā self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
Ā Ā Ā Ā Ā Ā Ā Ā self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
Ā Ā Ā Ā Ā Ā Ā Ā self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
Ā
Ā Ā Ā Ā def reset_parameters(self):
Ā Ā Ā Ā Ā Ā Ā Ā self.q_proj.reset_parameters()
Ā Ā Ā Ā Ā Ā Ā Ā self.k_proj.reset_parameters()
Ā Ā Ā Ā Ā Ā Ā Ā self.v_proj.reset_parameters()
Ā Ā Ā Ā Ā Ā Ā Ā self.o_proj.reset_parameters()
Ā
Ā Ā Ā Ā def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:
Ā Ā Ā Ā Ā Ā Ā Ā bs, seq_len, dim = hidden_states.size()
Ā
Ā Ā Ā Ā Ā Ā Ā Ā # Project inputs to Q, K, V
Ā Ā Ā Ā Ā Ā Ā Ā query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)
Ā Ā Ā Ā Ā Ā Ā Ā key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)
Ā Ā Ā Ā Ā Ā Ā Ā value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)
Ā
Ā Ā Ā Ā Ā Ā Ā Ā # Apply rotary position embeddings
Ā Ā Ā Ā Ā Ā Ā Ā query_states = rope(query_states)
Ā Ā Ā Ā Ā Ā Ā Ā key_states = rope(key_states)
Ā
Ā Ā Ā Ā Ā Ā Ā Ā # Transpose tensors from BSHD to BHSD dimension for scaled_dot_product_attention
Ā Ā Ā Ā Ā Ā Ā Ā query_states = query_states.transpose(1, 2)
Ā Ā Ā Ā Ā Ā Ā Ā key_states = key_states.transpose(1, 2)
Ā Ā Ā Ā Ā Ā Ā Ā value_states = value_states.transpose(1, 2)
Ā
Ā Ā Ā Ā Ā Ā Ā Ā # Use PyTorch’s optimized attention implementation
Ā Ā Ā Ā Ā Ā Ā Ā # setting is_causal=True is incompatible with setting explicit attention mask
Ā Ā Ā Ā Ā Ā Ā Ā attn_output = F.scaled_dot_product_attention(
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā query_states,
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā key_states,
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā value_states,
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā attn_mask=attn_mask,
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā dropout_p=0.0,
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā enable_gqa=True,
Ā Ā Ā Ā Ā Ā Ā Ā )
Ā
Ā Ā Ā Ā Ā Ā Ā Ā # Transpose output tensor from BHSD to BSHD dimension, reshape to 3D, and then project output
Ā Ā Ā Ā Ā Ā Ā Ā attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)
Ā Ā Ā Ā Ā Ā Ā Ā attn_output = self.o_proj(attn_output)
Ā Ā Ā Ā Ā Ā Ā Ā return attn_output
Ā
Ā
class LlamaMLP(nn.Module):
Ā Ā Ā Ā “”“Feed-forward network with SwiGLU activation.”“”
Ā
Ā Ā Ā Ā def __init__(self, config: LlamaConfig) -> None:
Ā Ā Ā Ā Ā Ā Ā Ā super().__init__()
Ā Ā Ā Ā Ā Ā Ā Ā # Two parallel projections for SwiGLU
Ā Ā Ā Ā Ā Ā Ā Ā self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
Ā Ā Ā Ā Ā Ā Ā Ā self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
Ā Ā Ā Ā Ā Ā Ā Ā self.act_fn = F.siluĀ Ā # SwiGLU activation function
Ā Ā Ā Ā Ā Ā Ā Ā # Project back to hidden size
Ā Ā Ā Ā Ā Ā Ā Ā self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
Ā
Ā Ā Ā Ā def reset_parameters(self):
Ā Ā Ā Ā Ā Ā Ā Ā self.gate_proj.reset_parameters()
Ā Ā Ā Ā Ā Ā Ā Ā self.up_proj.reset_parameters()
Ā Ā Ā Ā Ā Ā Ā Ā self.down_proj.reset_parameters()
Ā
Ā Ā Ā Ā def forward(self, x: Tensor) -> Tensor:
Ā Ā Ā Ā Ā Ā Ā Ā # SwiGLU activation: multiply gate and up-projected inputs
Ā Ā Ā Ā Ā Ā Ā Ā gate = self.act_fn(self.gate_proj(x))
Ā Ā Ā Ā Ā Ā Ā Ā up = self.up_proj(x)
Ā Ā Ā Ā Ā Ā Ā Ā return self.down_proj(gate * up)
Ā
Ā
class LlamaDecoderLayer(nn.Module):
Ā Ā Ā Ā “”“Single transformer layer for a Llama model.”“”
Ā
Ā Ā Ā Ā def __init__(self, config: LlamaConfig) -> None:
Ā Ā Ā Ā Ā Ā Ā Ā super().__init__()
Ā Ā Ā Ā Ā Ā Ā Ā self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)
Ā Ā Ā Ā Ā Ā Ā Ā self.self_attn = LlamaAttention(config)
Ā Ā Ā Ā Ā Ā Ā Ā self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)
Ā Ā Ā Ā Ā Ā Ā Ā self.mlp = LlamaMLP(config)
Ā
Ā Ā Ā Ā def reset_parameters(self):
Ā Ā Ā Ā Ā Ā Ā Ā self.input_layernorm.reset_parameters()
Ā Ā Ā Ā Ā Ā Ā Ā self.self_attn.reset_parameters()
Ā Ā Ā Ā Ā Ā Ā Ā self.post_attention_layernorm.reset_parameters()
Ā Ā Ā Ā Ā Ā Ā Ā self.mlp.reset_parameters()
Ā
Ā Ā Ā Ā def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:
Ā Ā Ā Ā Ā Ā Ā Ā # First residual block: Self-attention
Ā Ā Ā Ā Ā Ā Ā Ā residual = hidden_states
Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.input_layernorm(hidden_states)
Ā Ā Ā Ā Ā Ā Ā Ā attn_outputs = self.self_attn(hidden_states, rope=rope, attn_mask=attn_mask)
Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = attn_outputs + residual
Ā
Ā Ā Ā Ā Ā Ā Ā Ā # Second residual block: MLP
Ā Ā Ā Ā Ā Ā Ā Ā residual = hidden_states
Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.post_attention_layernorm(hidden_states)
Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.mlp(hidden_states) + residual
Ā Ā Ā Ā Ā Ā Ā Ā return hidden_states
Ā
Ā
class LlamaModel(nn.Module):
Ā Ā Ā Ā “”“The full Llama model without any pretraining heads.”“”
Ā
Ā Ā Ā Ā def __init__(self, config: LlamaConfig) -> None:
Ā Ā Ā Ā Ā Ā Ā Ā super().__init__()
Ā Ā Ā Ā Ā Ā Ā Ā self.rotary_emb = RotaryPositionEncoding(
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā config.hidden_size // config.num_attention_heads,
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā config.max_position_embeddings,
Ā Ā Ā Ā Ā Ā Ā Ā )
Ā
Ā Ā Ā Ā Ā Ā Ā Ā self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
Ā Ā Ā Ā Ā Ā Ā Ā self.layers = nn.ModuleList([
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
Ā Ā Ā Ā Ā Ā Ā Ā ])
Ā Ā Ā Ā Ā Ā Ā Ā self.norm = nn.RMSNorm(config.hidden_size, eps=1e–5)
Ā
Ā Ā Ā Ā def reset_parameters(self):
Ā Ā Ā Ā Ā Ā Ā Ā self.embed_tokens.reset_parameters()
Ā Ā Ā Ā Ā Ā Ā Ā for layer in self.layers:
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā layer.reset_parameters()
Ā Ā Ā Ā Ā Ā Ā Ā self.norm.reset_parameters()
Ā
Ā Ā Ā Ā def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:
Ā Ā Ā Ā Ā Ā Ā Ā # Convert input token IDs to embeddings
Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.embed_tokens(input_ids)
Ā Ā Ā Ā Ā Ā Ā Ā # Process through all transformer layers, then the final norm layer
Ā Ā Ā Ā Ā Ā Ā Ā for layer in self.layers:
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask)
Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.norm(hidden_states)
Ā Ā Ā Ā Ā Ā Ā Ā # Return the final hidden states
Ā Ā Ā Ā Ā Ā Ā Ā return hidden_states
Ā
Ā
class LlamaForPretraining(nn.Module):
Ā Ā Ā Ā def __init__(self, config: LlamaConfig) -> None:
Ā Ā Ā Ā Ā Ā Ā Ā super().__init__()
Ā Ā Ā Ā Ā Ā Ā Ā self.base_model = LlamaModel(config)
Ā Ā Ā Ā Ā Ā Ā Ā self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
Ā
Ā Ā Ā Ā def reset_parameters(self):
Ā Ā Ā Ā Ā Ā Ā Ā self.base_model.reset_parameters()
Ā Ā Ā Ā Ā Ā Ā Ā self.lm_head.reset_parameters()
Ā
Ā Ā Ā Ā def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:
Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.base_model(input_ids, attn_mask)
Ā Ā Ā Ā Ā Ā Ā Ā return self.lm_head(hidden_states)
Ā
Ā
def create_causal_mask(batch: Tensor, dtype: torch.dtype = torch.float32) -> Tensor:
Ā Ā Ā Ā “”“Create a causal mask for self-attention.
Ā
Ā Ā Ā Ā Args:
Ā Ā Ā Ā Ā Ā Ā Ā batch: Batch of sequences, shape (batch_size, seq_len)
Ā Ā Ā Ā Ā Ā Ā Ā dtype: Data type of the mask
Ā
Ā Ā Ā Ā Returns:
Ā Ā Ā Ā Ā Ā Ā Ā Causal mask of shape (seq_len, seq_len)
Ā Ā Ā Ā ““”
Ā Ā Ā Ā batch_size, seq_len = batch.shape
Ā Ā Ā Ā mask = torch.full((seq_len, seq_len), float(“-inf”), device=batch.device, dtype=dtype) \
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā .triu(diagonal=1)
Ā Ā Ā Ā return mask
Ā
Ā
def create_padding_mask(batch: Tensor, padding_token_id: int, dtype: torch.dtype = torch.float32) -> Tensor:
Ā Ā Ā Ā “”“Create a padding mask for a batch of sequences for self-attention.
Ā
Ā Ā Ā Ā Args:
Ā Ā Ā Ā Ā Ā Ā Ā batch: Batch of sequences, shape (batch_size, seq_len)
Ā Ā Ā Ā Ā Ā Ā Ā padding_token_id: ID of the padding token
Ā Ā Ā Ā Ā Ā Ā Ā dtype: Data type of the mask
Ā
Ā Ā Ā Ā Returns:
Ā Ā Ā Ā Ā Ā Ā Ā Padding mask of shape (batch_size, 1, seq_len, seq_len)
Ā Ā Ā Ā ““”
Ā Ā Ā Ā padded = torch.zeros_like(batch, device=batch.device, dtype=dtype) \
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā .masked_fill(batch == padding_token_id, float(“-inf”))
Ā Ā Ā Ā mask = padded[:,:,None] + padded[:,None,:]
Ā Ā Ā Ā return mask[:, None, :, :]
Ā
Ā
# Generator function to create padded sequences of fixed length
class PretrainingDataset(torch.utils.data.Dataset):
Ā Ā Ā Ā def __init__(self, dataset: datasets.Dataset, tokenizer: tokenizers.Tokenizer,
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā seq_length: int):
Ā Ā Ā Ā Ā Ā Ā Ā self.dataset = dataset
Ā Ā Ā Ā Ā Ā Ā Ā self.tokenizer = tokenizer
Ā Ā Ā Ā Ā Ā Ā Ā self.seq_length = seq_length
Ā Ā Ā Ā Ā Ā Ā Ā self.bot = tokenizer.token_to_id(“[BOT]”)
Ā Ā Ā Ā Ā Ā Ā Ā self.eot = tokenizer.token_to_id(“[EOT]”)
Ā Ā Ā Ā Ā Ā Ā Ā self.pad = tokenizer.token_to_id(“[PAD]”)
Ā
Ā Ā Ā Ā def __len__(self):
Ā Ā Ā Ā Ā Ā Ā Ā return len(self.dataset)
Ā
Ā Ā Ā Ā def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
Ā Ā Ā Ā Ā Ā Ā Ā “”“Get a sequence of token ids from the dataset. [BOT] and [EOT] tokens
Ā Ā Ā Ā Ā Ā Ā Ā are added. Clipped and padded to the sequence length.
Ā Ā Ā Ā Ā Ā Ā Ā ““”
Ā Ā Ā Ā Ā Ā Ā Ā seq = self.dataset[index][“text”]
Ā Ā Ā Ā Ā Ā Ā Ā tokens: list[int] = [self.bot] + self.tokenizer.encode(seq).ids + [self.eot]
Ā Ā Ā Ā Ā Ā Ā Ā # pad to target sequence length
Ā Ā Ā Ā Ā Ā Ā Ā toklen = len(tokens)
Ā Ā Ā Ā Ā Ā Ā Ā if toklen < self.seq_length+1:
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā pad_length = self.seq_length+1 – toklen
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā tokens += [self.pad] * pad_length
Ā Ā Ā Ā Ā Ā Ā Ā # return the sequence
Ā Ā Ā Ā Ā Ā Ā Ā x = torch.tensor(tokens[:self.seq_length], dtype=torch.int64)
Ā Ā Ā Ā Ā Ā Ā Ā y = torch.tensor(tokens[1:self.seq_length+1], dtype=torch.int64)
Ā Ā Ā Ā Ā Ā Ā Ā return x, y
Ā
Ā
def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:
Ā Ā Ā Ā dist.barrier()
Ā Ā Ā Ā model_state, optimizer_state = get_state_dict(
Ā Ā Ā Ā Ā Ā Ā Ā model, optimizer, options=StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload),
Ā Ā Ā Ā )
Ā Ā Ā Ā load(
Ā Ā Ā Ā Ā Ā Ā Ā {“model”: model_state, “optimizer”: optimizer_state},
Ā Ā Ā Ā Ā Ā Ā Ā checkpoint_id=“checkpoint-dist”,
Ā Ā Ā Ā )
Ā Ā Ā Ā set_state_dict(
Ā Ā Ā Ā Ā Ā Ā Ā model, optimizer,
Ā Ā Ā Ā Ā Ā Ā Ā model_state_dict=model_state, optim_state_dict=optimizer_state,
Ā Ā Ā Ā Ā Ā Ā Ā options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True, cpu_offload=cpu_offload),
Ā Ā Ā Ā )
Ā Ā Ā Ā scheduler.load_state_dict(
Ā Ā Ā Ā Ā Ā Ā Ā torch.load(“checkpoint-dist/lrscheduler.pt”, map_location=device),
Ā Ā Ā Ā )
Ā Ā Ā Ā dist.barrier()
Ā
Ā
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:
Ā Ā Ā Ā dist.barrier()
Ā Ā Ā Ā model_state, optimizer_state = get_state_dict(
Ā Ā Ā Ā Ā Ā Ā Ā model, optimizer, options=StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload),
Ā Ā Ā Ā )
Ā Ā Ā Ā save(
Ā Ā Ā Ā Ā Ā Ā Ā {“model”: model_state, “optimizer”: optimizer_state},
Ā Ā Ā Ā Ā Ā Ā Ā checkpoint_id=“checkpoint-dist”,
Ā Ā Ā Ā )
Ā Ā Ā Ā if dist.get_rank() == 0:
Ā Ā Ā Ā Ā Ā Ā Ā torch.save(scheduler.state_dict(), “checkpoint-dist/lrscheduler.pt”)
Ā Ā Ā Ā dist.barrier()
Ā
Ā
# Load the tokenizer and dataset
tokenizer = tokenizers.Tokenizer.from_file(“bpe_50K.json”)
dataset = datasets.load_dataset(“HuggingFaceFW/fineweb”, “sample-10BT”, split=“train”)
Ā
# Initialize the distributed environment
dist.init_process_group(backend=“nccl”)
local_rank = int(os.environ[“LOCAL_RANK”])
device = torch.device(f“cuda:{local_rank}”)
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f“World size {world_size}, rank {rank}, local rank {local_rank}. Using {device}”)
Ā
# Create pretraining model on meta device, on all ranks
with torch.device(“meta”):
Ā Ā Ā Ā model_config = LlamaConfig()
Ā Ā Ā Ā model = LlamaForPretraining(model_config)
Ā
# Convert model from meta device to FSDP2, must shard every component
cpu_offload = False
fsdp_kwargs = {
Ā Ā Ā Ā # optional: use mixed precision training
Ā Ā Ā Ā “mp_policy”: MixedPrecisionPolicy(
Ā Ā Ā Ā Ā Ā Ā Ā param_dtype=torch.bfloat16,
Ā Ā Ā Ā Ā Ā Ā Ā reduce_dtype=torch.float32,
Ā Ā Ā Ā ),
Ā Ā Ā Ā # optional: CPU offloading
Ā Ā Ā Ā “offload_policy”: CPUOffloadPolicy() if cpu_offload else None,
Ā Ā Ā Ā # optional: discard all-gathered parameters after forward pass even on root modules
Ā Ā Ā Ā # “reshard_after_forward”: True,
}
for layer in model.base_model.layers:
Ā Ā Ā Ā fully_shard(layer, **fsdp_kwargs)
fully_shard(model.base_model, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)
model.to_empty(device=“cpu” if cpu_offload else device)
model.reset_parameters()
assert isinstance(model, FSDPModule), f“Expected FSDPModule, got {type(model)}”
Ā
# Set explicit prefetching on models
# more prefetching uses more memory, but allow more overlap of computation and communication
num_prefetch = 1
if num_prefetch > 1:
Ā Ā Ā Ā modules = list(model.base_model.layers)
Ā Ā Ā Ā for i, module in enumerate(modules):
Ā Ā Ā Ā Ā Ā Ā Ā if i == len(modules) – 1:
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā break
Ā Ā Ā Ā Ā Ā Ā Ā module.set_modules_to_forward_prefetch(modules[i+1:i+num_prefetch+1])
Ā Ā Ā Ā for i, module in enumerate(modules):
Ā Ā Ā Ā Ā Ā Ā Ā if i == 0:
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā continue
Ā Ā Ā Ā Ā Ā Ā Ā module.set_modules_to_backward_prefetch(modules[max(0, i–num_prefetch):i])
Ā
# Optional: Apply gradient checkpointing on a distributed model (all ranks)
#wrap_policy = functools.partial(
#Ā Ā Ā Ā transformer_auto_wrap_policy,
#Ā Ā Ā Ā transformer_layer_cls={LlamaDecoderLayer, nn.Embedding},
#)
#apply_activation_checkpointing(
#Ā Ā Ā Ā model,
#Ā Ā Ā Ā checkpoint_wrapper_fn=checkpoint_wrapper,
#Ā Ā Ā Ā auto_wrap_policy=wrap_policy,
#)
Ā
# Training parameters
epochs = 3
learning_rate = 1e–3
batch_size = 64 // world_size
seq_length = 512
num_warmup_steps = 1000
PAD_TOKEN_ID = tokenizer.token_to_id(“[PAD]”)
model.train()
Ā
# DataLoader, optimizer, scheduler, and loss function
# Sampler is needed to shard the dataset across world size
dataset = PretrainingDataset(dataset, tokenizer, seq_length)
sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
dataloader = torch.utils.data.DataLoader(
Ā Ā Ā Ā dataset,
Ā Ā Ā Ā sampler=sampler,
Ā Ā Ā Ā batch_size=batch_size,
Ā Ā Ā Ā pin_memory=True,Ā Ā # optional
Ā Ā Ā Ā shuffle=False,
Ā Ā Ā Ā num_workers=2,
Ā Ā Ā Ā prefetch_factor=2,
)
num_training_steps = len(dataloader) * epochs
Ā
optimizer = torch.optim.AdamW(
Ā Ā Ā Ā model.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e–8, weight_decay=0.1,
)
warmup_scheduler = lr_scheduler.LinearLR(
Ā Ā Ā Ā optimizer,
Ā Ā Ā Ā start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps,
)
cosine_scheduler = lr_scheduler.CosineAnnealingLR(
Ā Ā Ā Ā optimizer,
Ā Ā Ā Ā T_max=num_training_steps – num_warmup_steps,
Ā Ā Ā Ā eta_min=0,
)
scheduler = lr_scheduler.SequentialLR(
Ā Ā Ā Ā optimizer,
Ā Ā Ā Ā schedulers=[warmup_scheduler, cosine_scheduler],
Ā Ā Ā Ā milestones=[num_warmup_steps],
)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID)
Ā
# Optional: Compile the model and loss function
#model = torch.compile(model)
#loss_fn = torch.compile(loss_fn)
Ā
# if checkpoint-dist dir exists, load the checkpoint to model and optimizer
if os.path.exists(“checkpoint-dist”):
Ā Ā Ā Ā load_checkpoint(model, optimizer, scheduler)
Ā
# start training
for epoch in range(epochs):
Ā Ā Ā Ā pbar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{epochs}”)
Ā Ā Ā Ā for batch_id, batch in enumerate(pbar):
Ā Ā Ā Ā Ā Ā Ā Ā if batch_id % 1000 == 0:
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā save_checkpoint(model, optimizer, scheduler)
Ā Ā Ā Ā Ā Ā Ā Ā # Explicit prefetching before sending any data to model
Ā Ā Ā Ā Ā Ā Ā Ā model.unshard()
Ā Ā Ā Ā Ā Ā Ā Ā # Get batched data, move from CPU to GPU
Ā Ā Ā Ā Ā Ā Ā Ā input_ids, target_ids = batch
Ā Ā Ā Ā Ā Ā Ā Ā input_ids = input_ids.to(device)
Ā Ā Ā Ā Ā Ā Ā Ā target_ids = target_ids.to(device)
Ā Ā Ā Ā Ā Ā Ā Ā # create attention mask: causal mask + padding mask
Ā Ā Ā Ā Ā Ā Ā Ā attn_mask = create_causal_mask(input_ids) + \
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā create_padding_mask(input_ids, PAD_TOKEN_ID)
Ā Ā Ā Ā Ā Ā Ā Ā # Extract output from model
Ā Ā Ā Ā Ā Ā Ā Ā logits = model(input_ids, attn_mask)
Ā Ā Ā Ā Ā Ā Ā Ā # Compute loss: cross-entropy between logits and target, ignoring padding tokens
Ā Ā Ā Ā Ā Ā Ā Ā loss = loss_fn(logits.view(–1, logits.size(–1)), target_ids.view(–1))
Ā Ā Ā Ā Ā Ā Ā Ā # Backward with loss and gradient clipping by L2 norm to 1.0
Ā Ā Ā Ā Ā Ā Ā Ā # Optimizer and gradient clipping works on DTensor
Ā Ā Ā Ā Ā Ā Ā Ā optimizer.zero_grad(set_to_none=False if cpu_offload else True)
Ā Ā Ā Ā Ā Ā Ā Ā loss.backward()
Ā Ā Ā Ā Ā Ā Ā Ā # All-reduce fail if using CPU offloading
Ā Ā Ā Ā Ā Ā Ā Ā if not cpu_offload:
Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
Ā Ā Ā Ā Ā Ā Ā Ā optimizer.step()
Ā Ā Ā Ā Ā Ā Ā Ā scheduler.step()
Ā Ā Ā Ā Ā Ā Ā Ā pbar.set_postfix(loss=loss.item())
Ā Ā Ā Ā Ā Ā Ā Ā pbar.update(1)
Ā Ā Ā Ā pbar.close()
Ā
# Save the model
save_checkpoint(model, optimizer, scheduler)
Ā
# Clean up the distributed environment
dist.destroy_process_group()




