别再用MNIST了!用AG_NEWS新闻数据集入门PyTorch文本分类(附完整代码)

张开发
2026/4/15 18:28:07 15 分钟阅读

分享文章

别再用MNIST了!用AG_NEWS新闻数据集入门PyTorch文本分类(附完整代码)
告别MNIST用AG_NEWS新闻数据集开启PyTorch文本分类实战当大多数深度学习教程还在用MNIST手写数字识别作为Hello World示例时我们是否想过——2023年了为什么不用更贴近真实场景的数据开始学习文本分类作为自然语言处理的基础任务在新闻分类、情感分析、垃圾邮件过滤等领域有着广泛应用。而AG_NEWS新闻数据集正是你从图像处理转向文本处理的理想跳板。1. 为什么选择AG_NEWS而非MNISTMNIST作为经典数据集确实有其教学价值但它也存在明显局限过于简单28x28的灰度图像10个类别现代模型轻松达到99%准确率脱离现实手写数字识别在实际应用中已很少见无法迁移学到的技巧难以应用到其他计算机视觉任务相比之下AG_NEWS具有以下优势特性AG_NEWSMNIST数据类型真实新闻文本手写数字图像类别数4 (World, Sports, Business, Sci/Tech)10 (0-9数字)样本量120k训练7.6k测试60k训练10k测试应用场景新闻分类、主题检测手写数字识别学习价值文本预处理、嵌入表示基础CV流程提示AG_NEWS由康奈尔大学收集包含2000多家新闻来源的文章是研究文本分类的黄金标准数据集之一。2. 快速搭建AG_NEWS文本分类流程2.1 数据加载与探索首先确保安装必要库pip install torch torchtext加载数据集并查看样本分布from torchtext.datasets import AG_NEWS train_iter, test_iter AG_NEWS(split(train, test)) class_names [World, Sports, Business, Sci/Tech] # 查看前5个样本 for i, (label, text) in zip(range(5), train_iter): print(f{class_names[label-1]}: {text[:100]}...)典型输出示例Business: Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Streets dwindling... Sports: Tigers Win World Series The Detroit Tigers won the World Series for the first time in 72 years... Sci/Tech: Microsoft Releases New Zune Player Microsoft has released a new version of its Zune media player...2.2 文本预处理流水线文本分类的关键步骤是构建有效的特征表示分词将句子拆分为单词或子词单元构建词表创建单词到索引的映射数值化将文本转换为数字序列批处理统一序列长度便于模型处理实现代码示例from torchtext.data.utils import get_tokenizer from torchtext.vocab import build_vocab_from_iterator tokenizer get_tokenizer(basic_english) def yield_tokens(data_iter): for _, text in data_iter: yield tokenizer(text) vocab build_vocab_from_iterator(yield_tokens(train_iter), specials[unk]) vocab.set_default_index(vocab[unk]) text_pipeline lambda x: vocab(tokenizer(x)) label_pipeline lambda x: int(x) - 12.3 构建数据加载器为高效训练我们需要创建DataLoaderfrom torch.utils.data import DataLoader import torch def collate_batch(batch): label_list, text_list [], [] for (_label, _text) in batch: label_list.append(label_pipeline(_label)) processed_text torch.tensor(text_pipeline(_text), dtypetorch.int64) text_list.append(processed_text) return torch.tensor(label_list), torch.nn.utils.rnn.pad_sequence( text_list, batch_firstTrue, padding_value0) train_loader DataLoader(list(train_iter), batch_size8, shuffleTrue, collate_fncollate_batch)3. 从Baseline模型到实战技巧3.1 基础文本分类模型一个简单的嵌入全连接架构import torch.nn as nn class TextClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, num_class): super().__init__() self.embedding nn.Embedding(vocab_size, embed_dim) self.fc nn.Linear(embed_dim, num_class) def forward(self, text): embedded self.embedding(text) # [batch, seq_len, embed_dim] pooled embedded.mean(dim1) # 平均池化 return self.fc(pooled)训练循环设置def train(model, optimizer, criterion, epochs5): model.train() for epoch in range(epochs): for labels, texts in train_loader: optimizer.zero_grad() outputs model(texts) loss criterion(outputs, labels) loss.backward() optimizer.step()3.2 提升模型性能的实用技巧技巧1调整嵌入维度# 较小的嵌入维度适合简单任务 model TextClassifier(len(vocab), embed_dim64, num_class4)技巧2尝试不同池化方式# 替换mean pooling为max pooling pooled embedded.max(dim1)[0]技巧3添加正则化optimizer torch.optim.Adam(model.parameters(), lr0.001, weight_decay1e-5)4. 评估与结果分析使用测试集评估模型性能def evaluate(model): model.eval() correct, total 0, 0 with torch.no_grad(): for labels, texts in DataLoader(list(test_iter), batch_size8, collate_fncollate_batch): outputs model(texts) predicted outputs.argmax(1) total labels.size(0) correct (predicted labels).sum().item() return correct / total accuracy evaluate(model) print(fTest Accuracy: {accuracy:.2%})典型结果对比模型配置测试准确率训练时间(epoch)EmbedDim3275-78%~1分钟EmbedDim6478-82%~1.5分钟EmbedDim12880-83%~2分钟在实际项目中遇到的一个常见问题是类别不平衡。检查AG_NEWS的类别分布from collections import Counter label_counts Counter() for label, _ in train_iter: label_counts[label] 1 print(label_counts)典型输出Counter({3: 30000, 4: 30000, 2: 30000, 1: 30000})

更多文章