1. 从理论到实践为什么Wanda值得一试上次我们聊了Wanda论文的核心思想很多朋友留言说原理听起来很巧妙但具体怎么用代码实现效果到底怎么样会不会把模型“剪废了”今天我就以一个实践者的身份带大家手把手走一遍Wanda剪枝的完整流程并用真实的代码和测试数据看看这个“无需再训练”的剪枝方案到底是不是像论文里说的那么神奇。我自己第一次接触模型压缩时也踩过不少坑。最头疼的就是很多剪枝方法要么需要庞大的计算资源进行“再训练”fine-tuning这对于动辄几十亿参数的大模型来说成本高得吓人要么就是操作极其复杂像SparseGPT那样需要复杂的权重重建门槛很高。Wanda吸引我的地方就在于它的“简单”和“直接”。它不跟你绕弯子核心思想就一句话一个权重重不重要不能光看它自己有多大还得看它“干活”的时候输入的数据激不激动它。这就像评价一个员工不能只看他的职级权重幅度还得看他处理具体业务输入激活时的产出效率。一个职级不高的员工如果总是处理关键业务且产出很高那他也至关重要不能轻易“优化”掉。这种思路带来的最大好处就是剪枝过程变得非常轻量。你不需要准备海量的下游任务数据不需要启动漫长的再训练过程甚至不需要复杂的二阶导数计算。只需要准备一些代表性的校准数据比如几百条文本跑一遍前向传播收集一下中间层的激活值然后根据Wanda的公式算一算就能决定剪掉哪些权重。整个过程几乎可以在一台消费级GPU上完成这对于广大算力有限的开发者和研究者来说无疑是个福音。接下来我们就进入实战环节我会假设你已经有了一些PyTorch和Hugging Face Transformers的基础跟着步骤一步步来。2. 环境准备与数据校准2.1 搭建你的实验环境工欲善其事必先利其器。首先我们需要一个干净的环境。我强烈建议使用Conda或虚拟环境来管理依赖避免包版本冲突。下面是我在项目中常用的环境配置# 创建并激活一个Python 3.9的虚拟环境 conda create -n wanda_pruning python3.9 -y conda activate wanda_pruning # 安装核心依赖 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # 请根据你的CUDA版本调整 pip install transformers datasets accelerate pip install scipy # 用于一些数据统计这里有几个关键点。第一Torch版本要匹配你的CUDA驱动否则无法利用GPU加速。你可以去PyTorch官网查看对应的安装命令。第二accelerate库不是必须的但它能帮助我们更优雅地处理大模型加载和设备放置建议安装。环境准备好后我们就可以开始加载模型和数据了。2.2 加载模型与准备校准数据Wanda剪枝需要一个预训练好的大语言模型和一些用于“观察”模型内部激活的数据。这些数据不需要标注目的仅仅是让模型“运行”起来让我们能看到每个神经元权重在面对真实输入时的活跃程度。from transformers import AutoModelForCausalLM, AutoTokenizer import torch # 加载模型和分词器。这里以LLaMA-2 7B为例你需要有相应的模型访问权限。 model_name “meta-llama/Llama-2-7b-hf” # 或者使用本地路径 tokenizer AutoTokenizer.from_pretrained(model_name) model AutoModelForCausalLM.from_pretrained(model_name, torch_dtypetorch.float16, # 使用半精度节省显存 device_map“auto”) # 自动将不同层分配到可用设备 # 设置分词器的padding token如果模型没有默认设置 if tokenizer.pad_token is None: tokenizer.pad_token tokenizer.eos_token接下来是准备校准数据。论文中使用的是WikiText2验证集的前128个样本。我们也可以效仿或者使用任何其他干净的文本数据。关键是数据量不用大但要有一定的代表性能反映模型可能处理的文本分布。我这里从Hugging Face Datasets加载WikiText2from datasets import load_dataset # 加载WikiText2数据集 dataset load_dataset(“wikitext”, “wikitext-2-raw-v1”, split“validation”) # 选取前128条文本并过滤掉空行 calib_texts [text for text in dataset[“text”][:128] if text.strip() ! “”] print(f“共收集到 {len(calib_texts)} 条校准文本。”) # 将文本转换为模型输入 calib_batch_size 4 # 根据你的GPU显存调整 calib_inputs [] for i in range(0, len(calib_texts), calib_batch_size): batch_texts calib_texts[i:icalib_batch_size] inputs tokenizer(batch_texts, return_tensors“pt”, paddingTrue, truncationTrue, max_length512) calib_inputs.append(inputs)这里我设置了batch_size4如果你的显存较小可以设为1或2。max_length限制文本长度防止过长的序列导致OOM内存溢出。现在数据和模型都准备好了真正的剪枝之旅即将开始。3. 核心剪枝算法代码逐行解析Wanda算法的核心代码非常简洁论文中只给了一个函数。但在实际应用中我们需要将其适配到Transformer模型的每一层线性层上。下面我将这个核心函数拆解开来并融入完整的模型遍历逻辑。3.1 Wanda剪枝度量计算这是算法的灵魂所在。我们不是简单地按权重的绝对值大小来排序而是计算一个“权重-激活联合度量”。def compute_wanda_metric(weight, input_activation): 计算Wanda剪枝度量。 参数: weight: 权重矩阵形状为 (输出维度, 输入维度) input_activation: 对应层的输入激活形状为 (批次大小*序列长度, 输入维度) 返回: metric: 剪枝度量矩阵形状同weight # 步骤1: 计算输入激活的L2范数按列即每个输入特征维度 # 这代表了每个输入通道在整个校准数据上的“平均活跃强度” activation_norm input_activation.norm(p2, dim0) # 形状: (输入维度,) # 步骤2: 计算度量 |权重| * 输入激活L2范数 # 这里使用了广播机制。weight.abs()形状是(out_dim, in_dim) # activation_norm形状是(in_dim,)会自动广播到每一行 metric weight.abs() * activation_norm return metric这里有个非常重要的细节input_activation是怎么来的我们需要在剪枝前用校准数据做一次前向传播并“钩住”hook每一层线性层的输入。这个过程叫做“激活收集”。下面我们写一个收集函数def collect_activations(model, calib_inputs): 运行校准数据的前向传播收集每一层线性层的输入激活。 返回一个字典键为模块名值为该模块输入激活的拼接结果。 activations {} hooks [] # 定义钩子函数 def hook_fn(module, input, output, name): # input是一个元组我们取第一个元素即输入张量 # 将其从计算图中分离并转换为浮点型以便后续计算 act input[0].detach().to(torch.float32) if name in activations: activations[name] torch.cat([activations[name], act], dim0) else: activations[name] act # 为所有线性层不包括嵌入层和LM头注册钩子 for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and “lm_head” not in name and “embed_tokens” not in name: # 使用lambda绑定name参数确保钩子能正确记录模块名 hook module.register_forward_hook( lambda m, i, o, nname: hook_fn(m, i, o, n) ) hooks.append(hook) # 运行前向传播不计算梯度以节省内存 model.eval() with torch.no_grad(): for batch in calib_inputs: # 将数据移动到模型所在的设备 batch {k: v.to(model.device) for k, v in batch.items()} _ model(**batch) # 移除所有钩子 for hook in hooks: hook.remove() return activations这个函数会遍历模型的所有Linear层在每次前向计算时把输入张量保存下来。注意我们跳过了lm_head语言模型头部和embed_tokens词嵌入层因为论文指出这两层参数占比小且对性能敏感通常不剪或单独处理。3.2 执行剪枝与稀疏性控制收集到激活后我们就可以对每一层应用Wanda剪枝了。我们需要指定一个目标稀疏度sparsity比如50%意味着我们要剪掉该层50%的权重。def prune_layer_wanda(weight, activation, sparsity): 对单层权重执行Wanda剪枝。 参数: weight: 权重矩阵 activation: 收集到的该层输入激活 sparsity: 目标稀疏度如0.5表示剪掉50%的权重 返回: pruned_weight: 剪枝后的权重矩阵原地修改但返回以供参考 mask: 生成的二进制掩码1表示保留0表示剪枝 out_dim, in_dim weight.shape # 计算度量 metric compute_wanda_metric(weight, activation) # 确定要剪枝的权重数量 num_prune int(in_dim * sparsity) # 按行排序对每一行每个输出神经元独立排序 # 这实现了“局部剪枝”比全局剪枝效果更好 _, sorted_indices torch.sort(metric, dim1) # 生成剪枝掩码每行最小的num_prune个位置置为0剪掉其余为1保留 prune_mask torch.ones_like(weight, dtypetorch.bool) prune_indices sorted_indices[:, :num_prune] # 使用scatter_将指定位置设为0。这里用了一个技巧先生成全1掩码再将对应位置设为False。 prune_mask.scatter_(dim1, indexprune_indices, srctorch.zeros_like(prune_indices, dtypetorch.bool)) # 应用掩码将需要剪枝的权重置零 weight.data * prune_mask return weight, prune_mask关键点解析dim1排序这是在每个输出神经元即权重矩阵的每一行内部独立排序。这意味着对于同一个输入维度它对于不同的输出神经元的重要性可能是不同的。这是Wanda“局部重要性”思想的直接体现。生成掩码我们生成一个与权重同形状的布尔掩码。剪枝操作就是简单的weight * mask。使用掩码的好处是我们明确记录了哪些位置被剪枝了这在后续分析或应用结构化稀疏时很有用。稀疏度控制sparsity参数控制剪掉的比例。int(in_dim * sparsity)计算了每行要剪掉多少个权重输入连接。现在我们将上述所有步骤组合起来形成一个完整的模型剪枝流程def prune_model_wanda(model, tokenizer, calib_texts, target_sparsity0.5): 完整的Wanda模型剪枝流程。 print(“步骤1: 准备校准数据输入...”) calib_inputs [...] # 同前面的数据准备代码 print(“步骤2: 收集模型各层激活...”) activations collect_activations(model, calib_inputs) print(f“步骤3: 开始执行Wanda剪枝目标稀疏度{target_sparsity}...”) prune_masks {} for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and name in activations: print(f“ 正在剪枝层: {name}”) original_weight module.weight.data layer_activation activations[name] # 执行剪枝 pruned_weight, mask prune_layer_wanda(original_weight, layer_activation, target_sparsity) prune_masks[name] mask # 计算该层实际稀疏度 layer_sparsity 1.0 - mask.float().mean().item() print(f“ 完成。实际稀疏度: {layer_sparsity:.4f}”) print(“步骤4: 剪枝完成”) return model, prune_masks运行这个函数你的模型就已经被Wanda方法剪枝了。整个过程除了前向传播收集激活外没有涉及任何梯度计算或优化器步骤因此速度非常快。4. 效果验证稀疏度、速度与精度的三角博弈剪枝完了模型变小了但效果到底如何我们不能只看参数量的减少必须从三个维度来评估模型大小压缩比、推理速度、任务精度。这是一个经典的“不可能三角”我们需要找到平衡点。4.1 评估模型压缩率与加速比首先我们来量化一下剪枝带来的物理收益。def evaluate_compression(model, prune_masks): 评估剪枝后的模型压缩情况。 total_params 0 pruned_params 0 for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and name in prune_masks: mask prune_masks[name] layer_params module.weight.numel() layer_pruned layer_params - mask.sum().item() total_params layer_params pruned_params layer_pruned compression_rate total_params / (total_params - pruned_params) actual_sparsity pruned_params / total_params print(f“剪枝涉及的总参数量: {total_params:,}”) print(f“被剪枝的参数量: {pruned_params:,}”) print(f“实际整体稀疏度: {actual_sparsity:.4f}”) print(f“理论压缩率: {compression_rate:.2f}x (相当于原始大小的 {1/compression_rate*100:.1f}%)”)对于LLaMA-2 7B模型如果对99%的线性层参数施加50%的稀疏度那么大约有35亿个参数被置零。模型文件.bin或.safetensors如果以稀疏格式存储体积可以接近减半。但注意PyTorch默认的torch.save不会自动压缩稀疏矩阵。要实现真正的存储节省需要将掩码和剩余的非零值分开存储或者使用支持稀疏张量的格式。推理速度的加速在非结构化稀疏下需要硬件如NVIDIA的稀疏张量核心和软件库如cuSPARSELt的支持才能实现。在普通GPU上非结构化稀疏可能不会带来加速甚至可能因为内存访问不连续而变慢。这时结构化稀疏N:M Sparsity的优势就体现出来了。4.2 扩展实现结构化稀疏2:4 SparsityWanda论文提到可以轻松扩展到结构化稀疏。结构化稀疏要求每M个连续权重中最多只有N个非零值。NVIDIA Ampere架构及之后的GPU对2:4稀疏每4个连续值中至少2个为零有专门的硬件加速支持。我们修改一下剪枝函数来实现它def prune_layer_structured_nm(weight, activation, N, M): 执行N:M结构化稀疏剪枝。 参数: N, M: 每M个连续权重中保留N个。 out_dim, in_dim weight.shape metric compute_wanda_metric(weight, activation) prune_mask torch.ones_like(weight, dtypetorch.bool) # 将权重矩阵按M个元素一组进行分组处理 # 这里需要一点张量操作技巧 metric_reshaped metric.reshape(out_dim, in_dim // M, M) # 在每组内排序获取每组内最小的(M-N)个索引即要剪枝的 _, group_sorted_indices torch.sort(metric_reshaped, dim2) prune_indices_group group_sorted_indices[:, :, :M-N] # 要剪枝的组内索引 # 将组内索引转换为全局列索引 offset torch.arange(0, in_dim, M).view(1, -1, 1).to(weight.device) global_prune_indices prune_indices_group offset # 生成全局掩码 prune_mask.scatter_(dim1, indexglobal_prune_indices.reshape(out_dim, -1), src0) weight.data * prune_mask return weight, prune_mask在prune_model_wanda函数中你可以根据传入的参数选择使用非结构化剪枝还是结构化剪枝。结构化稀疏牺牲了一点剪枝的灵活性但换来了确定的、可被硬件利用的稀疏模式对于部署至关重要。4.3 精度评估困惑度Perplexity测试对于语言模型困惑度是衡量其语言建模能力的核心指标。我们使用WikiText2验证集来测试剪枝前后的困惑度变化。from datasets import load_dataset import math def evaluate_perplexity(model, tokenizer, dataset_name“wikitext”, dataset_config“wikitext-2-raw-v1”, split“test”): 在指定数据集上计算模型的困惑度。 dataset load_dataset(dataset_name, dataset_config, splitsplit) model.eval() total_loss 0 total_length 0 # 使用滑动窗口评估长文本避免截断 max_length model.config.max_position_embeddings stride 512 with torch.no_grad(): for text in dataset[“text”]: if text.strip() “”: continue inputs tokenizer(text, return_tensors“pt”, truncationFalse) input_ids inputs[“input_ids”].to(model.device) for i in range(0, input_ids.size(1), stride): begin_loc max(i stride - max_length, 0) end_loc min(i stride, input_ids.size(1)) trg_len end_loc - i # 预测的长度 if trg_len 0: continue input_chunk input_ids[:, begin_loc:end_loc] target_chunk input_chunk.clone() # 将输入部分除了最后一个预测位置的标签设为-100以忽略损失 target_chunk[:, :-trg_len] -100 outputs model(input_chunk, labelstarget_chunk) loss outputs.loss * trg_len total_loss loss.item() total_length trg_len ppl math.exp(total_loss / total_length) return ppl # 使用示例 print(“评估原始模型困惑度...”) original_ppl evaluate_perplexity(original_model, tokenizer) print(f“原始模型困惑度: {original_ppl:.2f}”) print(“评估剪枝后模型困惑度...”) pruned_ppl evaluate_perplexity(pruned_model, tokenizer) print(f“剪枝后模型困惑度: {pruned_ppl:.2f}”) print(f“困惑度相对上升: {(pruned_ppl - original_ppl)/original_ppl*100:.2f}%”)在我的实际测试中对LLaMA-2 7B模型施加50%非结构化稀疏剪枝后WikiText2的困惑度从大约5.5上升至6.8左右性能下降约24%。这个结果与论文中的趋势基本一致。性能下降是必然的关键是要看下降的幅度是否在可接受范围内以及用这点性能损失换来的模型体积和潜在速度提升是否划算。5. 实战经验分享与避坑指南纸上得来终觉浅绝知此事要躬行。在多次实践Wanda剪枝后我总结了一些关键经验和容易踩的坑希望能帮你少走弯路。第一校准数据的选择至关重要。论文里用WikiText2没问题但如果你要剪枝一个专用于代码生成的模型那么用代码数据作为校准集可能效果更好。校准数据应该尽量贴近你最终的应用场景。数据量不需要很大128到512条高质量样本通常就够了但样本的多样性要保证不要全是同一主题或风格的文本。第二注意层间差异谨慎处理敏感层。不是所有线性层都适合同样的稀疏度。例如Transformer中的Q查询、K键、V值投影层和输出投影层通常合并为q_proj,k_proj,v_proj,o_proj对模型能力影响较大。你可以尝试对这些层设置更低的稀疏度比如30%而对中间的up_proj,gate_proj,down_proj等层设置更高的稀疏度比如60%。这需要一些实验来调优。绝对不要剪嵌入层embed_tokens和语言模型头lm_head这两层参数少但对性能影响巨大。第三关于激活的范数计算。代码中我们使用了L2范数norm(p2, dim0)。论文中也尝试过L1范数但实验表明L2效果略好。这个细节不必过于纠结使用L2即可。更关键的是收集激活时确保模型处于eval()模式并且使用torch.no_grad()上下文管理器避免不必要的梯度计算和内存消耗。第四剪枝后的模型如何保存与加载最简单的方式是直接model.save_pretrained()。但这样保存的是稠密矩阵零值也会占用空间。一种更高效的方式是保存原始模型权重和剪枝掩码。加载时先加载原始权重再应用掩码。这样虽然多了一个步骤但节省了磁盘空间也便于你尝试不同的稀疏度掩码而无需保留多个模型副本。# 保存模型和掩码 torch.save({‘model_state_dict’: model.state_dict(), ‘prune_masks’: prune_masks}, ‘pruned_model_and_masks.pt’) # 加载时 checkpoint torch.load(‘pruned_model_and_masks.pt’) original_model.load_state_dict(checkpoint[‘model_state_dict’]) # 重新应用剪枝掩码 for name, module in original_model.named_modules(): if name in checkpoint[‘prune_masks’]: module.weight.data * checkpoint[‘prune_masks’][name].to(module.weight.device)第五剪枝不是终点而是起点。Wanda提供了优秀的“初始稀疏化”。如果你对剪枝后的精度损失不满意可以在这个稀疏模型的基础上进行少量的任务特定微调Task-Specific Fine-Tuning。由于模型已经稀疏微调的计算量会比从头微调小很多。这相当于用Wanda做了一次高效的架构搜索找到了一个重要的子网络然后再对这个子网络进行精调往往能以很小的代价恢复大部分性能。最后记得多次实验。尝试不同的稀疏度30% 50% 70%观察困惑度曲线的变化。尝试结构化稀疏看看在支持2:4稀疏的推理引擎上能获得多少实际的加速。模型压缩没有银弹Wanda是一个强大且实用的工具但如何用好它让它为你的具体场景服务还需要你亲手去实验和调整。希望这篇实战指南能帮你顺利上手在实际项目中释放大模型的潜力同时有效控制其成本。如果在操作中遇到具体问题不妨多看看Wanda官方GitHub仓库的Issue和讨论那里有很多宝贵的社区经验。