Skip to content

Autograph - Intuitive Data-Driven Control At Last

Continuing our blogs, we shift our focus to yet another new feature in TF, the autograph functionality.

Previously, and as far as intuitive expression of code was concerned, “graph ops” efficiently solved complex calculations while failed at simple, sequential control.

By generating on-demand Python code now, autograph transparently patches all the necessary graph ops together and packages the result into a “python op”.

While the generated new ops are potentially faster than the code before them, in this blog we are more interested in the new expressive powers of the autograph package.

Specifically, we look at what becomes possible when decorating our functions with the new tf.function decorator, as doing this would by default invoke the autograph functionality.

Our objective is to ultimately arrive at a model as represented by the graph and runnable example.

Just as before, we need to prep our environment to run any meaningful code:

import numpy as np
import tensorflow as tf
import dataset as qd
import custom as qc
import layers as ql
ks = tf.keras
kl = ks.layers

Embed op

Next, we borrow the pos_timing function from our previous blogs, and override it to return a constant “timing signal” tensor, depending on the width and depth arguments.

As our first task is to implement a “python branch” in our new Embed op, we will be using two different “timing” tensors, one for the encode input and the other for the decode input.

def pos_timing(width, depth):
    t = ql.pos_timing(width, depth)
    t = tf.constant(t, dtype=tf.float32)
    return t

The Embed layer will thus create the two constant tensors to be sourced in the subsequent call methods.

Our model will call the shared Embed instance for both of our stacks. As we have decorated its call method with tf.function, we can use familiar and intuitive Python comparisons to branch on the value of tensors on-the-fly, during graph execution.

Clearly, our two stacks, while having the same depths, have different widths. Also the constant “timing” tensors have different widths as well.

Yet we are still able to pick-and-match the otherwise incompatible tensors and successfully add them together, all depending on the actual width of our “current” input tensor:

class Embed(qc.Embed):
    def __init__(self, ps):
        super().__init__(ps)
        self.enc_p = pos_timing(ps.width_enc, ps.dim_hidden)
        self.dec_p = pos_timing(ps.width_dec, ps.dim_hidden)

    @tf.function(input_signature=[[
        tf.TensorSpec(shape=[None, None], dtype=tf.int32),
        tf.TensorSpec(shape=[None], dtype=tf.int32)
    ]])
    def call(self, x):
        y, lens = x
        y = tf.nn.embedding_lookup(self.emb, y)
        s = tf.shape(y)
        if s[-2] == self.ps.width_enc:
            y += tf.broadcast_to(self.enc_p, s)
        elif s[-2] == self.ps.width_dec:
            y += tf.broadcast_to(self.dec_p, s)
        else:
            pass
        y *= tf.cast(s[-1], tf.float32)**0.5
        return [y, lens]

Frames op

Next we demonstrate how on-the-fly “python ops” can also provide insights into inner processes and data flows.

We borrow our Frames layer from the previous blog and override its call method with a tf.function decorated new version that, besides calling super().call(), also calls a new print_row Python function on every row in our batch.

Yes, we are calling a Python function and printing its results in a TF graph op while never leaving our intuitive and familiar Python environment! Isn’t that great?

The print_row function itself is simple, it iterates through the tokens of the “row”, it does a lookup of each in our vocab “table” for the actual character representing the token and then it “joins” all the characters and prints out the resulting string.

And, if we scroll down to the listing of our training session, we can actually see the “sliding context” of our samples as they fly by during our training.

Needless to say, the listing confirms that our Frames layer does a good job concatenating the varied length sample inputs, the target results, as well as the necessary separators.

As a result, a simple Python function, usable during graph ops, provides us invaluable insights deep into our inner processes and data flow.

class Frames(qc.Frames):
    @tf.function
    def call(self, x):
        y = super().call.python_function(x)
        tf.print()

        def print_row(r):
            tf.print(
                tf.numpy_function(
                    lambda ts: ''.join([qd.vocab[t] for t in ts]),
                    [r],
                    Tout=[tf.string],
                ))
            return r

        tf.map_fn(print_row, self.prev)
        return y

Deduce and Probe ops

Our next new layer is the partial Deduce layer, showcasing how control is intuitive at last from data-driven branching to searching.

This layer will be used in the next group of blogs as a replacement for our previous Debed layer. It contains a tensor-dependent for loop to iteratively replace our masked characters with “deduced” ones.

The future Probe layer, building on the Deduce scheme, implements an approximation of “Beam Search”, see paper.

It effectively iterates through the hidden dimensions of the output, and based on parallel topk searches, comparing various choices for “debedding” the output, it settles on an “optimal” debedding and thus final token output for our decoder.

Without autograph the data-driven looping/branching graph ops would have to be expressed in a much more convoluted manner:

"""
class Deduce(Layer):
    @tf.function
    def call(self, x):
        toks, *x = x
        if self.cfg.runtime.print_toks:
            qu.print_toks(toks, qd.vocab)
        y = self.deduce([toks] + x)
        n = tf.shape(y)[1]
        p = tf.shape(toks)[1] - n
        for i in tf.range(n):
            t = toks[:, :n]
            m = tf.equal(t, qd.MSK)
            if tf.equal(tf.reduce_any(m), True):
                t = self.update(t, m, y)
                if self.cfg.runtime.print_toks:
                    qu.print_toks(t, qd.vocab)
                toks = tf.pad(t, [[0, 0], [0, p]])
                y = self.deduce([toks] + x)
            else:
                e = tf.equal(t, qd.EOS)
                e = tf.math.count_nonzero(e, axis=1)
                if tf.equal(tf.reduce_any(tf.not_equal(e, 1)), False):
                    break
        return y
"""
class Probe(ql.Layer):
    def __init__(self, ps):
        super().__init__(ps)
        self.dbd = qc.Dense(self, 'dbd', [ps.dim_hidden, ps.dim_vocab])

    @tf.function
    def call(self, x):
        y, lens = x
        s = tf.shape(y)
        y = tf.reshape(y, [s[0] * s[1], -1])
        y = self.dbd(y)
        y = tf.reshape(y, [s[0], s[1], -1])
        y = y[:, :tf.math.reduce_max(lens), :]
        return y

Our model needs to be updated as well to use the newly defined components.

Other than that, we are ready to start training:

def model_for(ps):
    x = [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    x += [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    x += [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    y = qc.ToRagged()(x)
    y = Frames(ps)(y)
    embed = Embed(ps)
    ye = qc.Encode(ps)(embed(y[:2]))
    yd = qc.Decode(ps)(embed(y[2:]) + [ye[0]])
    y = Probe(ps)(yd)
    m = ks.Model(inputs=x, outputs=y)
    m.compile(optimizer=ps.optimizer, loss=ps.loss, metrics=[ps.metric])
    print(m.summary())
    return m

Training session

By firing up our training session, we can confirm the model’s layers and connections. The listing of a short session follows.

We can easily adjust the parameters to tailor the length of the sessions to our objectives.

ps = qd.Params(**qc.params)
ps.num_epochs = 1
import masking as qm
qm.main_graph(ps, qc.dset_for(ps).take(10), model_for(ps))
  Model: "model_3"
  __________________________________________________________________________________________________
  Layer (type)                    Output Shape         Param #     Connected to
  ==================================================================================================
  input_19 (InputLayer)           [(None,)]            0
  __________________________________________________________________________________________________
  input_20 (InputLayer)           [(None,)]            0
  __________________________________________________________________________________________________
  input_21 (InputLayer)           [(None,)]            0
  __________________________________________________________________________________________________
  input_22 (InputLayer)           [(None,)]            0
  __________________________________________________________________________________________________
  input_23 (InputLayer)           [(None,)]            0
  __________________________________________________________________________________________________
  input_24 (InputLayer)           [(None,)]            0
  __________________________________________________________________________________________________
  to_ragged_3 (ToRagged)          [(None, None), (None 0           input_19[0][0]
                                                                   input_20[0][0]
                                                                   input_21[0][0]
                                                                   input_22[0][0]
                                                                   input_23[0][0]
                                                                   input_24[0][0]
  __________________________________________________________________________________________________
  frames_3 (Frames)               [(5, 25), (None,), ( 125         to_ragged_3[0][0]
                                                                   to_ragged_3[0][1]
                                                                   to_ragged_3[0][2]
  __________________________________________________________________________________________________
  embed_3 (Embed)                 multiple             120         frames_3[0][0]
                                                                   frames_3[0][1]
                                                                   frames_3[0][2]
                                                                   frames_3[0][3]
  __________________________________________________________________________________________________
  encode_3 (Encode)               [(None, 25, 6), (Non 90516       embed_3[0][0]
                                                                   embed_3[0][1]
  __________________________________________________________________________________________________
  decode_3 (Decode)               [(None, 15, 6), (Non 54732       embed_3[1][0]
                                                                   embed_3[1][1]
                                                                   encode_3[0][0]
  __________________________________________________________________________________________________
  probe_3 (Probe)                 (None, None, None)   140         decode_3[0][0]
                                                                   decode_3[0][1]
  ==================================================================================================
  Total params: 145,633
  Trainable params: 145,508
  Non-trainable params: 125
  __________________________________________________________________________________________________
  None
  ["        y=76,x=34:y-x:42|"]
  ["       y=12,x=33:x*y:396|"]
  ["     y=-13,x=-80:x-y:-67|"]
  ["      x=51,y=-70:x-y:121|"]
  ["       y=-24,x=30:x-y:54|"]
        1/Unknown - 3s 3s/step - loss: 3.0014 - sparse_categorical_crossentropy: 3.0014
  ["x:42|x=-15,y=-38:y*x:570|"]
  ["x*y:396|x=36,y=72:y-x:36|"]
  ["y:-67|x=-93,y=55:y+x:-38|"]
  ["-y:121|x=2,y=-66:y-x:-68|"]
  ["x-y:54|x=-1,y=-59:x*y:59|"]
        2/Unknown - 3s 1s/step - loss: 2.9652 - sparse_categorical_crossentropy: 2.9652
  ["570|y=59,x=-78:x*y:-4602|"]
  [":36|y=-98,x=-78:y*x:7644|"]
  [":-38|x=-21,y=-36:y+x:-57|"]
  ["y-x:-68|x=21,y=40:y+x:61|"]
  ["x*y:59|y=31,x=-12:x+y:19|"]
        3/Unknown - 4s 1s/step - loss: 2.9888 - sparse_categorical_crossentropy: 2.9956
  ["y:-4602|y=59,x=66:y-x:-7|"]
  [":7644|y=21,x=67:x*y:1407|"]
  ["x:-57|x=-51,y=-69:x-y:18|"]
  [":61|y=49,x=-70:y*x:-3430|"]
  ["x+y:19|y=53,x=15:x*y:795|"]
        4/Unknown - 5s 1s/step - loss: 2.9879 - sparse_categorical_crossentropy: 2.9924
  [":y-x:-7|y=52,x=50:x-y:-2|"]
  ["1407|y=-86,x=40:y-x:-126|"]
  ["-y:18|x=48,y=-43:y-x:-91|"]
  [":-3430|x=99,y=-24:x+y:75|"]
  ["795|x=94,y=-79:x*y:-7426|"]
        5/Unknown - 6s 1s/step - loss: 2.9691 - sparse_categorical_crossentropy: 2.9697
  ["y:-2|x=17,y=-37:x*y:-629|"]
  ["126|x=99,y=-94:y*x:-9306|"]
  ["x:-91|y=-82,x=63:x+y:-19|"]
  [":75|x=-51,y=-79:x*y:4029|"]
  ["426|y=-67,x=-44:x*y:2948|"]
        6/Unknown - 7s 1s/step - loss: 2.9654 - sparse_categorical_crossentropy: 2.9654
  ["y:-629|y=72,x=28:y+x:100|"]
  ["306|y=93,x=-67:x*y:-6231|"]
  ["-19|y=83,x=-61:y*x:-5063|"]
  ["4029|x=-19,y=-63:x+y:-82|"]
  [":2948|y=-5,x=-31:y*x:155|"]
        7/Unknown - 8s 1s/step - loss: 2.9354 - sparse_categorical_crossentropy: 2.9323
  ["x:100|x=42,y=83:x*y:3486|"]
  [":-6231|x=-8,y=23:x-y:-31|"]
  ["*x:-5063|x=7,y=40:y+x:47|"]
  ["-82|y=-63,x=-35:y*x:2205|"]
  ["155|y=-68,x=-17:y*x:1156|"]
        8/Unknown - 9s 1s/step - loss: 2.9268 - sparse_categorical_crossentropy: 2.9247
  [":3486|x=97,y=30:y*x:2910|"]
  ["y:-31|y=-50,x=-71:y-x:21|"]
  ["y+x:47|x=44,y=59:x+y:103|"]
  ["*x:2205|x=23,y=66:y-x:43|"]
  ["1156|y=-90,x=76:y-x:-166|"]
        9/Unknown - 10s 1s/step - loss: 2.8991 - sparse_categorical_crossentropy: 2.8989
  [":2910|x=-20,y=72:x-y:-92|"]
  ["y-x:21|y=-1,x=91:y-x:-92|"]
  ["+y:103|x=-14,y=0:y+x:-14|"]
  [":43|x=-78,y=64:y*x:-4992|"]
  ["166|x=-81,y=16:y*x:-1296|"]
  10/10 [==============================] - 11s 1s/step - loss: 2.8581 - sparse_categorical_crossentropy: 2.8533

With our TensorBoard callback in place, the model’s fit method will generate the standard summaries that TB can conveniently visualize.

If you haven’t run the code below, an already generated graph is here.

#%load_ext tensorboard
#%tensorboard --logdir /tmp/q/logs

This concludes our blog, please see how to use customize the losses and metrics driving the training by clicking on the next blog.