Reorder axis in TensorFlow Keras layer

I am building a model that applies a random shuffle to data along the first non batch axis, applies a series of Conv1Ds, then applies the inverse of the shuffle. Unfortunately the tf.gather layer messes up the batch dimension None, and i’m not sure why.

Below is an example of what happens.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

dim = 90
input_img = keras.Input(shape=(dim, 4))

# Get random shuffle order
order = layers.Lambda(lambda x: tf.random.shuffle(tf.range(x)))(dim)

# Apply shuffle
tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))(input_img, order)

model = keras.models.Model(
   inputs=[input_img],
   outputs=tensor,
)

Here the summary is as follows:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)           [(None, 90, 4)]           0         
_________________________________________________________________
lambda_51 (Lambda)           (90, 90, 4)               0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

Whereas I want the output shape of lambda_51 to be (None, 90, 4).

>Solution :

Try to wrap input_img and order into a list when you pass them to tensor layer.

In this way tensor layer becomes:

tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))([input_img, order])

and your summary:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 90, 4)]           0         
_________________________________________________________________
lambda_3 (Lambda)            (None, 90, 4)             0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0

Leave a Reply