网站首页 > 技术文章 正文
【AI 和机器学习】PyTorch BASIC 基础知识:节2
—— 数据集和数据加载器(Datasets & DataLoaders)
前言
—— 哪个更适合初学者?
想要学习并掌握 AI,最直接的办法就是自己动手进行实操。有一些流行的来源可供练习 AI 技能,例如:
- Kaggle:一个托管各种数据集和机器学习竞赛的平台。
- UCI 机器学习存储库:用于机器学习研究的数据集集合。
- TensorFlow 教程:TensorFlow 团队提供的教程和示例。
- PyTorch 教程:PyTorch 团队提供的教程和示例。
其中PyTorch和TensorFlow的AI教程资源非常丰富。但对于初学者来说,哪个更合适,可能还得取决于您的特定目标(研究与生产)以及您的偏好等:
- PyTorch 因其简单、易读和易于调试而通常被认为更适合初学者。PyTorch 的动态特性使新手可以学习概念而不会被复杂的语法所困扰。
- TensorFlow 随着 TensorFlow 2.x 和 Keras 的推出变得更加适合初学者,但它仍可能对初学者构成挑战。
本文先选择PyTorch来和大家一起学习,学习它的一些基础内容。其中所有素材均取自其教程。对于每一节内容,我们都将先给出摘要,然后把译文稍作整理后附在后面,供参考。
目录
【续前文】
本节摘要
本节讨论了 PyTorch 中数据集和数据加载器的使用,强调了将数据集处理与模型训练分离,以便提高可读性和模块化的重要性。PyTorch 提供了两个关键组件:“torch.utils.data.Dataset”用于存储样本及其标签,以及“torch.utils.data.DataLoader”用于方便访问这些样本。
本节解释了如何使用 TorchVision 加载 Fashion-MNIST 数据集(包含 60000 张训练图像和 10000 张测试灰度图像的集合),详细说明了 “root”、“train” 和transformations等参数。本文包含加载数据集和使用 Matplotlib 可视化样本的代码示例。
此外,还介绍了通过实现“__init__”、“__len__”和“__getitem__”方法来创建自定义数据集类,允许用户从指定目录加载图像,并从 CSV 文件加载其标签。
最后,本节重点介绍了 DataLoader 在模型训练期间的批处理、随机排序和加速数据检索功能,从而可以在机器学习流中更加轻松地有效地管理数据集。
本节正文
用于处理数据样本的代码,可能会变得混乱而且难以维护;理想情况下,我们希望数据集代码与我们的模型训练代码解耦,以此提高可读性和模块化。PyTorch 提供了两个数据原语: torch.utils.data.DataLoader 和 torch.utils.data.Dataset,它们允许您使用预加载的数据集以及您自己的数据。Dataset 存储样本及其相应的标签,DataLoader 在 Dataset 周围包装了一个可迭代对象,以便于访问样本。
PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集子类化 torch.utils.data.Dataset 并实现特定于特定数据的功能。它们可以用于对模型进行原型设计和基准测试。您可以在此处找到它们:图像数据集、文本数据集和音频数据集
加载数据集
以下是如何从 TorchVision 加载 Fashion-MNIST 数据集的示例。Fashion-MNIST 是 Zalando 的文章图像数据集,由 60,000 个训练示例和 10,000 个测试示例组成。每个示例都包含一个 28×28 灰度图像和一个来自 10 个类之一的关联标签。
我们使用以下参数加载 FashionMNIST 数据集:
- root 存储训练/测试数据的路径
- train 指定训练或测试数据集
- download=True 如果数据在 root 中不可用,则从 Internet 下载数据。
- transform 和 target_transform 指定特征和标签转换
输出:
迭代和可视化数据集
我们可以像列表一样手动索引 DataDataset:training_data[index]。我们使用 matplotlib 来可视化训练数据中的一些样本。
(踝靴, 衬衫, 包, 踝靴, 长裤, 凉鞋, 外套, 凉鞋, 套头衫)
为您的文件创建自定义数据集
自定义 Dataset 类必须实现三个函数:__init__、__len__ 和 __getitem__。来看一下这个实现;FashionMNIST 图像存储在目录img_dir中,其标签单独存储在 CSV 文件annotations_file中。
在接下来的部分中,我们将分解每个函数中发生的情况。
__init__
__init__ 函数在实例化 Dataset 对象时运行一次。我们初始化包含图像、annotations 文件和两个transform的目录(下一节将更详细地介绍)。
labels.csv 文件如下所示:
__len__
__len__ 函数返回数据集中的样本数。
例:
__getitem__
__getitem__ 函数加载并返回给定索引 idx 处的数据集中的样本。根据索引,它识别图像在磁盘上的位置,使用 read_image 将其转换为张量,从 self.img_labels 中的 csv 数据中检索相应的标签,对其调用transform函数(如果适用),并在元组中返回张量图像和相应的标签。
准备用于DataLoader 进行训练的数据
Dataset 检索我们数据集的特征,并一次标记一个样本。在训练模型时,我们通常希望以 “小批量” 的形式传递样本,在每个 epoch 重新洗牌数据以减少模型过度拟合,并使用 Python 的multiprocessing来加快数据检索速度。
DataLoader 是一个可迭代对象,它通过一个简单的 API 为我们抽象了这种复杂性。
通过DataLoader 迭代
我们已将该数据集加载到 DataLoader 中,并可以根据需要迭代数据集。下面的每次迭代都会返回一批 train_features 和 train_labels(分别包含 batch_size=64 个特征和标签)。因为我们指定了 shuffle=True,所以在我们迭代所有批次后,数据会被随机排序(要对数据加载顺序进行更精细的控制,请查看 Samplers)。
输出:
延伸阅读:torch.utils.data API 接口
【未完待续】
农历甲辰十月廿五
2024.11.25
【部分图片来源网络,侵删】
- 上一篇: 一次性把JVM讲清楚,别再被面试官问住了
- 下一篇: 使用Python实现水质预测
猜你喜欢
- 2024-12-08 使用Python实现智能医疗影像分析与诊断
- 2024-12-08 Python中的可视化:使用Seaborn绘制常用图表
- 2024-12-08 GPT-4 + Canvas + o1 preview — 数据分析的量子飞跃!
- 2024-12-08 「jupyter」Google Colab使用外部数据的几种方法
- 2024-12-08 论文观点竟没数据支持?别犯傻!用ChatGPT分析论文数据
- 2024-12-08 谷歌狂卷小模型,20亿参数Gemma 2赶超GPT-3.5
- 2024-12-08 论文观点没数据?ChatGPT帮你整理论文数据,快速提高论文质量!
- 2024-12-08 使用Python实现水质预测
- 最近发表
- 标签列表
-
- cmd/c (90)
- c++中::是什么意思 (84)
- 标签用于 (71)
- 主键只能有一个吗 (77)
- c#console.writeline不显示 (95)
- pythoncase语句 (88)
- es6includes (74)
- sqlset (76)
- apt-getinstall-y (100)
- node_modules怎么生成 (87)
- chromepost (71)
- flexdirection (73)
- c++int转char (80)
- mysqlany_value (79)
- static函数和普通函数 (84)
- el-date-picker开始日期早于结束日期 (76)
- js判断是否是json字符串 (75)
- c语言min函数头文件 (77)
- asynccallback (87)
- localstorage.removeitem (74)
- vector线程安全吗 (70)
- java (73)
- js数组插入 (83)
- mac安装java (72)
- 无效的列索引 (74)