Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

How to retrieve Dataset object from DataLoader?

I have a PyTorch DataLoader and want to retrieve the Dataset object that the loader wraps around. If this is possible, how? Or does the dataset object only exist for pre-loaded datasets on torch?

The end goal is to easily integrate data in dataloader format into code setup for a dataset format (e.g. CIFAR10).

Where in the original code there is:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset

def get_dataset(dataset, split):
    if dataset == "CIFAR10"
        return _cifar10(split)

def _cifar10(split: str) -> Dataset:
    if split == "train":
        return datasets.CIFAR10("./dataset_cache", train=True, download=True)

dataset = get_dataset("CIFAR10", train)
for i in range(len(dataset)):
    ...

I have tried importing the whole dataset at once:

from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset

def get_dataset(dataset, split):
    if dataset == "CIFAR10"
        return _cifar10(split)
    elif dataset == "mydataset"
        return _mydataset(split)

def _mydataset(split: str) -> Dataset:
    files = [file for file in os.listdir(database_directory + '/' + split)]
    total_num_images = 0
    for file in files:
        number_images = len([name for name in os.listdir(database_directory +
            '/' + split + '/' + file)])
        total_num_images += number_images
    if split == "train":
        mydataset = torch.utils.data.DataLoader(
            datasets.ImageFolder(dataset_directory + '/train'),batch_size=total_num_images)
        return mydataset

dataset = get_dataset("mydataset", train)
for i in range(len(dataset)):
    ...

But this returns the error ‘DataLoader’ object is not subscriptable.

>Solution :

You can access the dataset attribute on data.DataLoader to get its underlying data.Dataset object. As seen in the source code here.

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading