首页 > 服务端开发 > 聊聊pytorch中的DataLoader
2019
11-15

聊聊pytorch中的DataLoader

实际上pytorch在定义dataloader的时候是需要传入很多参数的,比如,number_workers, pin_memory, 以及shuffle, dataset等,其中sampler参数算是其一

sampler实际上定义了torch.utils.data.dataloader的数据取样方式,什么意思呢?

在自己定义dataset中的__getitem__函数的时候,每一个index,唯一的对应一个样本,sampler实际上就是一系列的index组成的可迭代对象

如下图所示的__iter__函数返回的可迭代对象

下图所示的是randomsampler,即随机的shuffle图片的index,然后取样,关键就一句话,在__iter__中的torch.randperm(n).tolist()

表明产生了一个0到n-1的一个list

 聊聊pytorch中的DataLoader - 第1张  | 逗分享开发经验

 

比如我的数据是128张图片,然后,dataset中的__len__也是128,我的sampler,如果不shuffle的话,其中的index从0-127没毛病

如果将这个顺序打乱,那就是相当于随机取样,和上图一样

也就是说,每个图片都定义了唯一的一个index,取图的时候按照sampler定义的规则来取图,实际上这样就可以做一些有意思的事情了

比如我的batch size是2,我想每一个batch取的图片是第一张和紧挨着的后面的一张图,假设sampler不shuffle的话,那么__iter__返回的可迭代对象应该是

iter([0,1,1,2,2,3,3,4,4,5,.....])没毛病

再比如,我想我的batch是4,隔一张取一下图片,那么我的sampler的函数返回的__iter__应该是iter[0,2,4,6,1,3,5,7,2,4,6,8,.......]

 

但是这又有什么用?比如按照我上面的第一种取图的需求,我完全可以在__getitem__中定义下一个index使得和上面一个读到的图像一样。用处就在这里,

假设是定义一个相同的index,读取相同的图片是非常的占用内存的,比如imagenet,读完放到内存里面,大概是需要一百多个g,按照我刚刚的第一个例子读取,内存就需要2倍,实际上这是不允许的,通过定义sampler,对于一张图片重复采样,比对于一张图片读取多遍显然要划算的多

最后编辑:
作者:搬运工
这个作者貌似有点懒,什么都没有留下。