Keras Callbacks

Keras Callbacks are incredibly useful while training neural networks especially when running large models on Azure (these days) that I know will cost me time, computation and $.

I pretty much used to painfully run models across all epochs before I discovered this gem. The official documentation describes it best but its essentially a setting to let your model know when to stop training post a self-set threshold of accuracy. The callback I usually end up using is the ModelCheckPoint as well as the EarlyStopping.

Model Checkpointing is used with to check-in weights at a defined interval so the model be loaded at the state that it was saved in. Some great scenarios also saving it at stages based on certain value you can monitor such as accuracy (val_acc), loss (applied to train set), val_loss(applied to test set).

The advantage of using ModelCheckPoint versus save_weights or just save is that it can save the whole model or just the weights depending on the state.

Detailed parameters here in the source code:

When we call fit() on the the model for training , Keras calls the following functions:

  • on_train_begin and on_train_end being called at the beginning and the end of training respectively. on_test_begin and on_test_end being called at the beginning and the end of evaluation respectively.
  • on_predict_begin and on_predict_end at the beginning and at the end of the prediction  respectively.

In addition , the baselogger class also accumulates the epoch average of metrics with  on_epoch_begin, on_epoch_end, on_batch_begin, on_batch_end – this gives us flexibility, for example, for Early Stopping to be called at the end of every epoch and to compare current value with the best value until then.

Detailed parameters here in the source code:

You can use EarlyStopping callback to stop training when the val_acc stops increasing, else the model with overfit on the data. You could also see this as cases where the loss keeps decreasing while the val_loss increases or stays stagnant. What usually works is to start with something like below and plot the error loss with and without early stopping.

A combination of both approaches:

Plenty more callbacks in the documentation:

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.