Collect Metrics on your Model Training Runs with a Custom Callback

AI
python
data
Author

Jonah Murray

Published

January 22, 2024

Have you ever wanted to keep track of model training metrics, for you or your organization? By using Keras’ Custom Callback functionality, you can collect data on every single model training run and glean valuable metrics from the data. Here’s how:

The AI model training process can be laborious, and through iterating across different model configurations, layer combinations, and hyperparameters, that one perfect combination of factors that you seek can get lost. If we log all our model training runs to a database, we can easily put our finger on the run that we want to use for production, or at least go back to for further iteration. In addition, we can compare metrics across different models, architectures, datasets, teams, and more.

One thing to keep in mind with this article is that although I am showing you how to track some obvious values like number of epochs, training duration, accuracy, and loss, you can track whatever you think is valuable to yourself or your organization. This is simply a basic end-to-end example of tracking model training metrics, and can be used as inspiration or a jumping off point.

I personally love visualizing data, so I’m going to take the opportunity to show you how to use Bokeh to visualize these metrics, but of course you can use whatever you choose to further analyze this data!

In this article I will show you how to:

  • Write a custom callback to be used in any Keras model
  • Write key model training metrics to a database from a callback
  • Run a few different model training examples (with varying parameters) to populate the database
  • Visualize and analyze the data

The Keras framework is one of the most widely used deep learning frameworks available in the open source community. Keras has a cool feature called a Custom Callback that allows you to write functionality at certain points in the model training process.

Example of the Custom Callback class from keras.io:

class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))

Then actually calling it from a model:

model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=1,
    verbose=0,
    validation_split=0.5,
    callbacks=[CustomCallback()],
)

res = model.evaluate(
    x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)

res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])

That’s it! Now you can simply write whatever you code you want into the methods in the custom callback, and your code will be executed when you run the applicable action on your model.

Depending on how complex your callback gets, you may need to pass values from your model to your callback, which means more steps for others using your callback, if you’re deploying this in a larger organization.

As the documentation says, you can write functionality into the starts and ends of actions for training, epochs, batches, fitting, evaluating, and predicting. In this example we will use start and ends of training, and ends of epochs during training.

Let’s do this.

Before we start with the callback, we need to set up the database. Even if you already have a database, you will still need to create a table for this new data. For this example I set up a simple Postgres database on my Mac. To create the local database itself, I did the following:

Install Postgres with homebrew:

brew install postgresql@14

Create a user for myself and a database called MT_METRICS

(Please be careful with your database credentials when doing this in the real world)

psql postgres
CREATE ROLE jonah WITH LOGIN PASSWORD 'getting started';
CREATE DATABASE MT_METRICS;
GRANT ALL PRIVILEGES ON DATABASE MT_METRICS TO jonah;
\c mt_metrics jonah 

Now to create the tables. In this case I created two tables; one for models (models) to hold some information about each model, and one for training (training) to keep track of each training run. The training table uses model_id as a foreign key.

CREATE TABLE models (
    model_id serial primary KEY,
    model_name VARCHAR ( 50 ) UNIQUE NOT NULL,
    created_date timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
    model_type VARCHAR ( 50 ) NOT NULL,
    model_pkg VARCHAR ( 50 ) NOT NULL
);

CREATE TABLE training (
    id serial primary KEY,
    model_id INT NOT NULL,
    training_date timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
    training_duration INT NOT NULL,
    epochs INT NOT NULL,
    learning_rate DECIMAL NOT NULL,
    accuracy DECIMAL NOT NULL,
    loss DECIMAL NOT NULL, 
    accuracy_list DECIMAL[], 
    loss_list DECIMAL[],
    CONSTRAINT fk_model
        FOREIGN KEY(model_id)
            REFERENCES models(model_id)
            ON DELETE SET NULL
);

As you can see, the models table has the model name, the type to give a bit of depth to what the model is, and the package used to create the model (ours will be Keras every time, but could be useful for later development in other packages). Our training table has training duration, epochs, learning rate, accuracy, and loss. We also collect accuracy and loss per epoch which is what accuracy_list and loss_list are collecting. If you are unfamiliar with these terms, read up on their topics here.

Now that we have our database, let’s fill out our callback for our metrics purposes.

We will use environment variables to collect information about the model name and type. This means that for every model training instance, as long as you or anyone in your organization sets their environment variables beforehand, they can train their model as much as they want and fill out training data for the same model in the database.

I recommend using Jupyter notebooks for this work as it makes it super easy to declare objects and see the output immediately.

Example of setting python environment variables:

import os

os.environ['MODEL_NAME'] = "Pima2"
os.environ['MODEL_TYPE'] = "Binary Classification"

As I said earlier, for this callback we will log metrics at the start and end of training, and ends of epochs during training. This means we need to fill out methods on_train_begin, on_train_end, and on_epoch_end . We also need to initialize and update some other values to capture the duration of model training and the accuracy and loss per epoch. Here is the Metrics Callback: (stored in MetricsCallbackKeras.py)

from tensorflow import keras
import tensorflow.keras.backend as K
from datetime import datetime
import os 
from db_util import write_training_metrics

class MetricsCallback(keras.callbacks.Callback):

    def __init__(self, cursor):
        self.cursor = cursor
        self.epochs = 0
        self.start = None
        self.end = None
        self.accuracy_list = []
        self.loss_list = []
        self.duration = 0

    def on_train_begin(self, logs=None):
        self.start = datetime.now()

    def on_train_end(self, logs=None):
        self.end = datetime.now()
        self.duration = (self.end - self.start).total_seconds()
        self.write_metrics(logs)

    def on_epoch_end(self, epoch, logs=None):
        self.epochs += 1
        self.accuracy_list.append(logs['accuracy'])
        self.loss_list.append(logs['loss'])
    
    def write_metrics(self, logs):
        model_name = os.environ.get("MODEL_NAME")
        model_type = os.environ.get("MODEL_TYPE")
        learning_rate = K.eval(self.model.optimizer.lr)

        write_training_metrics(
            self.cursor, 
            model_name, 
            model_type, 
            'keras', 
            self.duration,
            self.params.get("epochs"),
            learning_rate, 
            logs['accuracy'], 
            logs['loss'], 
            self.accuracy_list,
            self.loss_list
        )

To keep the callback clean, I put the actual function to write to the database in a separate file. For the write function, I write an entry for the model into the models table if it doesn’t exist, and then write the training entry to the training table. (stored in db_util.py)

def write_training_metrics(
        cursor, 
        model_name, 
        model_type, 
        model_pkg, 
        duration,
        epochs, 
        lr, 
        accuracy, 
        loss, 
        accuracy_list, 
        loss_list):
    cursor.execute("""
        INSERT INTO MODELS (
        model_name, 
        model_type, 
        model_pkg
        ) VALUES (%s, %s, %s)
        ON CONFLICT (model_name) DO UPDATE SET model_name=EXCLUDED.model_name
        RETURNING model_id
        """, (model_name, model_type, model_pkg))
    
    model_id = cursor.fetchone()[0]

    cursor.execute(f"""
        INSERT INTO TRAINING (
        model_id, 
        training_duration,
        epochs,
        learning_rate,
        accuracy,
        loss,
        accuracy_list,
        loss_list
        ) VALUES (
        {model_id},
        {duration},
        {epochs},
        {lr},
        {accuracy},
        {loss},
        %s,
        %s
        )
        """,
        (accuracy_list,loss_list)
        )

If you looked through the code, you saw that we’re passing a cursor object to do database operations. For a Postgres connection I used psycopg2. Example of creating a psycopg2 cursor:

import psycopg2

conn = psycopg2.connect(database="mt_metrics",
                        host="localhost",
                        user="jonah",
                        password="getting started",
                        port=5432)
conn.autocommit = True

cursor = conn.cursor()

Alright. Now that we have our database running, our environment variables set, our callback written, and our cursor created, we can run a model!

This is a common deep learning example and it is using binary cross entropy to predict diabetes in Pima Indians. Download the dataset from this address and save the file as pima-indians-diabetes.csv in a directory called data.

from numpy import loadtxt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from pathlib import Path
from MetricsCallbackKeras import MetricsCallback

pima_indians_csv = Path('data', 'pima-indians-diabetes.csv')
# load the dataset
dataset = loadtxt(pima_indians_csv, delimiter=',')
# split into input (X) and output (y) variables
X = dataset[:,0:8]
y = dataset[:,8]

# define the keras model
model = Sequential()
model.add(Dense(12, input_shape=(8,), activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# compile the keras model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# fit the keras model on the dataset with the added callback
model.fit(X, y, epochs=15, batch_size=10, callbacks=[MetricsCallback(cursor)])

If everything went well, you should see some output with the epochs and the associated loss and accuracy. We should have some data in our database now! Let’s check it out.

If you’re using a notebook, you can reuse your cursor object from above.

Read the data into a dataframe from a query:

import pandas as pd

df = pd.read_sql_query('select t.model_id, m.model_name, m.model_type, t.training_date, t.training_duration, t.epochs, t.learning_rate, t.accuracy, t.loss, t.accuracy_list, t.loss_list from training t INNER JOIN models m ON t.model_id=m.model_id', con=conn)
df

Output:

model_id model_name model_type training_date training_duration epochs learning_rate accuracy loss accuracy_list loss_list
0 1 pima2 binary classification 2024-01-16 16:05:02.561282 1 15 0.001 0.708333 0.631138 [0.3919271 0.5963542 0.5703125 0.5651042 0.61848956 0.62630206 [21.039516 2.0474312 1.2937006 0.95252365 0.80447465 0.7372218
0.6458333 0.640625 0.6666667 0.6796875 0.68098956 0.67057294 0.68905145 0.7194483 0.6707931 0.62520367 0.66718054 0.6340042
0.67057294 0.6510417 0.7083333 ] 0.6297135 0.6884937 0.63113797]

You should see your single training run as well as the model name and type you provided with your environment variables.

Awesome, so now we have a way to write model training metrics into a database! From here you can do whatever you want with it! Perhaps you want to dig down into some very specific metrics and collect a buttload of data while you train out a particular model. Or maybe you want to deploy something like this in your organization that does model training so that you can collect metrics across different teams and projects to get a wide overall view of model training. For the rest of this article I am going to populate the database with some diverse model training data, and then visualize it with Bokeh.

To populate the database, I grabbed a bunch of model training examples from the examples section on the Keras website. For each model, I wrote out a brief blurb to tell you who created the model, a description, and the link to the page with the model (I DID NOT CREATE ANY OF THESE MODELS), set environment variables, and then built and trained the models using my Metrics Callback. I also set loops for each training instance where I set the number of epochs to be different for each training run. This helps to populate my database with a few different training runs for each model, so we can compare the results. Check out my pop_db.ipynb notebook on Github, as it’s quite long.

If we run the above query again and then look at the table, we should see the data. Or we can simply look at the shape to see how many rows we have.

df.shape

Output:

(16, 9)

Now we have a database that’s populated with a range of data. We’re going to create an interactive Bokeh plot that allows us to compare training runs across multiple axes, as well as zooming in on a training run to see how the accuracy and loss change across epochs. Bokeh is a versatile open source tool that allows you to use Python to create amazing interactive visualizations.

Before we do that, let’s take a look at those lists we stored for accuracy and loss. We can graph them to see how the loss and accuracy changed throughout the training run.

First we need to convert the lists into lists of floats that Pandas will like:

df['accuracy_list'] = df['accuracy_list'].apply(pd.to_numeric, downcast='float')
df['loss_list'] = df['loss_list'].apply(pd.to_numeric, downcast='float')

Let’s visualize just one of the training runs. I’m choosing index 2 just for this example.

idx = 2

x = list(range(1, df.iloc[[idx][0]]['epochs'] + 1))
y1 = df.iloc[[idx][0]]['accuracy_list']
y2 = df.iloc[[idx][0]]['loss_list']

plt.plot(x, y1, label = 'accuracy')
plt.plot(x, y2, label = 'loss')

plt.xlabel('Epochs')
plt.ylabel('Accuracy / Loss')
plt.legend()
plt.title('Accuracy & Loss Per Epoch')
plt.show()

Our Bokeh visualization will allow us to see these per epoch metrics if we select it on the larger canvas.

For the whole Bokeh plot, you can take a look at my code on Github. I highly recommend Bokeh for interactive data visualizations.

Check out an example deployed live here.

In conclusion, this is a pretty specific example of something to build on the ML Ops side of things, but the exercise can help you learn a lot about Machine Learning and the work that goes into it. Feel free to reach out to me if you have questions, comments, or gripes with how I did any of this! Thanks for reading!