How can I retrieve the first N items from a TensorFlow batch dataset, and not an iterator that reevaluates to different items?

I would like to retrieve the first N items from a BatchDataSet. I have tried a number of different ways to do this, and they all retrieve different items when reevaluated. However I would like to retrieve N actual items, not an iterator that will continue to retrieve new items.

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
ds = tf.keras.utils.image_dataset_from_directory(
    "Images", 
    validation_split=0.2,
    seed=123,
    subset="training")

# Attempt to retrieve 9 items
test_ds = ds.take(9)

# Plot the 9 items and their labels
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

#
# AGAIN, plot the 9 items and their labels
# NOTE: This will show 9 different images, and my expectation is 
# that it should show the same images as above.
# 
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

>Solution :

Iterating over a tf.data.Dataset will trigger shuffling every time. You could set shuffle to False to get deterministic results:

import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(64, 64),
  batch_size=1,
  shuffle=False)

# Attempt to retrieve 9 items
test_ds = ds.take(9)

class_names = ['a', 'b', 'c', 'd', 'e']
# Plot the 9 items and their labels
plt.figure(figsize=(4, 4))
for i, (images, labels) in enumerate(test_ds):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(images[0, ...].numpy().astype("uint8"))
  plt.title(class_names[labels.numpy()[0]])
  plt.axis("off")

plt.figure(figsize=(4, 4))
for i, (images, labels) in enumerate(test_ds):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(images[0, ...].numpy().astype("uint8"))
  plt.title(class_names[labels.numpy()[0]])
  plt.axis("off")

enter image description here
enter image description here

If you are interested in other data samples, you can just use the methods tf.data.Dataset.skip and tf.data.Dataset.take.

Leave a Reply