引言

大语言模型(Large Language Models, LLMs)的预训练是当今人工智能领域最引人注目的技术突破之一。从GPT系列到LLaMA、Claude,这些模型展现了惊人的语言理解和生成能力。本文将从架构原理到工程实践,全面解析大模型预训练的完整流程。

大模型预训练的意义

预训练使模型能够从海量文本数据中学习通用的语言表示,这些知识可以迁移到各种下游任务中。预训练的意义体现在:

  1. 知识获取:模型学习到丰富的世界知识和语言模式
  2. 迁移学习:预训练权重可作为各种任务的初始化
  3. 涌现能力:规模足够大时,模型展现出意想不到的新能力

1. Transformer架构深度解析

Transformer是现代大语言模型的基石架构,由Google在2017年的论文《Attention Is All You Need》中提出。

Transformer架构

1.1 核心组件

嵌入层(Embedding Layer)

$$
\text{Embedding}(x) = W_e \cdot \text{OneHot}(x) + P
$$

其中 $W_e$ 是嵌入矩阵,$P$ 是位置编码。

位置编码(Positional Encoding)

原始Transformer使用正弦位置编码:

$$
PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})
$$
$$
PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})
$$

现代LLM多采用旋转位置编码(RoPE)或ALiBi位置编码。

1.2 编码器-解码器架构

原始Transformer包含编码器和解码器两部分:

  • 编码器:双向注意力,适用于理解任务(BERT风格)
  • 解码器:单向(因果)注意力,适用于生成任务(GPT风格)

现代大模型多采用仅解码器架构。


2. 自注意力机制详解

自注意力是Transformer的核心创新,允许模型在处理每个位置时关注序列的所有位置。

自注意力机制

2.1 缩放点积注意力

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

其中:

  • $Q$ (Query):查询矩阵
  • $K$ (Key):键矩阵
  • $V$ (Value):值矩阵
  • $d_k$:键向量维度

2.2 多头注意力

$$
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1, …, head_h)W^O
$$

其中每个头:
$$
head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$$

2.3 代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0

self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads

self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

self.dropout = nn.Dropout(dropout)

def forward(self, x, mask=None):
batch_size, seq_len, _ = x.size()

# Linear projections
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

# Apply causal mask for decoder
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)

# Apply attention to values
context = torch.matmul(attn_weights, V)

# Concatenate heads
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

# Final linear projection
output = self.W_o(context)

return output, attn_weights

3. 主流大模型架构对比

3.1 GPT系列

GPT采用仅解码器架构,使用因果注意力掩码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class GPTBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)

def forward(self, x, mask=None):
# Pre-norm architecture
x = x + self.attn(self.ln1(x), mask)[0]
x = x + self.ffn(self.ln2(x))
return x

3.2 LLaMA架构

LLaMA在GPT基础上做了几个关键改进:

  1. RMSNorm:替代LayerNorm
  2. RoPE:旋转位置编码
  3. SwiGLU:新的激活函数
  4. 分组查询注意力(GQA):提升推理效率
1
2
3
4
5
6
7
8
9
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight

3.3 架构对比表

特性 GPT LLaMA PaLM
归一化位置 Pre-norm Pre-norm Pre-norm
归一化方式 LayerNorm RMSNorm LayerNorm
位置编码 学习式 RoPE RoPE
激活函数 GELU SwiGLU SwiGLU
注意力 MHA GQA MQA

4. 预训练任务设计

4.1 因果语言建模(CLM)

最常用的预训练目标,预测下一个token:

$$
\mathcal{L}{CLM} = -\sum{t=1}^{T} \log P(x_t | x_{<t})
$$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def clm_loss(logits, targets, ignore_index=-100):
"""
Causal Language Modeling Loss
logits: [batch_size, seq_len, vocab_size]
targets: [batch_size, seq_len]
"""
loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean')
# Shift logits and targets for next token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_targets = targets[..., 1:].contiguous()

loss = loss_fn(
shift_logits.view(-1, shift_logits.size(-1)),
shift_targets.view(-1)
)
return loss

4.2 掩码语言建模(MLM)

BERT使用的预训练目标,随机遮蔽token并预测:

$$
\mathcal{L}{MLM} = -\sum{i \in M} \log P(x_i | x_{\setminus M})
$$


5. 大规模分布式训练

训练大模型需要分布式计算来处理内存和计算需求。

分布式训练策略

5.1 数据并行(Data Parallelism)

最简单的并行策略,每个设备持有完整模型副本。

1
2
3
4
5
6
7
8
9
10
11
12
# PyTorch DDP示例
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_distributed():
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
return local_rank

model = MyModel().cuda()
model = DDP(model, device_ids=[local_rank])

5.2 张量并行(Tensor Parallelism)

将单个张量切分到多个设备:

1
2
3
4
5
6
7
8
9
10
class ColumnParallelLinear(nn.Module):
def __init__(self, in_features, out_features, world_size):
super().__init__()
self.out_features_per_part = out_features // world_size
self.weight = nn.Parameter(
torch.empty(self.out_features_per_part, in_features)
)

def forward(self, x):
return F.linear(x, self.weight)

5.3 流水线并行(Pipeline Parallelism)

将模型层切分到不同设备:

1
2
3
4
5
6
7
8
9
10
from torch.distributed.pipeline.sync import Pipe

model = nn.Sequential(
nn.Linear(1024, 4096),
nn.ReLU(),
nn.Linear(4096, 1024),
)

# 自动分割模型到多个GPU
model = Pipe(model, chunks=8)

5.4 ZeRO优化

DeepSpeed ZeRO通过分片优化器状态、梯度和参数降低内存:

1
2
3
4
5
6
7
8
9
10
11
12
13
# DeepSpeed配置
ds_config = {
"train_batch_size": 512,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
},
"offload_param": {
"device": "cpu"
}
}
}

6. 数据处理流水线

预训练流程

6.1 数据收集与清洗

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import re
from typing import List

class DataCleaner:
def __init__(self):
self.patterns = [
(r'<[^>]+>', ''), # Remove HTML tags
(r'\s+', ' '), # Normalize whitespace
(r'http\S+', ''), # Remove URLs
]

def clean(self, text: str) -> str:
for pattern, replacement in self.patterns:
text = re.sub(pattern, replacement, text)
return text.strip()

def deduplicate(self, documents: List[str]) -> List[str]:
seen = set()
unique = []
for doc in documents:
h = hash(doc[:1000]) # Hash first 1000 chars
if h not in seen:
seen.add(h)
unique.append(doc)
return unique

6.2 Tokenizer训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel

# 训练BPE Tokenizer
tokenizer = Tokenizer(BPE(unk_token="<unk>"))
tokenizer.pre_tokenizer = ByteLevel()
trainer = BpeTrainer(
vocab_size=32000,
special_tokens=["<unk>", "<s>", "</s>", "<pad>", "<mask>"]
)

tokenizer.train(files=["data/train.txt"], trainer=trainer)
tokenizer.save("tokenizer.json")

6.3 高效数据加载

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from torch.utils.data import Dataset, DataLoader
import numpy as np

class PretrainingDataset(Dataset):
def __init__(self, data_path, seq_length=2048):
self.seq_length = seq_length
# Memory-mapped file for efficient loading
self.data = np.memmap(data_path, dtype=np.uint16, mode='r')

def __len__(self):
return len(self.data) // self.seq_length

def __getitem__(self, idx):
start = idx * self.seq_length
end = start + self.seq_length + 1
chunk = torch.from_numpy(self.data[start:end].copy())
return chunk[:-1], chunk[1:]

7. 实战:从零预训练小型LLaMA

7.1 模型定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=2048, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.max_seq_len = max_seq_len

def forward(self, x, seq_len):
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()

class MiniLLaMA(nn.Module):
def __init__(
self,
vocab_size=32000,
dim=512,
n_layers=8,
n_heads=8,
max_seq_len=2048,
dropout=0.1
):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.n_layers = n_layers

# Token embedding
self.tok_embeddings = nn.Embedding(vocab_size, dim)

# Transformer layers
self.layers = nn.ModuleList([
TransformerBlock(dim, n_heads, dropout)
for _ in range(n_layers)
])

# Output
self.norm = RMSNorm(dim)
self.output = nn.Linear(dim, vocab_size, bias=False)

# RoPE
self.rope = RotaryPositionalEmbedding(dim // n_heads, max_seq_len)

# Initialize weights
self.apply(self._init_weights)

def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

def forward(self, tokens):
batch_size, seq_len = tokens.shape

# Get embeddings
x = self.tok_embeddings(tokens)

# Get RoPE
cos, sin = self.rope(x, seq_len)

# Create causal mask
mask = torch.triu(
torch.full((seq_len, seq_len), float('-inf'), device=tokens.device),
diagonal=1
)

# Apply transformer layers
for layer in self.layers:
x = layer(x, cos, sin, mask)

# Output projection
x = self.norm(x)
logits = self.output(x)

return logits

# Training script
def train_model():
# Hyperparameters
batch_size = 32
seq_length = 512
learning_rate = 3e-4
num_epochs = 10

# Initialize model
model = MiniLLaMA(
vocab_size=32000,
dim=512,
n_layers=8,
n_heads=8
).cuda()

# Optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
betas=(0.9, 0.95),
weight_decay=0.1
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=num_epochs * 1000,
eta_min=1e-5
)

# Training loop
model.train()
for epoch in range(num_epochs):
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.cuda(), targets.cuda()

# Forward pass
logits = model(inputs)
loss = clm_loss(logits, targets)

# Backward pass
optimizer.zero_grad()
loss.backward()

# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

optimizer.step()
scheduler.step()

if batch_idx % 100 == 0:
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")

8. 训练监控与调优

8.1 关键监控指标

  • 训练损失:持续下降为正常
  • 学习率:warmup后逐步衰减
  • 梯度范数:过大可能不稳定
  • GPU内存:监控OOM风险
1
2
3
4
5
6
7
8
9
10
11
12
from torch.utils.tensorboard import SummaryWriter

class TrainingMonitor:
def __init__(self, log_dir):
self.writer = SummaryWriter(log_dir)
self.global_step = 0

def log_metrics(self, loss, lr, grad_norm, epoch):
self.writer.add_scalar('Loss/train', loss, self.global_step)
self.writer.add_scalar('LR', lr, self.global_step)
self.writer.add_scalar('GradNorm', grad_norm, self.global_step)
self.global_step += 1

8.2 常见问题与解决方案

问题 症状 解决方案
损失NaN 训练崩溃 降低学习率,检查数据
不收敛 损失震荡 调整学习率,增加warmup
内存不足 OOM 减小batch或使用梯度检查点
训练缓慢 吞吐量低 检查数据加载瓶颈

9. 总结与展望

9.1 关键要点

  1. 架构选择:现代LLM多采用仅解码器架构配合RoPE
  2. 分布式训练:数据并行、张量并行、流水线并行结合使用
  3. 数据处理:高质量数据是预训练成功的关键
  4. 工程优化:混合精度、梯度检查点、高效数据加载

9.2 发展趋势

  • 更长上下文:从4K到100K+ tokens
  • 多模态融合:文本、图像、音频统一建模
  • 高效微调:LoRA、QLoRA等参数高效方法
  • 模型压缩:量化、剪枝、蒸馏

大模型预训练是一个快速发展的领域,持续关注最新论文和开源项目是保持竞争力的关键。希望本文能为你搭建自己的大模型提供实用指导!