引言 大语言模型(Large Language Models, LLMs)的预训练是当今人工智能领域最引人注目的技术突破之一。从GPT系列到LLaMA、Claude,这些模型展现了惊人的语言理解和生成能力。本文将从架构原理到工程实践,全面解析大模型预训练的完整流程。
大模型预训练的意义 预训练使模型能够从海量文本数据中学习通用的语言表示,这些知识可以迁移到各种下游任务中。预训练的意义体现在:
知识获取 :模型学习到丰富的世界知识和语言模式
迁移学习 :预训练权重可作为各种任务的初始化
涌现能力 :规模足够大时,模型展现出意想不到的新能力
Transformer是现代大语言模型的基石架构,由Google在2017年的论文《Attention Is All You Need》中提出。
1.1 核心组件 嵌入层(Embedding Layer) $$ \text{Embedding}(x) = W_e \cdot \text{OneHot}(x) + P $$
其中 $W_e$ 是嵌入矩阵,$P$ 是位置编码。
位置编码(Positional Encoding) 原始Transformer使用正弦位置编码:
$$ PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}}) $$ $$ PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}}) $$
现代LLM多采用旋转位置编码(RoPE)或ALiBi位置编码。
1.2 编码器-解码器架构 原始Transformer包含编码器和解码器两部分:
编码器 :双向注意力,适用于理解任务(BERT风格)
解码器 :单向(因果)注意力,适用于生成任务(GPT风格)
现代大模型多采用仅解码器架构。
2. 自注意力机制详解 自注意力是Transformer的核心创新,允许模型在处理每个位置时关注序列的所有位置。
2.1 缩放点积注意力 $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
其中:
$Q$ (Query):查询矩阵
$K$ (Key):键矩阵
$V$ (Value):值矩阵
$d_k$:键向量维度
2.2 多头注意力 $$ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1, …, head_h)W^O $$
其中每个头: $$ head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$
2.3 代码实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 import torchimport torch.nn as nnimport mathclass MultiHeadAttention (nn.Module): def __init__ (self, d_model, num_heads, dropout=0.1 ): super ().__init__() assert d_model % num_heads == 0 self .d_model = d_model self .num_heads = num_heads self .d_k = d_model // num_heads self .W_q = nn.Linear(d_model, d_model) self .W_k = nn.Linear(d_model, d_model) self .W_v = nn.Linear(d_model, d_model) self .W_o = nn.Linear(d_model, d_model) self .dropout = nn.Dropout(dropout) def forward (self, x, mask=None ): batch_size, seq_len, _ = x.size() Q = self .W_q(x).view(batch_size, seq_len, self .num_heads, self .d_k).transpose(1 , 2 ) K = self .W_k(x).view(batch_size, seq_len, self .num_heads, self .d_k).transpose(1 , 2 ) V = self .W_v(x).view(batch_size, seq_len, self .num_heads, self .d_k).transpose(1 , 2 ) scores = torch.matmul(Q, K.transpose(-2 , -1 )) / math.sqrt(self .d_k) if mask is not None : scores = scores.masked_fill(mask == 0 , float ('-inf' )) attn_weights = torch.softmax(scores, dim=-1 ) attn_weights = self .dropout(attn_weights) context = torch.matmul(attn_weights, V) context = context.transpose(1 , 2 ).contiguous().view(batch_size, seq_len, self .d_model) output = self .W_o(context) return output, attn_weights
3. 主流大模型架构对比 3.1 GPT系列 GPT采用仅解码器架构,使用因果注意力掩码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 class GPTBlock (nn.Module): def __init__ (self, d_model, num_heads, d_ff, dropout=0.1 ): super ().__init__() self .ln1 = nn.LayerNorm(d_model) self .attn = MultiHeadAttention(d_model, num_heads, dropout) self .ln2 = nn.LayerNorm(d_model) self .ffn = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model), nn.Dropout(dropout) ) def forward (self, x, mask=None ): x = x + self .attn(self .ln1(x), mask)[0 ] x = x + self .ffn(self .ln2(x)) return x
3.2 LLaMA架构 LLaMA在GPT基础上做了几个关键改进:
RMSNorm :替代LayerNorm
RoPE :旋转位置编码
SwiGLU :新的激活函数
分组查询注意力(GQA) :提升推理效率
1 2 3 4 5 6 7 8 9 class RMSNorm (nn.Module): def __init__ (self, dim, eps=1e-6 ): super ().__init__() self .eps = eps self .weight = nn.Parameter(torch.ones(dim)) def forward (self, x ): rms = torch.sqrt(torch.mean(x ** 2 , dim=-1 , keepdim=True ) + self .eps) return x / rms * self .weight
3.3 架构对比表
特性
GPT
LLaMA
PaLM
归一化位置
Pre-norm
Pre-norm
Pre-norm
归一化方式
LayerNorm
RMSNorm
LayerNorm
位置编码
学习式
RoPE
RoPE
激活函数
GELU
SwiGLU
SwiGLU
注意力
MHA
GQA
MQA
4. 预训练任务设计 4.1 因果语言建模(CLM) 最常用的预训练目标,预测下一个token:
$$ \mathcal{L}{CLM} = -\sum {t=1}^{T} \log P(x_t | x_{<t}) $$
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 def clm_loss (logits, targets, ignore_index=-100 ): """ Causal Language Modeling Loss logits: [batch_size, seq_len, vocab_size] targets: [batch_size, seq_len] """ loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean' ) shift_logits = logits[..., :-1 , :].contiguous() shift_targets = targets[..., 1 :].contiguous() loss = loss_fn( shift_logits.view(-1 , shift_logits.size(-1 )), shift_targets.view(-1 ) ) return loss
4.2 掩码语言建模(MLM) BERT使用的预训练目标,随机遮蔽token并预测:
$$ \mathcal{L}{MLM} = -\sum {i \in M} \log P(x_i | x_{\setminus M}) $$
5. 大规模分布式训练 训练大模型需要分布式计算来处理内存和计算需求。
5.1 数据并行(Data Parallelism) 最简单的并行策略,每个设备持有完整模型副本。
1 2 3 4 5 6 7 8 9 10 11 12 import torch.distributed as distfrom torch.nn.parallel import DistributedDataParallel as DDPdef setup_distributed (): dist.init_process_group(backend='nccl' ) local_rank = int (os.environ['LOCAL_RANK' ]) torch.cuda.set_device(local_rank) return local_rank model = MyModel().cuda() model = DDP(model, device_ids=[local_rank])
5.2 张量并行(Tensor Parallelism) 将单个张量切分到多个设备:
1 2 3 4 5 6 7 8 9 10 class ColumnParallelLinear (nn.Module): def __init__ (self, in_features, out_features, world_size ): super ().__init__() self .out_features_per_part = out_features // world_size self .weight = nn.Parameter( torch.empty(self .out_features_per_part, in_features) ) def forward (self, x ): return F.linear(x, self .weight)
5.3 流水线并行(Pipeline Parallelism) 将模型层切分到不同设备:
1 2 3 4 5 6 7 8 9 10 from torch.distributed.pipeline.sync import Pipe model = nn.Sequential( nn.Linear(1024 , 4096 ), nn.ReLU(), nn.Linear(4096 , 1024 ), ) model = Pipe(model, chunks=8 )
5.4 ZeRO优化 DeepSpeed ZeRO通过分片优化器状态、梯度和参数降低内存:
1 2 3 4 5 6 7 8 9 10 11 12 13 ds_config = { "train_batch_size" : 512 , "zero_optimization" : { "stage" : 3 , "offload_optimizer" : { "device" : "cpu" }, "offload_param" : { "device" : "cpu" } } }
6. 数据处理流水线
6.1 数据收集与清洗 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 import refrom typing import List class DataCleaner : def __init__ (self ): self .patterns = [ (r'<[^>]+>' , '' ), (r'\s+' , ' ' ), (r'http\S+' , '' ), ] def clean (self, text: str ) -> str : for pattern, replacement in self .patterns: text = re.sub(pattern, replacement, text) return text.strip() def deduplicate (self, documents: List [str ] ) -> List [str ]: seen = set () unique = [] for doc in documents: h = hash (doc[:1000 ]) if h not in seen: seen.add(h) unique.append(doc) return unique
6.2 Tokenizer训练 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 from tokenizers import Tokenizerfrom tokenizers.models import BPEfrom tokenizers.trainers import BpeTrainerfrom tokenizers.pre_tokenizers import ByteLevel tokenizer = Tokenizer(BPE(unk_token="<unk>" )) tokenizer.pre_tokenizer = ByteLevel() trainer = BpeTrainer( vocab_size=32000 , special_tokens=["<unk>" , "<s>" , "</s>" , "<pad>" , "<mask>" ] ) tokenizer.train(files=["data/train.txt" ], trainer=trainer) tokenizer.save("tokenizer.json" )
6.3 高效数据加载 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 from torch.utils.data import Dataset, DataLoaderimport numpy as npclass PretrainingDataset (Dataset ): def __init__ (self, data_path, seq_length=2048 ): self .seq_length = seq_length self .data = np.memmap(data_path, dtype=np.uint16, mode='r' ) def __len__ (self ): return len (self .data) // self .seq_length def __getitem__ (self, idx ): start = idx * self .seq_length end = start + self .seq_length + 1 chunk = torch.from_numpy(self .data[start:end].copy()) return chunk[:-1 ], chunk[1 :]
7. 实战:从零预训练小型LLaMA 7.1 模型定义 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 import torchimport torch.nn as nnimport torch.nn.functional as Fimport mathclass RotaryPositionalEmbedding (nn.Module): def __init__ (self, dim, max_seq_len=2048 , base=10000 ): super ().__init__() inv_freq = 1.0 / (base ** (torch.arange(0 , dim, 2 ).float () / dim)) self .register_buffer('inv_freq' , inv_freq) self .max_seq_len = max_seq_len def forward (self, x, seq_len ): t = torch.arange(seq_len, device=x.device).type_as(self .inv_freq) freqs = torch.einsum('i,j->ij' , t, self .inv_freq) emb = torch.cat((freqs, freqs), dim=-1 ) return emb.cos(), emb.sin()class MiniLLaMA (nn.Module): def __init__ ( self, vocab_size=32000 , dim=512 , n_layers=8 , n_heads=8 , max_seq_len=2048 , dropout=0.1 ): super ().__init__() self .vocab_size = vocab_size self .dim = dim self .n_layers = n_layers self .tok_embeddings = nn.Embedding(vocab_size, dim) self .layers = nn.ModuleList([ TransformerBlock(dim, n_heads, dropout) for _ in range (n_layers) ]) self .norm = RMSNorm(dim) self .output = nn.Linear(dim, vocab_size, bias=False ) self .rope = RotaryPositionalEmbedding(dim // n_heads, max_seq_len) self .apply(self ._init_weights) def _init_weights (self, module ): if isinstance (module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0 , std=0.02 ) if module.bias is not None : torch.nn.init.zeros_(module.bias) elif isinstance (module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0 , std=0.02 ) def forward (self, tokens ): batch_size, seq_len = tokens.shape x = self .tok_embeddings(tokens) cos, sin = self .rope(x, seq_len) mask = torch.triu( torch.full((seq_len, seq_len), float ('-inf' ), device=tokens.device), diagonal=1 ) for layer in self .layers: x = layer(x, cos, sin, mask) x = self .norm(x) logits = self .output(x) return logitsdef train_model (): batch_size = 32 seq_length = 512 learning_rate = 3e-4 num_epochs = 10 model = MiniLLaMA( vocab_size=32000 , dim=512 , n_layers=8 , n_heads=8 ).cuda() optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, betas=(0.9 , 0.95 ), weight_decay=0.1 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_epochs * 1000 , eta_min=1e-5 ) model.train() for epoch in range (num_epochs): for batch_idx, (inputs, targets) in enumerate (train_loader): inputs, targets = inputs.cuda(), targets.cuda() logits = model(inputs) loss = clm_loss(logits, targets) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0 ) optimizer.step() scheduler.step() if batch_idx % 100 == 0 : print (f"Epoch {epoch} , Batch {batch_idx} , Loss: {loss.item():.4 f} " )
8. 训练监控与调优 8.1 关键监控指标
训练损失 :持续下降为正常
学习率 :warmup后逐步衰减
梯度范数 :过大可能不稳定
GPU内存 :监控OOM风险
1 2 3 4 5 6 7 8 9 10 11 12 from torch.utils.tensorboard import SummaryWriterclass TrainingMonitor : def __init__ (self, log_dir ): self .writer = SummaryWriter(log_dir) self .global_step = 0 def log_metrics (self, loss, lr, grad_norm, epoch ): self .writer.add_scalar('Loss/train' , loss, self .global_step) self .writer.add_scalar('LR' , lr, self .global_step) self .writer.add_scalar('GradNorm' , grad_norm, self .global_step) self .global_step += 1
8.2 常见问题与解决方案
问题
症状
解决方案
损失NaN
训练崩溃
降低学习率,检查数据
不收敛
损失震荡
调整学习率,增加warmup
内存不足
OOM
减小batch或使用梯度检查点
训练缓慢
吞吐量低
检查数据加载瓶颈
9. 总结与展望 9.1 关键要点
架构选择 :现代LLM多采用仅解码器架构配合RoPE
分布式训练 :数据并行、张量并行、流水线并行结合使用
数据处理 :高质量数据是预训练成功的关键
工程优化 :混合精度、梯度检查点、高效数据加载
9.2 发展趋势
更长上下文 :从4K到100K+ tokens
多模态融合 :文本、图像、音频统一建模
高效微调 :LoRA、QLoRA等参数高效方法
模型压缩 :量化、剪枝、蒸馏
大模型预训练是一个快速发展的领域,持续关注最新论文和开源项目是保持竞争力的关键。希望本文能为你搭建自己的大模型提供实用指导!