Keras Callbacks Tutorial for Training Your Neural Networks Efficiently
Last Updated on July 18, 2023 by Editorial Team
Author(s): Muttineni Sai Rohith
Originally published on Towards AI.
Training of Neural Networks can take many hours or even few days to complete. so we need some function to monitor and control our model.Because after hour/days of training, if a model crashes, then all the training time gets wasted. And once we choose some epochs and training is started, we may need to stop the training to avoid overfitting or if we have achieved some least loss and if it is increasing afterwards we need to halt the training and etc.,. So we must need some function to monitor and control our model once the training is started and thatβs what this article is about β Callbacks
In general definition, A Callback is an object in Keras that can perform actions at various stages of training.
A callback can be called before the starting or ending of an epoch, before or after a single batch, etc. We can use callbacks to do early stopping, save our model to disk periodically, get a view on internal states and get statistics of the model during training, write logs after every batch of training, etc.,
Callbacks Usage:
- Defining Callbacks
#EarlyStopping
early_stop = tf.keras.callbacks.EarlyStopping(
monitor='val_loss', min_delta=0.001, patience=3, verbose=1,
mode='min'
)#ModelCheckpoint
checkpoint = ModelCheckpoint(filepath ,monitor='val_loss',
mode='min',save_best_only=True,verbose=1
)
Donβt worry if you donβt get them now, we will go detailed on all of them below.
2. Passing the Callbacks to model.fit()
model.fit(X_train,y_train,epochs=10,validation_data=(X_test,y_test),
callbacks = [early_stop, checkpoint])
Letβs go through some of the most used callbacks.
ModelCheckpoint:
We use this callback to save our model periodically so that we wonβt waste our training time if our training crashes unexpectedly. Also, we can make use of the intermediate best-stored weights and load them later to continue the training from the saved state.
Syntax:
tf.keras.callbacks.ModelCheckpoint(
filepath,
monitor = "val_loss",
verbose = 0,
save_best_only = False,
save_weights_only = True,
mode = "auto",
save_freq="epoch",
options=None,
initial_value_threshold=None,
**kwargs
)
filepath β Location to save the model
monitor β Metric which has to be monitored Ex:(βval_lossβ, βval_accuracyβ, βlossβ, βaccuracyβ)
verbose β If 1, then it displays a message when callback action is taken and vice-versa if 0
save_best_only β If True saves the model only when it thinks it is βbestβ, by comparing the performance using the monitored metric.
save_weights_only β If True only weights will be saved.
mode β (βautoβ, βminβ, βmaxβ) For accuracy, it should be βmax,β and for loss, it should be βminβ. If it is βautoβ it can infer the mode by using the name of the metric.
save_freq β if βepochβ saves for every epoch else if integer n, saves after every nth batch.
TerminateOnNaN
This callback terminates the training when a NaN loss occurs.
tf.keras.callbacks.TerminateOnNaN()
EarlyStopping:
EarlyStopping is a callback used while training neural networks, which provides us the advantage of using a large number of training epochs and stopping the training once the modelβs performance stops improving on the validation Dataset.
tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0,
patience=0,
verbose=0,
mode="auto",
baseline=None,
restore_best_weights=False,
)
For a detailed explanation of EarlyStopping, refer to my article below β
Keras EarlyStopping Callback to train the Neural Networks Perfectly
Early stopping is a method that helps you in avoiding overfitting and underfitting while training Neural Networks.
muttinenisairohith.medium.com
ReduceLROnPlateau
This Callback will reduce the learning rate(lr) if there is no improvement. Models often benefit by reducing learning rates. Using this Callback will monitor the specified metric, and if there is no improvement in the βpatienceβ number of epochs, the learning rate will be reduced.
Syntax:
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.1,
patience=10,
verbose=0,
mode="auto",
min_delta=0.0001,
cooldown=0,
min_lr=0,
**kwargs
)
monitor β Metric which has to be monitored Ex:(βval_lossβ, βval_accuracyβ, βlossβ, βaccuracyβ)
factor β factor by which the learning rate will be reduced. new_lr = lr * factor
.
patience β number of epochs with no improvement after which lr is reduced
verbose β If 1, then it displays a message when callback action is taken and vice-versa if 0
mode β (βautoβ, βminβ, βmaxβ) For accuracy, it should be βmaxβ and for loss, it should be βminβ. If it is βautoβ it can infer the mode by using the name of the metric.
min_delta β used to focus only on a significant change.
cooldownβ number of epochs to wait before resuming normal operation after lr has been reduced.
min_lr β lower bound on the learning rate.
LearningRateScheduler
A simple callback function is used to tweak the learning rate over a while. We can write a function to change the learning rate based on epochs or some condition and can pass it as an argument to this callback.
Syntax with an example:
def scheduler(epoch, lr):
if epoch % 10 == 0:
return lr * tf.math.exp(-0.1)callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
This will change the learning rate for every 10 epochs.
Tensorboard
It is a visualization tool provided with Tensorflow. This Callback allows us to visualize information regarding the training process like Metrics, Training graphs, Activation function histograms, and other distribution of gradients. To use tensorboard we first need to set up a log_dir where the tensorboard files get saved to.
log_dir="logs"
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir, histogram_freq=1, write_graph=True)
log_dir β directory to which the files are saved
histogram_freq β epochs frequency for which the histogram and gradient maps are computed
write_graph β whether we need to display and visualize graphs in the tensorboard
We have discussed a few Callbacks. There are other callbacks, such as
- BackupAndRestore β To backup and restore the particular training state
- RemoteMonitor β Callback used to stream the events to the server
- CSVLogger β Callback that streams epoch results to a CSV file.
- LambdaCallback β Callback for creating simple, custom callbacks on-the-fly.
- ProgbarLogger β Callback that prints metrics to stdout.
So thatβs what Keras Callbacks is all about. Make sure to use them next time when you are training Neural Networks.
If you like this article, Go to my feed and read my other articles. They might be helpful.
References:
Happy coding β¦.
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI