Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Pytorch DataLoader changes dict return values

Given a Pytorch dataset that reads a JSON file as such:

import csv

from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader2, DataLoader

class MyDataset(IterableDataset):
    def __init__(self, jsonfilename):
        self.filename = jsonfilename

    def __iter__(self):
        with open(self.filename) as fin:
            reader = csv.reader(fin)
            headers = next(reader)
            for line in reader:
                yield dict(zip(headers, line))
                
content = """imagefile,label
train/0/16585.png,0
train/0/56789.png,0"""

with open('myfile.json', 'w') as fout:
    fout.write(content)

ds = MyDataset("myfile.json")

When I loop through the dataset, the return values are dict of each line of the json, e.g.

ds = MyDataset("myfile.json")

for i in ds:
    print(i)

[out]:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

{'imagefile': 'train/0/16585.png', 'label': '0'}
{'imagefile': 'train/0/56789.png', 'label': '0'}

But when I read the Dataset into a DataLoader, it returns the values of the dict as lists instead of the values themselves, e.g.

ds = MyDataset("myfile.json")
x = DataLoader(dataset=ds)

for i in x:
    print(i)

[out]:

{'imagefile': ['train/0/16585.png'], 'label': ['0']}
{'imagefile': ['train/0/56789.png'], 'label': ['0']}

Q (part1) : Why does the DataLoader changes the value of the dict to a list?

and also

Q (part2) : How to make the DataLoader return just the values of the dict instead of the list of value when running __iter__ with the DataLoader? Is there some arguments/options to use in DataLoader to do this?

>Solution :

The reason is the default collate behaviour in torch.utils.data.DataLoader, which determines how data samples in a batch are merged. By default, the torch.utils.data.default_collate collate function is used, which transforms mappings as:

Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]

and strings as:

str -> str (unchanged)

Note that if you set batch_size to 2 in your example, you get:

{'imagefile': ['train/0/16585.png', 'train/0/56789.png'], 'label': ['0', '0']}

as a consequence of these transforms.

Assuming you do not need batching, you can get your desired output by disabling it by setting batch_size=None. More information on this here: Loading Batched and Non-Batched Data.

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading