Skip to content

Custom Keras Layers Without The Drawbacks

In this blog we continue to train our computer to “understand” elementary symbolic arithmetic.

We slightly change our approach however. Instead of having fixed input/context for our encoder/decoder stacks, we follow the idea of “sliding contexts” from this paper.

In addition, we continue to architect our model with a mixture of “validated” Keras layers as well as light-weight Modules containing bare TF ops.

Our objective is to ultimately arrive at a model representable by the graph and runnable example.

Just as before, we need to prep our environment to run any meaningful code:

import tensorflow as tf
import dataset as qd
import layers as ql
ks = tf.keras
kl = ks.layers

Before we start, we need to increase our dataset slightly as the training steps are becoming more meaningful.

Calling the dump_dset function with a parameters instance will update our stored sharded binary files:

def dump_dset(ps):
    ps.max_val = 10000
    ps.num_samples = 1000  # 100000
    ps.num_shards = 10
    fs = [f for f in qd.dump(ps)]
    ps.dim_batch = 100
    for i, _ in enumerate(qd.load(ps, fs).map(adapter)):
    print(f'dumped {i} batches of {ps.dim_batch} samples each')
    return fs

For verification purposes, loading our already created meta data from the sources gives us:

print(qd.SPC, qd.SEP, qd.STP)
  (' ', ':', '|', 'x', 'y', '=', ',', '+', '-', '*', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
  0 1 2

“Formatter” expanded

We also need to expand our previously used formatter.

As we intend to concatenate subsequent inputs in our “sliding context”, we need to end the result feature, res, of our samples with our STP = "|" token.

We have started to use tf.debugging.asserts to increase our confidence in the correctness of our data. Later we will be able to switch these out with the familiar Python asserts.

Our formatter comes with an other significant adjustment.

We intend to feed both our encoder and our decoder with inputs. Namely, the encoder gets the concatenated defs and op features, while the decoder gets either a fully or a partially blanked res.

Our dataset will supply an enc, a dec and a tgt (the full correct result of the math expression in the sample) tensors. The rand_blank function does the quick (inline) random blanking, or masking, of the arithmetic result to be fed into our decoder as the des tensor.

def formatter(d):
    ds = tf.RaggedTensor.from_sparse(d['defs'])
    n = ds.nrows()
    os = tf.RaggedTensor.from_sparse(d['op'])
    tf.debugging.assert_equal(n, os.nrows())
    ss = tf.fill([n, 1], qd.SEP)
    enc = tf.concat([ds, ss, os, ss], axis=1)
    rs = tf.RaggedTensor.from_sparse(d['res'])
    tf.debugging.assert_equal(n, rs.nrows())
    tgt = tf.concat([rs, tf.fill([n, 1], qd.STP)], axis=1)

    def rand_blank(x):
        y = x.flat_values
        mv = tf.shape(y)[0]
        s = mv // 2
        i = tf.random.uniform([s], maxval=mv, dtype=tf.int32)[:, None]
        y = tf.tensor_scatter_nd_update(y, i, tf.zeros([s], dtype=tf.int32))
        return x.with_flat_values(y)

    return {'enc': enc, 'dec': rand_blank(tgt), 'tgt': tgt}

In order for our dataset to be usable, we also need to update our adapter.

We continue to split our input ragged tensors into their components, and as we now have 3 ragged inputs: enc, dec and tgt, the total number of dense input tensors to our model will be 6.

The adapter needs to also supply our tgt dense tensor to the canned loss and metrics components that drive the gradient calculations.

In addition, we chose to add tgt, or its two components, to our inputs as well. This duplication gives us the chance of feeding correct arithmetic results into our “sliding context”.

def adapter(d):
    enc, dec, tgt = d['enc'], d['dec'], d['tgt']
    return (

Dataset adapter

Our new dataset creator function, dset_for is as follows.

We have added an optionally overridable adapter argument to be used later.

def dset_for(ps, adapter=adapter):
    ds =
    ds = ds.take(100).batch(ps.dim_batch)
    fs = {
    ds = x:, fs)).map(qd.caster)

As we now have 3 pairs of input tensors, that we need to convert back into RaggedTensors, we quickly add a ToRagged convenience layer that can be seamlessly eliminated once the Keras Inputs start properly supporting the ragged=True keyword argument:

class ToRagged(kl.Layer):
    def call(self, x):
        ys = []
        for i in range(3):
            i *= 2
            fv, rs = x[i:i + 2]
            ys.append(tf.RaggedTensor.from_row_splits(fv, rs))
        return ys

Frames layer

The Frames layer is the new significant addition to our code.

With every new encoder input sequence of tokens xe, it first concatenates the prev stored context with xe and then stores the result in ye.

Then it updates the prev variable with the concatenation of ye and the passed in correct arithmetic result xt. The resulting prev is to be used in the next cycle.

The computations are slightly more complex due to using the raggedness of the inputs to satisfy the continuous, seamlessly “sliding context” requirement.

The layer also returns the “row_lengths” tensors for both enc and dec inputs. They will be used later for propagating the input token sequences’ “raggedness”.

The entire Frames layer works exclusively with tokens, as we don’t want to keep stale embedding calculations around in our “sliding context”.

class Frames(ql.Layer):
    def __init__(self, ps):
        super().__init__(ps, dtype=tf.int32)  # , dynamic=True)
        s = (ps.dim_batch, ps.width_enc)
        kw = dict(initializer='zeros', trainable=False, use_resource=True)
        self.prev = self.add_weight('prev', shape=s, **kw)

    def call(self, x):
        xe, xd, xt = x
        ye = tf.concat([self.prev, xe], axis=1)
        el = tf.cast(xe.row_lengths(), dtype=tf.int32)
        ye = tf.gather_nd(ye, self.calc_idxs(el))
        c = - xd.bounding_shape(axis=1, out_type=tf.int32)
        yd = tf.pad(xd.to_tensor(), [[0, 0], [0, c]])
        dl = tf.cast(xd.row_lengths(), dtype=tf.int32)
        p = tf.concat([ye, xt], axis=1)
        tl = tf.cast(xt.row_lengths(), dtype=tf.int32)
        p = tf.gather_nd(p, self.calc_idxs(tl))
        return [ye, el, yd, dl]

    def calc_idxs(self, lens):
        b, w =,
        y = tf.broadcast_to(tf.range(b)[:, None], [b, w])
        i = tf.range(w)[None, ] + lens[:, None]
        y = tf.stack([y, i], axis=2)
        return y

As our Frames layer returns fixed-width dense tensors once again, we can re-adjust our carried-over Embed layer to use the straight embedding_lookup instead.

class Embed(ql.Layer):
    def __init__(self, ps):
        s = (ps.dim_vocab, ps.dim_hidden)
        self.emb = self.add_weight('emb', shape=s)

    def call(self, x):
        y, lens = x
        y = tf.nn.embedding_lookup(self.emb, y)
        y *= y.shape[-1]**0.5
        return [y, lens]

Encode, decode and debed layers

We update the Encode and Decode layers with the addition of the tf.function decorators for the call methods.

class Encode(ql.Layer):
    def __init__(self, ps):
        self.width = ps.width_enc
        self.encs = [Encoder(self, f'enc_{i}') for i in range(ps.dim_stacks)]

    def call(self, x):
        y = x
        for e in self.encs:
            y = e(y)
        return y

class Decode(ql.Layer):
    def __init__(self, ps):
        self.width = ps.width_dec
        self.decs = [Decoder(self, f'dec_{i}') for i in range(ps.dim_stacks)]

    def call(self, x):
        y, ye = x[:-1], x[-1]
        for d in self.decs:
            y = d(y + [ye])
        return y

Our Debed layer is also largely a carry-over, with the adjustment for the now fixed-width tensors.

class Debed(ql.Layer):
    def __init__(self, ps):
        self.dbd = Dense(self, 'dbd', [ps.dim_hidden, ps.dim_vocab])

    def call(self, x):
        y, lens = x
        s = tf.shape(y)
        y = tf.reshape(y, [s[0] * s[1], -1])
        y = self.dbd(y)
        y = tf.reshape(y, [s[0], s[1], -1])
        y = y[:, :tf.math.reduce_max(lens), :]
        return y

Updated lightweight modules

We update the Encoder and Decoder modules with the addition of the tf.function decorators for the __call__ methods:

class Encoder(tf.Module):
    def __init__(self, layer, name):
        with self.name_scope:
            self.reflect = Attention(layer, 'refl')
            self.conclude = Conclusion(layer, 'conc')

    def __call__(self, x):
        y = x
        y = self.reflect(y + [y[0]])
        y = self.conclude(y)
        return y

class Decoder(tf.Module):
    def __init__(self, layer, name):
        with self.name_scope:
            self.reflect = Attention(layer, 'refl')
            self.consider = Attention(layer, 'cnsd')
            self.conclude = Conclusion(layer, 'conc')

    def __call__(self, x):
        y, ye = x[:-1], x[-1]
        y = self.reflect(y + [y[0]])
        y = self.consider(y + [ye])
        y = self.conclude(y)
        return y

The same applies to our new Attention module:

class Attention(tf.Module):
    def __init__(self, layer, name):
        h =
        self.scale = 1 / (h**0.5)
        with self.name_scope:
            self.q = layer.add_weight('q', shape=(h, h))
            self.k = layer.add_weight('k', shape=(h, h))
            self.v = layer.add_weight('v', shape=(h, h))

    def __call__(self, x):
        x, lens, ctx = x
        off = tf.math.reduce_max(lens)
        q = tf.einsum('bni,ij->bnj', x[:, -off:, :], self.q)
        k = tf.einsum('bni,ij->bnj', ctx, self.k)
        y = tf.einsum('bni,bmi->bnm', q, k)
        # use lens
        y = tf.nn.softmax(y * self.scale)
        v = tf.einsum('bni,ij->bnj', ctx, self.v)
        y = tf.einsum('bnm,bmi->bni', y, v)
        y = tf.concat([x[:, :-off, :], y], axis=1)
        return [y, lens]

The same applies to our new Conclusion module as well:

class Conclusion(tf.Module):
    def __init__(self, layer, name):
        self.layer = layer
        ps =
        w = layer.width * ps.dim_hidden
        with self.name_scope:
            s = [w, ps.dim_dense]
            self.inflate = Dense(layer, 'infl', s, activation='relu')
            s = [ps.dim_dense, w]
            self.deflate = Dense(layer, 'defl', s, bias=False)

    def __call__(self, x):
        y, lens = x
        w = self.layer.width
        d =
        y = tf.reshape(y, [-1, w * d])
        y = self.inflate(y)
        y = self.deflate(y)
        y = tf.reshape(y, [-1, w, d])
        return [y, lens]

To add the tf.function decorator to our Dense module, we simply inherit from the previous version:

class Dense(ql.Dense):
    def __call__(self, x):
        return super().__call__(x)

Updates for our model

Our model instance needs to be updated as well to use the newly defined components.

Another significant change is the addition of the “row_lengths” tensor (received directly from the ragged tensors) to all the now fixed-width input and output dense tensors.

Once again, we were able to return to using dense tensors for our inputs, despite the “raggedness” of our samples, because we adopted the “sliding context” strategy, thus smoothly concatenating an entire “history” of inputs and correct arithmetic results, into our “working set”:

def model_for(ps):
    x = [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    x += [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    x += [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    y = ToRagged()(x)
    y = Frames(ps)(y)
    embed = Embed(ps)
    ye = Encode(ps)(embed(y[:2]))
    yd = Decode(ps)(embed(y[2:]) + [ye[0]])
    y = Debed(ps)(yd)
    m = ks.Model(inputs=x, outputs=y)
    m.compile(optimizer=ps.optimizer, loss=ps.loss, metrics=[ps.metric])
    return m

Our parameters need to be expanded with the addition of the values for the now fixed widths of both our encoder and decoder stacks.

params = dict(

Training session

By firing up our training session, we can confirm the model’s layers and connections. The listing of a short session follows.

We can easily adjust the parameters to tailor the length of the sessions to our objectives.

ps = qd.Params(**params)
import masking as qm
qm.main_graph(ps, dset_for(ps), model_for(ps))
  Model: "model_1"
  Layer (type)                    Output Shape         Param #     Connected to
  input_7 (InputLayer)            [(None,)]            0
  input_8 (InputLayer)            [(None,)]            0
  input_9 (InputLayer)            [(None,)]            0
  input_10 (InputLayer)           [(None,)]            0
  input_11 (InputLayer)           [(None,)]            0
  input_12 (InputLayer)           [(None,)]            0
  to_ragged_1 (ToRagged)          [(None, None), (None 0           input_7[0][0]
  frames_1 (Frames)               [(5, 25), (None,), ( 125         to_ragged_1[0][0]
  embed_1 (Embed)                 multiple             120         frames_1[0][0]
  encode_1 (Encode)               [(None, 25, 6), (Non 90516       embed_1[0][0]
  decode_1 (Decode)               [(None, 15, 6), (Non 54732       embed_1[1][0]
  debed_1 (Debed)                 (None, None, None)   140         decode_1[0][0]
  Total params: 145,633
  Trainable params: 145,508
  Non-trainable params: 125
  Epoch 1/5
  20/20 [==============================] - 20s 1s/step - loss: 2.6308 - sparse_categorical_crossentropy: 2.6164
  Epoch 2/5
  20/20 [==============================] - 0s 11ms/step - loss: 2.1488 - sparse_categorical_crossentropy: 2.1325
  Epoch 3/5
  20/20 [==============================] - 0s 11ms/step - loss: 1.8967 - sparse_categorical_crossentropy: 1.8844
  Epoch 4/5
  20/20 [==============================] - 0s 13ms/step - loss: 1.7398 - sparse_categorical_crossentropy: 1.7248
  Epoch 5/5
  20/20 [==============================] - 0s 13ms/step - loss: 1.5818 - sparse_categorical_crossentropy: 1.5736

With our TensorBoard callback in place, the model’s fit method will generate the standard summaries that TB can conveniently visualize.

If you haven’t run the code below, an already generated graph is here.

#%load_ext tensorboard
#%tensorboard --logdir /tmp/q/logs

We can also switch over to the new eager execution mode.

Once again, this is particularly convenient for experimentation, as all ops are immediately executed. And here is a much shortened eager session.

# import ragged as qr
# qr.main_eager(ps, dset_for(ps), model_for(ps))

This concludes our blog, please see how to use the autograph features with our model by clicking on the next blog.