利用torch.utils.data.Dataset自定义数据加载类
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)

import torch as t from torch.utils import data import os from PIL import Image import numpy as np import torchvision.transforms as T transforms = T.Compose([ T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) # 继承Dataset类要重写__getitem__ 在深度学习领域,高效的数据加载和预处理是训练模型的关键环节。`torch.utils.data.Dataset` 是 PyTorch 提供的一个接口,允许用户自定义数据集类来加载和处理自己的数据。这个类需要重写 `__getitem__()` 和 `__len__()` 方法,以满足特定数据集的需求。本示例中,我们将探讨如何利用 `torch.utils.data.Dataset` 创建一个自定义的数据加载类,用于区分猫狗图片的数据集。 导入所需的库: ```python import torch as t from torch.utils import data import os from PIL import Image import numpy as np import torchvision.transforms as T ``` `torchvision.transforms` 模块提供了一系列图像预处理操作,如 `Resize`, `CenterCrop`, `ToTensor`, `Normalize` 等。这些操作在训练神经网络模型时非常常见,可以将图片转换为模型需要的格式。例如,在本例中,我们创建了一个转换器 `transforms`: ```python transforms = T.Compose([ T.Resize(224), # 将图片调整为 224x224 的大小 T.CenterCrop(224), # 对图片中心进行裁剪,保持 224x224 的尺寸 T.ToTensor(), # 将 PIL 图片转换为 PyTorch 张量 T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # 归一化张量,减去均值并除以标准差 ]) ``` 接下来,定义自定义数据集类 `CatDog`,它继承自 `data.Dataset`: ```python class CatDog(data.Dataset): def __init__(self, root, transforms=None): imgs = os.listdir(root) # 获取根目录下的所有文件名 self.imgs = [os.path.join(root, img) for img in imgs] # 构建完整的文件路径 self.transforms = transforms # 存储预处理变换 def __getitem__(self, index): # 必须重写的方法,返回索引对应的图片及其标签 img_path = self.imgs[index] label = 1 if 'dog' in img_path else 0 # 假设图片名称包含 'dog' 表示狗,否则表示猫 data = Image.open(img_path) # 打开图片 if self.transforms: # 如果有预处理变换,则应用 data = self.transforms(data) return data, label # 返回处理后的图片和对应的标签 def __len__(self): # 必须重写的方法,返回数据集的长度(图片数量) return len(self.imgs) ``` 在这个类中,`__init__()` 方法初始化数据集,读取指定根目录下的所有文件,并保存它们的完整路径。`__getitem__()` 方法根据索引返回图片和相应的标签。在这个例子中,我们简单地通过检查文件名是否包含 "dog" 来判断图片类别,实际项目中通常需要更精确的标注信息。`__len__()` 方法返回数据集中图片的数量。 为了使用这个数据加载类,你需要实例化 `CatDog` 并传入图片的根目录以及可选的预处理变换: ```python root = '/path/to/your/dataset' # 替换为你的数据集路径 dataset = CatDog(root=root, transforms=transforms) ``` 现在你可以使用这个数据集来训练模型了。通常,我们还会使用 `DataLoader` 类来批量加载数据,这可以进一步提高训练效率: ```python dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True) ``` 这里,`DataLoader` 会按指定的 `batch_size` 批量加载数据,并且可以随机打乱数据顺序,这对于训练过程是有益的。 总结来说,`torch.utils.data.Dataset` 为自定义数据加载提供了便利。通过继承该类并重写 `__getitem__()` 和 `__len__()` 方法,我们可以灵活地处理各种类型的数据,并结合 `DataLoader` 实现高效、批量的数据加载,以适应深度学习模型的训练需求。
























- 粉丝: 6
我的内容管理 展开
我的资源 快来上传第一个资源
我的收益
登录查看自己的收益我的积分 登录查看自己的积分
我的C币 登录后查看C币余额
我的收藏
我的下载
下载帮助


最新资源
- 软件测试简历的自我评价(1).docx
- 互联网信息时代的人工智能应用(1).docx
- 第九章-软件测试(1).ppt
- 弹幕视频网站传播分析(1).docx
- 网络传媒推广系统软件设计文档(1)(1).doc
- 课程设计基于单片机红外防盗报警器的设计(1).doc
- 【推荐下载】宝马工厂里的智能机器人-高度自动化提升质量与效率(1).doc
- excel合并工作簿和工作表的代码(1).doc
- 基于互联网+视域下大学生创新创业教育路径研究(1).docx
- 塞曼效应计算机辅助软件设计论文(1)(1).docx
- 网站前台设计与实现(毕业论文)(1).doc
- 单片机电子称优秀课程设计.doc
- 2023年自考项目管理软件重点(1).docx
- 中职计算机教学实践中存在的问题和对策研究(1).docx
- 基于MATLAB的ASK调制解调实现(1).doc
- 企业信息化常见缩略语.docx



评论2