Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Reorder axis in TensorFlow Keras layer

I am building a model that applies a random shuffle to data along the first non batch axis, applies a series of Conv1Ds, then applies the inverse of the shuffle. Unfortunately the tf.gather layer messes up the batch dimension None, and i’m not sure why.

Below is an example of what happens.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

dim = 90
input_img = keras.Input(shape=(dim, 4))

# Get random shuffle order
order = layers.Lambda(lambda x: tf.random.shuffle(tf.range(x)))(dim)

# Apply shuffle
tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))(input_img, order)

model = keras.models.Model(
   inputs=[input_img],
   outputs=tensor,
)

Here the summary is as follows:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)           [(None, 90, 4)]           0         
_________________________________________________________________
lambda_51 (Lambda)           (90, 90, 4)               0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

Whereas I want the output shape of lambda_51 to be (None, 90, 4).

>Solution :

Try to wrap input_img and order into a list when you pass them to tensor layer.

In this way tensor layer becomes:

tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))([input_img, order])

and your summary:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 90, 4)]           0         
_________________________________________________________________
lambda_3 (Lambda)            (None, 90, 4)             0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading