当前位置: 首页 > news >正文

Softmax回归模型

这段代码是一个完整的 Softmax回归模型 实现,用于解决 Fashion-MNIST数据集的图像分类问题。简单来说,它的作用是:让计算机通过学习大量衣服、鞋子等服饰图片,学会识别新的服饰图片属于哪一类(比如T恤、裤子、运动鞋等,共10个类别)。

具体来说,代码做了这几件事:

  1. 准备数据:加载Fashion-MNIST数据集(包含6万张训练图和1万张测试图),并把图片转换成模型能处理的格式(张量),再分成批量供模型训练。
  2. 定义模型:用Softmax回归模型(一种简单的神经网络),先把28×28的图片“摊平”成784个数字,再通过一个全连接层输出10个结果(对应10个类别的“分数”)。
  3. 设置训练工具:用“交叉熵损失函数”衡量模型预测的错误程度,用“随机梯度下降(SGD)”优化器调整模型参数,让模型慢慢学会正确分类。
  4. 训练模型:重复10轮训练,每轮都用训练集数据让模型“学习”:
    • 先让模型预测图片类别,计算错误(损失);
    • 再根据错误调整模型参数;
    • 最后用测试集检验模型的分类能力(准确率)。
  5. 输出结果:每轮训练后打印损失值和准确率,最终模型在测试集上的准确率约85%,说明它能较好地识别服饰类别。
    简单讲,这是一个“教计算机认衣服”的程序,用的是最基础的深度学习方法(Softmax回归),适合入门理解神经网络的训练过程。

1.导入需要的工具库

点击查看代码
import torch  # PyTorch核心库:处理张量(类似数组但能在GPU上运行)和深度学习基础功能
from torch import nn  # 神经网络模块:提供各种层(如全连接层)、损失函数等
from torch.utils.data import DataLoader  # 数据加载器:帮我们批量处理数据,方便训练
from torchvision import datasets, transforms  # 计算机视觉工具:提供现成数据集和数据转换工具
from d2l import torch as d2l  # D2L工具库:提供一些辅助函数(如计算准确率的工具)
2.数据加载与预处理
点击查看代码
# 数据转换规则:把图像转成Tensor格式(PyTorch能处理的格式),同时自动把像素值从0-255变成0-1
transform = transforms.Compose([transforms.ToTensor()])
batch_size = 256  # 每次训练时一次性喂给模型256张图片(批量大小)# 加载训练数据集(Fashion-MNIST,包含6万张衣服、鞋子等图片)
train_dataset = datasets.FashionMNIST(root='./data',        # 数据存在当前文件夹的data文件夹里train=True,           # 这是训练集(用来训练模型的)download=True,        # 如果本地没有数据,就自动下载(约30MB)transform=transform   # 用上面定义的规则处理图片
)# 加载测试数据集(1万张图片,用来检验模型好坏)
test_dataset = datasets.FashionMNIST(root='./data',train=False,          # 这是测试集(不参与训练,只用来评估)download=True,transform=transform
)# 创建训练数据加载器(把数据分成批次,打乱顺序)
train_iter = DataLoader(train_dataset,batch_size=batch_size,  # 每次256张shuffle=True,           # 训练时打乱顺序,让模型学的更全面num_workers=0           # 单进程加载(Windows系统用多进程会出错,所以设为0)
)# 创建测试数据加载器(不需要打乱顺序)
test_iter = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,          # 测试时不需要打乱,按顺序来就行num_workers=0
)
3.模型定义与参数初始化
点击查看代码
# 定义模型:用Sequential把层按顺序拼起来
net = nn.Sequential(nn.Flatten(),          # 展平层:把28×28的图片(二维)变成784个数字的一维数组nn.Linear(784, 10)     # 全连接层:把784个输入转换成10个输出(对应10个类别)
)# 定义参数初始化函数(给模型的"权重"赋值初始值)
def init_weights(m):# 如果是全连接层(nn.Linear),就初始化它的权重if type(m) == nn.Linear:# 权重用均值0、标准差0.01的正态分布随机数初始化nn.init.normal_(m.weight, std=0.01)# 偏置(类似y=ax+b里的b)默认初始化为0,不用额外设置# 把上面的初始化规则应用到模型的所有层
net.apply(init_weights)
4.损失函数与优化器
点击查看代码
# 定义损失函数:CrossEntropyLoss(自带Softmax功能)
# 作用:比较模型预测结果和真实标签的差距,输出"损失值"
# reduction='none':返回每个样本的损失,不自动求平均
loss = nn.CrossEntropyLoss(reduction='none')# 定义优化器:SGD(随机梯度下降,最常用的优化方法)
# 作用:根据损失调整模型的参数(权重和偏置)
trainer = torch.optim.SGD(net.parameters(),  # 需要优化的参数(模型里的权重和偏置)lr=0.1             # 学习率:控制参数调整的幅度(0.1是比较经典的值)
)
5. 模型评估函数(计算准确率)
点击查看代码
def evaluate_accuracy(net, data_iter):# 如果模型是PyTorch的标准模型,就切换到"评估模式"# (有些层如Dropout在训练和测试时行为不同,这里确保是测试模式)if isinstance(net, torch.nn.Module):net.eval()  # 切换到评估模式# 累加器:记录2个数据——正确预测的数量、总样本数量metric = d2l.Accumulator(2)# 关闭梯度计算(评估时不需要训练,节省内存)with torch.no_grad():# 遍历数据集中的每一批数据for X, y in data_iter:# 计算当前批次中正确预测的数量,累加到metric# d2l.accuracy(net(X), y):模型预测结果和真实标签对比,返回正确数# y.numel():当前批次的总样本数(比如256)metric.add(d2l.accuracy(net(X), y), y.numel())# 准确率 = 正确预测数 / 总样本数return metric[0] / metric[2]
6.模型训练循环
点击查看代码
num_epochs = 10  # 训练10轮(把整个训练集重复用10次来训练)# 遍历每一轮训练
for epoch in range(num_epochs):# 累加器:记录3个数据——总损失、正确预测数、总样本数metric = d2l.Accumulator(3)net.train()  # 切换到训练模式(和评估模式对应,规范写法)# 遍历训练集中的每一批数据for X, y in train_iter:trainer.zero_grad()  # 梯度清零:每次计算前把上一轮的梯度清空y_hat = net(X)       # 前向传播:把输入X喂给模型,得到预测结果y_hatl = loss(y_hat, y)   # 计算损失:比较预测y_hat和真实标签y的差距l.mean().backward()  # 反向传播:计算每个参数的梯度(损失对参数的影响)trainer.step()       # 优化器更新参数:根据梯度调整权重和偏置# 关闭梯度计算,累加当前批次的指标with torch.no_grad():metric.add(l.sum(),          # 累加当前批次的总损失d2l.accuracy(y_hat, y),  # 累加当前批次的正确预测数y.numel()         # 累加当前批次的总样本数)# 计算当前轮次的关键指标train_loss = metric[0] / metric[2]  # 平均训练损失 = 总损失 / 总样本数train_acc = metric[1] / metric[2]   # 训练集准确率 = 正确数 / 总样本数test_acc = evaluate_accuracy(net, test_iter)  # 测试集准确率# 打印训练日志(保留3位小数,方便观察)print(f"epoch {epoch + 1:2d} | loss: {train_loss:.3f} | train_acc: {train_acc:.3f} | test_acc: {test_acc:.3f}")
7.预期输出
点击查看代码
epoch  1 | loss: 0.785 | train_acc: 0.746 | test_acc: 0.796
...
epoch 10 | loss: 0.447 | train_acc: 0.849 | test_acc: 0.851
http://icebutterfly214.com/news/2429/

相关文章:

  • 2025年上海衣帽间定制机构权威推荐榜单:衣帽间设计/衣帽间十大品牌/衣帽间装修源头公司精选
  • 2025年市场上小程序开发公司Top10权威推荐
  • JVM内存启动问题
  • 报纸阅读神器:支持多日期多版面自由切换,本地保存更方便
  • 2025 年不锈钢无缝管源头厂家最新推荐榜:重质守信企业盘点,覆盖多材质多行业适配与高性价比选购参考
  • windows下安装Nginx,并配置成服务
  • 什么是MII
  • 2025年广东会议室话筒设备服务商权威推荐榜单:红外线会议话筒/会议麦克风扬声器/会议麦克风音响源头公司精选
  • 2025年口碑好的铜芯电缆公司排行榜:鑫佰亿线缆领跑行业
  • 2025/10/27
  • 从案例看网站建设价值:卓越迈创如何适配企业多元化需求,深圳外贸网站建设公司,深圳企业网站建设公司推荐
  • 2025年阻燃输送带生产厂家权威推荐榜单:尼龙输送带/三叶输送带/输送带源头厂家精选
  • Transformers
  • 如何使用 vxe-table 展开行实现展开子表父子表格
  • 2025年成品岗亭供货厂家权威推荐榜单:成品门卫亭/小区保安亭/执法岗亭源头厂家精选
  • 2025年超导电缆制造厂权威推荐榜单:铜线电缆/感温电缆/国标电缆源头厂家精选
  • 2025 年铝塑板厂家最新推荐榜,从技术研发到市场服务多维度考量,企业综合实力与产品竞争力深度剖析网纹/磨砂/大理石/木纹/幻彩铝塑板公司推荐
  • ArkTS语言(三)
  • Legacy模式虚拟机,grub文件丢失如何处理
  • 魔域电脑版下载安装教程:重返经典魔幻世界的完整攻略(含新手入门+登录异常修复)
  • 什么是跨网域资源共享(CROS)
  • 2025最新AI智能体学习路线图
  • kettle调度系统-kettle嵌入集成方式调度,稳如磐石,一分钟完成任务调度配置
  • 用Go语言从零开始开发一个Prometheus Exporter
  • Linux服务器感染病毒,如何处理?
  • 2025年贵州推拿正骨培训机构权威推荐榜单:小儿按摩培训/小儿推拿培训/穴位敷贴培训源头机构精选
  • 2025年装饰装修公司推荐
  • 2025年数控对头钻批发厂家权威推荐:数控龙门镗铣床/数控双面镗/数控双面镗铣床源头厂家精选
  • 2025年10月兰花油品牌综合评估榜:市场主流产品真实效果横向测评
  • 2025年河北注册公司系统权威推荐:衡水公司注册公司/河北企业注册优化/衡水公司注销方案服务平台精选