设为首页收藏本站
网站公告 | 这是第一条公告
     

 找回密码
 立即注册
缓存时间21 现在时间21 缓存数据 不是每个人都愿意承担你的负能量,也不是每个人都愿听你的焦灼不安。只有真正爱你的人,才会对你的处境感同身受。

不是每个人都愿意承担你的负能量,也不是每个人都愿听你的焦灼不安。只有真正爱你的人,才会对你的处境感同身受。

查看: 810|回复: 0

python中关于CIFAR10数据集的使用

[复制链接]

  离线 

TA的专栏

  • 打卡等级:热心大叔
  • 打卡总天数:197
  • 打卡月天数:0
  • 打卡总奖励:3636
  • 最近打卡:2023-08-27 07:16:33
等级头衔

等級:晓枫资讯-上等兵

在线时间
33 小时

积分成就
威望
0
贡献
297
主题
390
精华
0
金钱
4830
积分
707
注册时间
2022-12-26
最后登录
2023-8-27

发表于 2023-2-2 13:19:01 | 显示全部楼层 |阅读模式
关于CIFAR10数据集的使用

主要解决了如何把数据集与transforms结合在一起的问题。


CIFAR10的官方解释
  1. torchvision.datasets.CIFAR10(
  2. root: str,
  3. train: bool = True,
  4. transform: Optional[Callable] = None,
  5. target_transform: Optional[Callable] = None,
  6. download: bool = False)
复制代码

注释:

  • root (string)存在 cifar-10-batches-py 目录的数据集的根目录,如果下载设置为 True,则将保存到该目录。
  • train (bool, optional)如果为True,则从训练集创建数据集, 如果为False,从测试集创建数据集。
  • transform (callable, optional)它接受一个 PIL 图像并返回一个转换后的版本。 例如,transforms.RandomCrop/transforms.ToTensor
  • target_transform (callable, optional) 接收目标并对其进行转换的函数/转换。
  • download (bool, optional)如果为 true,则从 Internet 下载数据集并将其放在根目录中。 如果数据集已经下载,则不会再次下载。

实战操作

1.CIAFR10数据集的下载

代码如下:

  1. import torchvision   #导入torchvision这个类

  2. train_set = torchvision.datasets.CIFAR10(root = "./dataset", train = True,
  3. download= True)  #从训练集创建数据集
  4. test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False,
  5. download=True)    #从测试集创建数据集
复制代码

root = "./dataset",将下载的数据集保存在这个文件夹下;download= True,从 Internet 下载数据集并将其放在根目录中,这里就是在相对路径中,创建dataset文件夹,将数据集保存在dataset中。

2.查看下载的CIAFR10数据集

运行程序,开始下载数据集。下载成功后,可以进行一些查看。代码如下:

接着输入:

  1. print(train_set[0])  #查看train_set训练集中的第一个数据
  2. print(train_set.classes)   #查看train_set训练集中有多少个类别

  3. img, target = train_set[0]
  4. print(img)
  5. print(target)
  6. print(train_set.classes[target])
  7. img.show()  #显示图片
复制代码

输出结果:

(<PIL.Image.Image image mode=RGB size=32x32 at 0x161E924B8D0>, 6)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship',
'truck']
<PIL.Image.Image image mode=RGB size=32x32 at 0x161E924B710>
6
frog

注释:可以看见,train_set数据集中有10个类别,train_set中第0个元素的target是6,也就是说,这个元素是属于第7个类别frog的。

3.数据转换

因为这些图片类型都是PIL Image,如果要供给pytorch使用的话,需要将数据全都转化成tensor类型。

完整代码如下:

  1. import torchvision   #导入torchvision这个类
  2. from torch.utils.tensorboard import SummaryWriter

  3. from torchvision import transforms
  4. dataset_transforms = transforms.ToTensor()

  5. # dataset_transforms = torchvision.transforms.Compose([
  6. #     torchvision.transforms.ToTensor()
  7. # ])    第3  4 行代码可以用compose直接写
  8. train_set = torchvision.datasets.CIFAR10(root = "./dataset", train = True, transform=dataset_transforms, download= True) #训练集
  9. test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transforms, download=True)   #测试集

  10. writer = SummaryWriter("logs")

  11. # print(train_set[0])  #查看train_set训练集中的第一个数据
  12. # print(train_set.classes)   #查看train_set训练集中有多少个类别

  13. # img, target = train_set[0]
  14. # print(img)
  15. # print(target)
  16. # print(train_set.classes[target])
  17. # img.show()
  18. for i in range(20):
  19.     img, target = train_set[i]
  20.     writer.add_image("cifar10_test2", img, i)

  21. writer.close()
复制代码

小结:CIFAR10数据集内存很小,只有100多m,下载方便。对我们学习数据集非常友好,练习的时候,我们可以使用SummaryWriter来将数据写入tensorboard中。


CIFAR-10 数据集简介

复现代码的过程中,简单了解了作者使用的数据集CIFAR-10 dataset ,简单记录一下。

CIFAR-10数据集是8000万微小图片的标签子集,它的收集者是:Alex Krizhevsky, Vinod Nair, Geoffrey Hinton。

2023020114501745.jpg

数据集由6万张32*32的彩色图片组成,一共有10个类别。每个类别6000张图片。其中有5万张训练图片及1万张测试图片。

数据集被划分为5个训练块和1个测试块,每个块1万张图片。

测试块包含了1000张从每个类别中随机选择的图片。训练块包含随机的剩余图像,但某些训练块可能对于一个类别的包含多于其他类别,训练块包含来自各个类别的5000张图片。

这些类是完全互斥的,及在一个类别中出现的图片不会出现在其它类中。


数据集版本

作者提供了3个版本的数据集:python version; Matlab version; binary version。

可根据自己的需求选择。

数据集下载地址:下载链接


数据集布置

以python version进行介绍,Matlab version与之相同。

下载后获得文件 data_batch_1, data_batch_2,…, data_batch_5。测试块相同。这些文件中的每一个都是用cPickle生成的python pickled对象。

具体使用方法:

  1. def unpickle(file):
  2.     import pickle
  3.     with open(file, 'rb') as fo:
  4.         dict = pickle.load(fo, encoding='bytes')
  5.     return dict
复制代码

返回字典类,每个块的文件包含一个字典类,包含以下元素:

  • data: 一个100003072的numpy数组(unit8)每个行存储3232的彩色图片,3072=1024*3,分别是red, green, blue。存储方式以行为主。
  • labels:使用0-9进行索引。

数据集包含的另一个文件batches.meta同样包含python字典,用于加载label_names。如:label_names[0] == “airplane”, label_names[1] == “automobile”



晓枫资讯-科技资讯社区-免责声明
免责声明:以上内容为本网站转自其它媒体,相关信息仅为传递更多信息之目的,不代表本网观点,亦不代表本网站赞同其观点或证实其内容的真实性。
      1、注册用户在本社区发表、转载的任何作品仅代表其个人观点,不代表本社区认同其观点。
      2、管理员及版主有权在不事先通知或不经作者准许的情况下删除其在本社区所发表的文章。
      3、本社区的文章部分内容可能来源于网络,仅供大家学习与参考,如有侵权,举报反馈:点击这里给我发消息进行删除处理。
      4、本社区一切资源不代表本站立场,并不代表本站赞同其观点和对其真实性负责。
      5、以上声明内容的最终解释权归《晓枫资讯-科技资讯社区》所有。
http://bbs.yzwlo.com 晓枫资讯--游戏IT新闻资讯~~~
严禁发布广告,淫秽、色情、赌博、暴力、凶杀、恐怖、间谍及其他违反国家法律法规的内容。!晓枫资讯-社区
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

手机版|晓枫资讯--科技资讯社区 本站已运行

CopyRight © 2022-2025 晓枫资讯--科技资讯社区 ( BBS.yzwlo.com ) . All Rights Reserved .

晓枫资讯--科技资讯社区

本站内容由用户自主分享和转载自互联网,转载目的在于传递更多信息,并不代表本网赞同其观点和对其真实性负责。

如有侵权、违反国家法律政策行为,请联系我们,我们会第一时间及时清除和处理! 举报反馈邮箱:点击这里给我发消息

Powered by Discuz! X3.5

快速回复 返回顶部 返回列表