Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

How to show more images than the batch size value?

I have the following code:

train_ds = tf.keras.utils.image_dataset_from_directory(
  '/media/Tesi',
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(360, 360),
  batch_size=18)

class_names = train_ds.class_names

val_ds = tf.keras.utils.image_dataset_from_directory(
  '/media/Tesi',
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(360, 360),
  batch_size=18)

num_classes = len(class_names)

Then I create a model and make some probabilities. When I show the images in val_ds, my code is:

plt.figure(figsize=(20, 20))
for images, _ in val_ds.take(1):
    for i in range(18):
        ax = plt.subplot(6, 6, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[np.argmax(predictions[i])])
        plt.axis("off")

In this way I show always the first 18 images of val_ds. How can I show the images for example from index 18 to 36?
thanks

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

>Solution :

You can use tf.data.Dataset.skip and tf.data.Dataset.take:

import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 18

val_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(360, 360),
  batch_size=batch_size, shuffle=False)

for images, _ in val_ds.skip(1).take(1):
  for i in range(18):
    ax = plt.subplot(6, 6, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.axis("off")

In this example, the first 18 images are skipped (1 batch) and afterwards you take the next 18 images (also 1 batch). You just need to make sure that shuffle=False, in order to make sure that you do not get the same images when calling take(1).

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading