Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Matrix multiplication in TensotFlow model

I want to use matrix multiplication inside TF model. My model is a NN with input shape = (1,9). And I want to get a product of this vectors by themself (i.e. I want to get a matrix-product equals multiplication of transposed input vector by itself, so its shape equals (9,9)).

Code example:

inputs = tf.keras.layers.Input(shape=(1,9))
outputs = tf.keras.layers.Dense(1, activation='linear')(tf.transpose(inputs) @ inputs)
    
model = tf.keras.Model(inputs, outputs)

adam = keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)

model.compile(optimizer=adam, loss='mse', metrics=['mae'])

But I have problem with shape of such result. In the case of the above code I get a next architecture:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

enter image description here

If I understand correctly, first dimension (None) in the input layer corresponds to size of batch of input data. And when I use transpose operation, it applies to all dimensions in this shape. So I get result with shape (9,1,9) after transpose and multiplication. But I think, that it is not correctly. Because I want to get product of transposed input vector by itself for all vectors in batch (i.e. correct shape for result which I want to get is (None, 9, 9)).

Getting this product as input for the model (compute this multiplication outside this model) is not suitable. Because I want to have in my model original input vector and the result of multiplication to do some operations after (above architecture is not full and using as example).

How can I get correct result? What is correct way to multiply matrices and vectors in TF, if we want to apply this operation to all vectors (matrices) in batch?

>Solution :

Try tf.linalg.matmul, since it will respect the batch dimension:

import tensorflow as tf

inputs = tf.keras.layers.Input(shape=(1,9))
outputs = tf.keras.layers.Dense(1, activation='linear')(tf.linalg.matmul(inputs, inputs, transpose_a=True))
    
model = tf.keras.Model(inputs, outputs)

adam = keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)

model.compile(optimizer=adam, loss='mse', metrics=['mae'])
print(model.summary())
Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_5 (InputLayer)           [(None, 1, 9)]       0           []                               
                                                                                                  
 tf.linalg.matmul_3 (TFOpLambda  (None, 9, 9)        0           ['input_5[0][0]',                
 )                                                                'input_5[0][0]']                
                                                                                                  
 dense_4 (Dense)                (None, 9, 1)         10          ['tf.linalg.matmul_3[0][0]']     
                                                                                                  
==================================================================================================
Total params: 10
Trainable params: 10
Non-trainable params: 0
__________________________________________________________________________________________________
None
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading