Pytorch DataLoader changes dict return values

Given a Pytorch dataset that reads a JSON file as such:

import csv

from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader2, DataLoader

class MyDataset(IterableDataset):
    def __init__(self, jsonfilename):
        self.filename = jsonfilename

    def __iter__(self):
        with open(self.filename) as fin:
            reader = csv.reader(fin)
            headers = next(reader)
            for line in reader:
                yield dict(zip(headers, line))
                
content = """imagefile,label
train/0/16585.png,0
train/0/56789.png,0"""

with open('myfile.json', 'w') as fout:
    fout.write(content)

ds = MyDataset("myfile.json")

When I loop through the dataset, the return values are dict of each line of the json, e.g.

ds = MyDataset("myfile.json")

for i in ds:
    print(i)

[out]:

{'imagefile': 'train/0/16585.png', 'label': '0'}
{'imagefile': 'train/0/56789.png', 'label': '0'}

But when I read the Dataset into a DataLoader, it returns the values of the dict as lists instead of the values themselves, e.g.

ds = MyDataset("myfile.json")
x = DataLoader(dataset=ds)

for i in x:
    print(i)

[out]:

{'imagefile': ['train/0/16585.png'], 'label': ['0']}
{'imagefile': ['train/0/56789.png'], 'label': ['0']}

Q (part1) : Why does the DataLoader changes the value of the dict to a list?

and also

Q (part2) : How to make the DataLoader return just the values of the dict instead of the list of value when running __iter__ with the DataLoader? Is there some arguments/options to use in DataLoader to do this?

>Solution :

The reason is the default collate behaviour in torch.utils.data.DataLoader, which determines how data samples in a batch are merged. By default, the torch.utils.data.default_collate collate function is used, which transforms mappings as:

Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]

and strings as:

str -> str (unchanged)

Note that if you set batch_size to 2 in your example, you get:

{'imagefile': ['train/0/16585.png', 'train/0/56789.png'], 'label': ['0', '0']}

as a consequence of these transforms.

Assuming you do not need batching, you can get your desired output by disabling it by setting batch_size=None. More information on this here: Loading Batched and Non-Batched Data.

Leave a Reply