In this blog we continue to train our computer to “understand” elementary symbolic arithmetic.
We slightly change our approach however. Instead of having fixed input/context for our encoder/decoder stacks, we follow the idea of “sliding contexts” from this paper.
In addition, we continue to architect our model with a mixture of “validated” Keras layers as well as light-weight Modules containing bare TF ops.
Our objective is to ultimately arrive at a model representable by the graph and runnable example.
Just as before, we need to prep our environment to run any meaningful code:
import tensorflow as tf
import dataset as qd
import layers as ql
ks = tf.keras
kl = ks.layers
Before we start, we need to increase our dataset slightly as the training steps are becoming more meaningful.
Calling the dump_dset function with a parameters instance will update our stored sharded binary files:
def dump_dset(ps):
ps.max_val = 10000
ps.num_samples = 1000 # 100000
ps.num_shards = 10
fs = [f for f in qd.dump(ps)]
ps.dim_batch = 100
for i, _ in enumerate(qd.load(ps, fs).map(adapter)):
pass
print(f'dumped {i} batches of {ps.dim_batch} samples each')
return fs
For verification purposes, loading our already created meta data from the sources gives us:
print(qd.vocab)
print(qd.SPC, qd.SEP, qd.STP)
(' ', ':', '|', 'x', 'y', '=', ',', '+', '-', '*', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
0 1 2
“Formatter” expanded
We also need to expand our previously used formatter.
As we intend to concatenate subsequent inputs in our “sliding context”, we need to end the result feature, res, of our samples with our STP = "|" token.
We have started to use tf.debugging.asserts to increase our confidence in the correctness of our data. Later we will be able to switch these out with the familiar Python asserts.
Our formatter comes with an other significant adjustment.
We intend to feed both our encoder and our decoder with inputs. Namely, the encoder gets the concatenated defs and op features, while the decoder gets either a fully or a partially blanked res.
Our dataset will supply an enc, a dec and a tgt (the full correct result of the math expression in the sample) tensors. The rand_blank function does the quick (inline) random blanking, or masking, of the arithmetic result to be fed into our decoder as the des tensor.
@tf.function
def formatter(d):
ds = tf.RaggedTensor.from_sparse(d['defs'])
n = ds.nrows()
os = tf.RaggedTensor.from_sparse(d['op'])
tf.debugging.assert_equal(n, os.nrows())
ss = tf.fill([n, 1], qd.SEP)
enc = tf.concat([ds, ss, os, ss], axis=1)
rs = tf.RaggedTensor.from_sparse(d['res'])
tf.debugging.assert_equal(n, rs.nrows())
tgt = tf.concat([rs, tf.fill([n, 1], qd.STP)], axis=1)
def rand_blank(x):
y = x.flat_values
mv = tf.shape(y)[0]
s = mv // 2
i = tf.random.uniform([s], maxval=mv, dtype=tf.int32)[:, None]
y = tf.tensor_scatter_nd_update(y, i, tf.zeros([s], dtype=tf.int32))
return x.with_flat_values(y)
return {'enc': enc, 'dec': rand_blank(tgt), 'tgt': tgt}
In order for our dataset to be usable, we also need to update our adapter.
We continue to split our input ragged tensors into their components, and as we now have 3 ragged inputs: enc, dec and tgt, the total number of dense input tensors to our model will be 6.
The adapter needs to also supply our tgt dense tensor to the canned loss and metrics components that drive the gradient calculations.
In addition, we chose to add tgt, or its two components, to our inputs as well. This duplication gives us the chance of feeding correct arithmetic results into our “sliding context”.
@tf.function
def adapter(d):
enc, dec, tgt = d['enc'], d['dec'], d['tgt']
return (
(
enc.flat_values,
enc.row_splits,
dec.flat_values,
dec.row_splits,
tgt.flat_values,
tgt.row_splits,
),
tgt.to_tensor(),
)
Dataset adapter
Our new dataset creator function, dset_for is as follows.
We have added an optionally overridable adapter argument to be used later.
def dset_for(ps, adapter=adapter):
ds = tf.data.TFRecordDataset(list(qd.files(ps)))
ds = ds.take(100).batch(ps.dim_batch)
fs = {
'defs': tf.io.VarLenFeature(tf.int64),
'op': tf.io.VarLenFeature(tf.int64),
'res': tf.io.VarLenFeature(tf.int64),
}
ds = ds.map(lambda x: tf.io.parse_example(x, fs)).map(qd.caster)
return ds.map(formatter).map(adapter)
As we now have 3 pairs of input tensors, that we need to convert back into RaggedTensors, we quickly add a ToRagged convenience layer that can be seamlessly eliminated once the Keras Inputs start properly supporting the ragged=True keyword argument:
class ToRagged(kl.Layer):
@tf.function
def call(self, x):
ys = []
for i in range(3):
i *= 2
fv, rs = x[i:i + 2]
ys.append(tf.RaggedTensor.from_row_splits(fv, rs))
return ys
Frames layer
The Frames layer is the new significant addition to our code.
With every new encoder input sequence of tokens xe, it first concatenates the prev stored context with xe and then stores the result in ye.
Then it updates the prev variable with the concatenation of ye and the passed in correct arithmetic result xt. The resulting prev is to be used in the next cycle.
The computations are slightly more complex due to using the raggedness of the inputs to satisfy the continuous, seamlessly “sliding context” requirement.
The layer also returns the “row_lengths” tensors for both enc and dec inputs. They will be used later for propagating the input token sequences’ “raggedness”.
The entire Frames layer works exclusively with tokens, as we don’t want to keep stale embedding calculations around in our “sliding context”.
class Frames(ql.Layer):
def __init__(self, ps):
super().__init__(ps, dtype=tf.int32) # , dynamic=True)
s = (ps.dim_batch, ps.width_enc)
kw = dict(initializer='zeros', trainable=False, use_resource=True)
self.prev = self.add_weight('prev', shape=s, **kw)
@tf.function
def call(self, x):
xe, xd, xt = x
ye = tf.concat([self.prev, xe], axis=1)
el = tf.cast(xe.row_lengths(), dtype=tf.int32)
ye = tf.gather_nd(ye, self.calc_idxs(el))
c = self.ps.width_dec - xd.bounding_shape(axis=1, out_type=tf.int32)
yd = tf.pad(xd.to_tensor(), [[0, 0], [0, c]])
dl = tf.cast(xd.row_lengths(), dtype=tf.int32)
p = tf.concat([ye, xt], axis=1)
tl = tf.cast(xt.row_lengths(), dtype=tf.int32)
p = tf.gather_nd(p, self.calc_idxs(tl))
self.prev.assign(p)
return [ye, el, yd, dl]
def calc_idxs(self, lens):
b, w = self.ps.dim_batch, self.ps.width_enc
y = tf.broadcast_to(tf.range(b)[:, None], [b, w])
i = tf.range(w)[None, ] + lens[:, None]
y = tf.stack([y, i], axis=2)
return y
As our Frames layer returns fixed-width dense tensors once again, we can re-adjust our carried-over Embed layer to use the straight embedding_lookup instead.
class Embed(ql.Layer):
def __init__(self, ps):
super().__init__(ps)
s = (ps.dim_vocab, ps.dim_hidden)
self.emb = self.add_weight('emb', shape=s)
@tf.function
def call(self, x):
y, lens = x
y = tf.nn.embedding_lookup(self.emb, y)
y *= y.shape[-1]**0.5
return [y, lens]
Encode, decode and debed layers
We update the Encode and Decode layers with the addition of the tf.function decorators for the call methods.
class Encode(ql.Layer):
def __init__(self, ps):
super().__init__(ps)
self.width = ps.width_enc
self.encs = [Encoder(self, f'enc_{i}') for i in range(ps.dim_stacks)]
@tf.function
def call(self, x):
y = x
for e in self.encs:
y = e(y)
return y
class Decode(ql.Layer):
def __init__(self, ps):
super().__init__(ps)
self.width = ps.width_dec
self.decs = [Decoder(self, f'dec_{i}') for i in range(ps.dim_stacks)]
@tf.function
def call(self, x):
y, ye = x[:-1], x[-1]
for d in self.decs:
y = d(y + [ye])
return y
Our Debed layer is also largely a carry-over, with the adjustment for the now fixed-width tensors.
class Debed(ql.Layer):
def __init__(self, ps):
super().__init__(ps)
self.dbd = Dense(self, 'dbd', [ps.dim_hidden, ps.dim_vocab])
@tf.function
def call(self, x):
y, lens = x
s = tf.shape(y)
y = tf.reshape(y, [s[0] * s[1], -1])
y = self.dbd(y)
y = tf.reshape(y, [s[0], s[1], -1])
y = y[:, :tf.math.reduce_max(lens), :]
return y
Updated lightweight modules
We update the Encoder and Decoder modules with the addition of the tf.function decorators for the __call__ methods:
class Encoder(tf.Module):
def __init__(self, layer, name):
super().__init__(name=name)
with self.name_scope:
self.reflect = Attention(layer, 'refl')
self.conclude = Conclusion(layer, 'conc')
@tf.function
def __call__(self, x):
y = x
y = self.reflect(y + [y[0]])
y = self.conclude(y)
return y
class Decoder(tf.Module):
def __init__(self, layer, name):
super().__init__(name=name)
with self.name_scope:
self.reflect = Attention(layer, 'refl')
self.consider = Attention(layer, 'cnsd')
self.conclude = Conclusion(layer, 'conc')
@tf.function
def __call__(self, x):
y, ye = x[:-1], x[-1]
y = self.reflect(y + [y[0]])
y = self.consider(y + [ye])
y = self.conclude(y)
return y
The same applies to our new Attention module:
class Attention(tf.Module):
def __init__(self, layer, name):
super().__init__(name=name)
h = layer.ps.dim_hidden
self.scale = 1 / (h**0.5)
with self.name_scope:
self.q = layer.add_weight('q', shape=(h, h))
self.k = layer.add_weight('k', shape=(h, h))
self.v = layer.add_weight('v', shape=(h, h))
@tf.function
def __call__(self, x):
x, lens, ctx = x
off = tf.math.reduce_max(lens)
q = tf.einsum('bni,ij->bnj', x[:, -off:, :], self.q)
k = tf.einsum('bni,ij->bnj', ctx, self.k)
y = tf.einsum('bni,bmi->bnm', q, k)
# use lens
y = tf.nn.softmax(y * self.scale)
v = tf.einsum('bni,ij->bnj', ctx, self.v)
y = tf.einsum('bnm,bmi->bni', y, v)
y = tf.concat([x[:, :-off, :], y], axis=1)
return [y, lens]
The same applies to our new Conclusion module as well:
class Conclusion(tf.Module):
def __init__(self, layer, name):
super().__init__(name=name)
self.layer = layer
ps = layer.ps
w = layer.width * ps.dim_hidden
with self.name_scope:
s = [w, ps.dim_dense]
self.inflate = Dense(layer, 'infl', s, activation='relu')
s = [ps.dim_dense, w]
self.deflate = Dense(layer, 'defl', s, bias=False)
@tf.function
def __call__(self, x):
y, lens = x
w = self.layer.width
d = self.layer.ps.dim_hidden
y = tf.reshape(y, [-1, w * d])
y = self.inflate(y)
y = self.deflate(y)
y = tf.reshape(y, [-1, w, d])
return [y, lens]
To add the tf.function decorator to our Dense module, we simply inherit from the previous version:
class Dense(ql.Dense):
@tf.function
def __call__(self, x):
return super().__call__(x)
Updates for our model
Our model instance needs to be updated as well to use the newly defined components.
Another significant change is the addition of the “row_lengths” tensor (received directly from the ragged tensors) to all the now fixed-width input and output dense tensors.
Once again, we were able to return to using dense tensors for our inputs, despite the “raggedness” of our samples, because we adopted the “sliding context” strategy, thus smoothly concatenating an entire “history” of inputs and correct arithmetic results, into our “working set”:
def model_for(ps):
x = [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
x += [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
x += [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
y = ToRagged()(x)
y = Frames(ps)(y)
embed = Embed(ps)
ye = Encode(ps)(embed(y[:2]))
yd = Decode(ps)(embed(y[2:]) + [ye[0]])
y = Debed(ps)(yd)
m = ks.Model(inputs=x, outputs=y)
m.compile(optimizer=ps.optimizer, loss=ps.loss, metrics=[ps.metric])
print(m.summary())
return m
Our parameters need to be expanded with the addition of the values for the now fixed widths of both our encoder and decoder stacks.
params = dict(
dim_batch=5,
dim_dense=150,
dim_hidden=6,
dim_stacks=2,
dim_vocab=len(qd.vocab),
loss=ks.losses.SparseCategoricalCrossentropy(from_logits=True),
metric=ks.metrics.SparseCategoricalCrossentropy(from_logits=True),
num_epochs=5,
num_shards=2,
optimizer=ks.optimizers.Adam(),
width_dec=15,
width_enc=25,
)
Training session
By firing up our training session, we can confirm the model’s layers and connections. The listing of a short session follows.
We can easily adjust the parameters to tailor the length of the sessions to our objectives.
ps = qd.Params(**params)
import masking as qm
qm.main_graph(ps, dset_for(ps), model_for(ps))
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_7 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
input_8 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
input_9 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
input_10 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
input_11 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
input_12 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
to_ragged_1 (ToRagged) [(None, None), (None 0 input_7[0][0]
input_8[0][0]
input_9[0][0]
input_10[0][0]
input_11[0][0]
input_12[0][0]
__________________________________________________________________________________________________
frames_1 (Frames) [(5, 25), (None,), ( 125 to_ragged_1[0][0]
to_ragged_1[0][1]
to_ragged_1[0][2]
__________________________________________________________________________________________________
embed_1 (Embed) multiple 120 frames_1[0][0]
frames_1[0][1]
frames_1[0][2]
frames_1[0][3]
__________________________________________________________________________________________________
encode_1 (Encode) [(None, 25, 6), (Non 90516 embed_1[0][0]
embed_1[0][1]
__________________________________________________________________________________________________
decode_1 (Decode) [(None, 15, 6), (Non 54732 embed_1[1][0]
embed_1[1][1]
encode_1[0][0]
__________________________________________________________________________________________________
debed_1 (Debed) (None, None, None) 140 decode_1[0][0]
decode_1[0][1]
==================================================================================================
Total params: 145,633
Trainable params: 145,508
Non-trainable params: 125
__________________________________________________________________________________________________
None
Epoch 1/5
20/20 [==============================] - 20s 1s/step - loss: 2.6308 - sparse_categorical_crossentropy: 2.6164
Epoch 2/5
20/20 [==============================] - 0s 11ms/step - loss: 2.1488 - sparse_categorical_crossentropy: 2.1325
Epoch 3/5
20/20 [==============================] - 0s 11ms/step - loss: 1.8967 - sparse_categorical_crossentropy: 1.8844
Epoch 4/5
20/20 [==============================] - 0s 13ms/step - loss: 1.7398 - sparse_categorical_crossentropy: 1.7248
Epoch 5/5
20/20 [==============================] - 0s 13ms/step - loss: 1.5818 - sparse_categorical_crossentropy: 1.5736
With our TensorBoard callback in place, the model’s fit method will generate the standard summaries that TB can conveniently visualize.
If you haven’t run the code below, an already generated graph is here.
#%load_ext tensorboard
#%tensorboard --logdir /tmp/q/logs
We can also switch over to the new eager execution mode.
Once again, this is particularly convenient for experimentation, as all ops are immediately executed. And here is a much shortened eager session.
# import ragged as qr
# qr.main_eager(ps, dset_for(ps), model_for(ps))
This concludes our blog, please see how to use the autograph features with our model by clicking on the next blog.