Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

predict all test batches in keras / tensorflow

i am trying to predict all my test batches in keras / tensorflow to then plot a confusion matrix.
The current BATCH_SIZE is: 32

My test dataset is generated with the following code from a big dataset:

test_dataset = big_dataset.skip(train_size).take(test_size)
test_dataset = test_dataset.shuffle(test_size).map(augment).batch(BATCH_SIZE)

After model.compile() and model.fit() i get the predictions and the correct labels with this code:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

points, labels = list(test_dataset)[0]
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
points = points.numpy()

This method only predicts one batch –> 32 predictions.

Is there a way to predict all test batches in keras / tensorflow?

Thanks in advance!

>Solution :

You could pass the entire dataset to model.predict according to the docs:

Input samples. It could be: A Numpy array (or array-like), or a list
of arrays (in case the model has multiple inputs). A TensorFlow
tensor, or a list of tensors (in case the model has multiple inputs).
A tf.data dataset. A generator or keras.utils.Sequence instance. A
more detailed description of unpacking behavior for iterator types
(Dataset, generator, Sequence) is given in the Unpacking behavior for
iterator-like inputs section of Model.fit.

points = test_dataset.map(lambda x, y: x)
labels = test_dataset.map(lambda x, y: y)
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)

Or with numpy:

points = np.concatenate(list(test_dataset.map(lambda x, y: x))
labels = np.concatenate(list(test_dataset.map(lambda x, y: y))
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading