[TOC]

PyTorch Dataset and DataLoader

check the video

collate_fn(batch)

1
2
3
4
5
6
7
8
9
    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels

collate_fn would settle how to output the data batch

1
2
3
a  =[(1,2), (3,4)]
tuple(zip(*a))
Out[4]: ((1, 3), (2, 4))

ConcatDataset (list(Dataset))

can be used for data that is stored in different files.

the ConcatDataset will automatically concatenate each Dataset efficiently

IterableDataset

An iterable-style dataset is an instance of a subclass of IterableDataset that implements the __iter__() protocol, and represents an iterable over data samples. This type of datasets is particularly suitable for cases where random reads are expensive or even improbable, and where the batch size depends on the fetched data.

For example, such a dataset, when called iter(dataset), could return a stream of data reading from a database, a remote server, or even logs generated in real time.

Dataset for pickle

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
from torch.utils.data import Dataset
from torch.utils.data import ConcatDataset
import pickle

class SingleFileDataSet(Dataset):

    def __init__(self, file_pathname, transform=None):

        self.file_pathname = file_pathname
        self.transform = transform # transform is needed for rgb 0 ~ 255 image data
        # import data
        with open(file_pathname, 'rb') as f:
            self.episode_arr = pickle.load(f)