Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Training a tensorflow model with an intermediate function call in training loop

I am trying to train a simple neural network where the input data is taken from a matlab simulink simulation and the output is then fed back into a different matlab simulink simulation. My code is as follows:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random


def get_pid_values():
    # call simulink model that just produces PID values

    return random.random()


def get_plant(intermediate_val):
    # get plant output.
    return random.random()


class CustomDataGen(tf.keras.utils.Sequence):
    
    def __init__(self, df, X_col,
                 batch_size,
                 input_size=(1,),
                 shuffle=True):
        
        self.df = df.copy()
        self.X_col = X_col
        self.batch_size = batch_size
        self.input_size = input_size
        self.shuffle = shuffle
        
        self.n = len(self.df)
    
    def __get_input(self, index):
        # Need to adjust this to support retrieving ref voltage.
        return self.df[self.X_col].iloc[index]
    
    def on_epoch_end(self):
        if self.shuffle:
            self.df = self.df.sample(frac=1).reset_index(drop=True)
    
    def __getitem__(self, index):
        X = self.__get_input(index)
        return X
    
    def __len__(self):
        return self.n // self.batch_size


def get_model(input_shape, hidden, output_shape):
    inputs = keras.layers.Input(shape=input_shape)
    x = layers.Dense(hidden, activation="relu")(inputs)
    x = layers.Dense(hidden, activation='relu')(x)
    outputs = layers.Dense(output_shape)(x)
    model = keras.Model(inputs=inputs, outputs=outputs, name="pid-modifier")
    return model


loss_object = tf.keras.losses.MeanSquaredError()

def loss(y_ref, y_plant):
  y_ = y_plant
  y = y_ref
  return loss_object(y_true=y, y_pred=y_)


if __name__ == "__main__":

    # Hyperparameters
    lr = 0.01
    num_epochs = 1
    hidden_size = 4
    net_input_size = 1
    net_output_size = 1
    batch_size = 1
    reference_fpath = "Run2_rThrottleTarget.csv"

    references = pd.read_csv(reference_fpath)

    data_generator = CustomDataGen(df=references, X_col='Throttle', batch_size=1)

    # Keep results for plotting
    train_loss_results = []


    # Initialize optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

    # error intitial condition
    err = 0

    # instantiate model
    model = get_model(input_shape=(2,), hidden=hidden_size, output_shape=net_output_size)

    for epoch in range(num_epochs):

        for ref in data_generator:

            with tf.GradientTape() as tape:
                tape.watch(model.trainable_variables)
                
                # Get pid values
                pid = get_pid_values()
                
                # Group ref with pid voltage for input
                input = tf.constant([[ref, pid]])
                
                # Get the adjusted voltage from the network
                intermediate_val = model(input)

                # Get the plant output based on the adjusted value.
                plant = get_plant(intermediate_val)

                plant = tf.constant([plant], dtype=tf.float64)
                ref = tf.constant([ref], dtype=tf.float64)
                
                # Calculate loss 
                loss_value = loss(ref, plant)

            grads = tape.gradient(loss_value, model.trainable_weights)
                
            optimizer.apply_gradients(zip(grads, model.trainable_weights))

            err = ref - plant


        if epoch % 50 == 0:
            print("Epoch {:03d}: Loss: {:.3f}".format(epoch, loss_value))
    

    fig, axes = plt.subplots(1, figsize=(12, 8))
    fig.suptitle('Training Metrics')

    axes[0].set_ylabel("Loss", fontsize=14)
    axes[0].plot(train_loss_results)

    plt.show()

For the moment I am just mocking the calls to simulink by returning a random number. My problem is that when I take the model output and then call the function that mocks a call to simulink and calculate my loss:

# Get the adjusted voltage from the network
intermediate_val = model(input)

# Get the plant output based on the adjusted value.
plant = get_plant(intermediate_val)

plant = tf.constant([plant], dtype=tf.float64)
ref = tf.constant([ref], dtype=tf.float64)

# Calculate loss 
loss_value = loss(ref, plant)

I get the error ValueError: No gradients provided for any variable. I’ve figured out that if I pass the model’s output directly to the loss function everything works fine. My question is how can I have the intermediate step of passing my model’s output to another function and using the returned value to calculate loss?

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

>Solution :

a gradient exists between intermediate_val and model.trainable_variables as it is calculated by back propagation, the tape however cannot perform back-propagation on ref because it wasn’t calculated by tensorflow, it’s just a constant to it, it has no gradient.

since the model knows nothing about the relation between the loss and how it is generated, this becomes a case of reinforcement learning, which can be done using the tensorflow-agents module.

this is a tutorial about it on youtube Everything You Need To Master Actor Critic Methods | Tensorflow 2 Tutorial , it’s about a certain network architecture but its gradient calculation method is exactly the same as your case, the code is easily adaptabe.

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading