Skip to content

Keras Callbacks - Extending Their Scope And Usage

As a transitioning piece between our to groups of blogs, this blog is still work in progress.

It concerns the various callbacks that we can register during training sessions.

As automating our sessions is an important objective of ours, specifically to enable us to fine-tune our training hyper-parameters, we will be adding to this blog as the rest of the next group’s blogs materialize.

Just as before, we need to prep our environment to run any meaningful code:

from datetime import datetime
import pathlib as pth
import tensorflow as tf
import dataset as qd
import custom as qc
ks = tf.keras
kl = ks.layers

Updates to our model

And now we are ready to define our model.

As this blog focuses on the actual training process, our model can be reused directly from a previous blog:

def model_for(ps):
    x = [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    x += [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    x += [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    y = qc.ToRagged()(x)
    y = qc.Frames(ps)(y)
    embed = qc.Embed(ps)
    ye = qc.Encode(ps)(embed(y[:2]))
    yd = qc.Decode(ps)(embed(y[2:]) + [ye[0]])
    y = qc.Debed(ps)(yd)
    m = ks.Model(name='callbacks', inputs=x, outputs=y)
    m.compile(optimizer=ps.optimizer, loss=ps.loss, metrics=[ps.metric])
    print(m.summary())
    return m

Using callbacks

Once the model is defined, we adjust our main calling function.

At this point we define our callbacks that should be kept “in loop” during our training session.

Initially we still want to include the standard Keras TensorBoard callbacks.

Additionally, we want to roll our own checkpointing. We choose to use the latest Checkpoint and CheckpointManager classes (see our blog regarding this topic).

For this we define a custom Keras Callback class called CheckpointCB. As this callback is only used to save or update our current checkpoint file, it only needs to override the on_epoch_end callback.

In the override it simply calls the manager’s save method.

To be expanded…

def main_graph(ps, ds, m):
    b = pth.Path('/tmp/q')
    b.mkdir(parents=True, exist_ok=True)
    lp = datetime.now().strftime('%Y%m%d-%H%M%S')
    lp = b / f'logs/{lp}'
    c = tf.train.Checkpoint(model=m)
    mp = b / 'model' / f'{m.name}'
    mgr = tf.train.CheckpointManager(c, str(mp), max_to_keep=3)
    # if mgr.latest_checkpoint:
    #     vs = tf.train.list_variables(mgr.latest_checkpoint)
    #     print(f'\n*** checkpoint vars: {vs}')
    c.restore(mgr.latest_checkpoint).expect_partial()

    class CheckpointCB(ks.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            mgr.save()

    cbs = [
        CheckpointCB(),
        ks.callbacks.TensorBoard(
            log_dir=str(lp),
            histogram_freq=1,
        ),
    ]
    m.fit(ds, callbacks=cbs, epochs=ps.num_epochs)

We may also need to update our parameters as they relate to our “callback objectives”.

To be expanded…

params = dict(
    dim_batch=5,
    dim_dense=150,
    dim_hidden=6,
    dim_stacks=2,
    dim_vocab=len(qd.vocab),
    loss=ks.losses.SparseCategoricalCrossentropy(from_logits=True),
    metric=ks.metrics.SparseCategoricalCrossentropy(from_logits=True),
    num_epochs=5,
    num_shards=2,
    optimizer=ks.optimizers.Adam(),
    width_dec=15,
    width_enc=25,
)

Training session

And now we are ready to start our training session.

We can confirm the model’s layers and connections. We can easily adjust the parameters to tailor the length of the sessions to our objectives.

ps = qd.Params(**params)
main_graph(ps, qc.dset_for(ps), model_for(ps))
  Model: "callbacks"
  __________________________________________________________________________________________________
  Layer (type)                    Output Shape         Param #     Connected to
  ==================================================================================================
  input_1 (InputLayer)            [(None,)]            0
  __________________________________________________________________________________________________
  input_2 (InputLayer)            [(None,)]            0
  __________________________________________________________________________________________________
  input_3 (InputLayer)            [(None,)]            0
  __________________________________________________________________________________________________
  input_4 (InputLayer)            [(None,)]            0
  __________________________________________________________________________________________________
  input_5 (InputLayer)            [(None,)]            0
  __________________________________________________________________________________________________
  input_6 (InputLayer)            [(None,)]            0
  __________________________________________________________________________________________________
  to_ragged (ToRagged)            [(None, None), (None 0           input_1[0][0]
                                                                   input_2[0][0]
                                                                   input_3[0][0]
                                                                   input_4[0][0]
                                                                   input_5[0][0]
                                                                   input_6[0][0]
  __________________________________________________________________________________________________
  frames (Frames)                 [(5, 25), (None,), ( 125         to_ragged[0][0]
                                                                   to_ragged[0][1]
                                                                   to_ragged[0][2]
  __________________________________________________________________________________________________
  embed (Embed)                   multiple             120         frames[0][0]
                                                                   frames[0][1]
                                                                   frames[0][2]
                                                                   frames[0][3]
  __________________________________________________________________________________________________
  encode (Encode)                 [(None, 25, 6), (Non 90516       embed[0][0]
                                                                   embed[0][1]
  __________________________________________________________________________________________________
  decode (Decode)                 [(None, 15, 6), (Non 54732       embed[1][0]
                                                                   embed[1][1]
                                                                   encode[0][0]
  __________________________________________________________________________________________________
  debed (Debed)                   (None, None, None)   140         decode[0][0]
                                                                   decode[0][1]
  ==================================================================================================
  Total params: 145,633
  Trainable params: 145,508
  Non-trainable params: 125
  __________________________________________________________________________________________________
  None
  Epoch 1/5
  20/20 [==============================] - 15s 733ms/step - loss: 1.3225 - sparse_categorical_crossentropy: 1.3311
  Epoch 2/5
  20/20 [==============================] - 0s 6ms/step - loss: 1.2037 - sparse_categorical_crossentropy: 1.2163
  Epoch 3/5
  20/20 [==============================] - 0s 6ms/step - loss: 1.2067 - sparse_categorical_crossentropy: 1.2187
  Epoch 4/5
  20/20 [==============================] - 0s 6ms/step - loss: 1.1113 - sparse_categorical_crossentropy: 1.1200
  Epoch 5/5
  20/20 [==============================] - 0s 6ms/step - loss: 1.0234 - sparse_categorical_crossentropy: 1.0336

A quick ls into our /tmp/q/model/callbacks checkpoint directory shows that our manager is in fact updating the checkpoint files and it is keeping only the last three, just as we expect.

To be expanded…