# Language Modeling From Scratch — Deep Dive Into Activations, Gradients and BatchNormalization (Part 3)

Last Updated on February 22, 2024 by Editorial Team

**Author(s): Abhishek Chaudhary**

Originally published on Towards AI.

In previous article we implemented a Multi Layer Perceptron with 2 Hidden Layers. We observed an improvement over our original model which made use of probability distribution to generate random names. In this article we’ll take a deeper look into activations, gradients and batch normalization. This is a separate article in itself because it’s important to develop a solid understanding of these topics before we mode onto more complex architectures like RNN and Transformers.

*I highly recommend reading the previous articles of the series.*

## High Dimensional Language Modeling Using Neural Network

### In the previous article we made use of probability distribution to create a name generator, we also looked into using a…

pub.towardsai.net

## Language Modeling From Scratch

### Language modeling is all about how computers understand and generate human language. It’s a key part of making AI…

pub.towardsai.net

The code for this article can be found in the following jupyter notebook

## Setup

Starting code for this article is same as previous article, just a cleaned out version of it. As usual, we’ll start by importing all the required modules.

`import torch`

import torch.nn.functional as F

import matplotlib.pyplot as plt # for making figures

import numpy as np

%matplotlib inline

Next we’ll import our dataset of names and build vocabulary out of it

`# read in all the words`

words = open('names.txt', 'r').read().splitlines()

# build the vocabulary of characters and mappings to/from integers

chars = sorted(list(set(''.join(words))))

stoi = {s:i+1 for i,s in enumerate(chars)}

stoi['.'] = 0

itos = {i:s for s,i in stoi.items()}

vocab_size = len(itos)

len(words), vocab_size, itos, stoi

`(32033,`

27,

{1: 'a',

2: 'b',

3: 'c',

4: 'd',

5: 'e',

6: 'f',

7: 'g',

8: 'h',

9: 'i',

10: 'j',

11: 'k',

12: 'l',

13: 'm',

14: 'n',

15: 'o',

16: 'p',

17: 'q',

18: 'r',

19: 's',

20: 't',

21: 'u',

22: 'v',

23: 'w',

24: 'x',

25: 'y',

26: 'z',

0: '.'},

{'a': 1,

'b': 2,

'c': 3,

'd': 4,

'e': 5,

'f': 6,

'g': 7,

'h': 8,

'i': 9,

'j': 10,

'k': 11,

'l': 12,

'm': 13,

'n': 14,

'o': 15,

'p': 16,

'q': 17,

'r': 18,

's': 19,

't': 20,

'u': 21,

'v': 22,

'w': 23,

'x': 24,

'y': 25,

'z': 26,

'.': 0})

Adding a utility method to create train, test and validation dataset

`block_size = 3 # context length: how many characters do we take to predict the next one?`

def build_dataset(words):

X, Y = [], []

for w in words:

context = [0] * block_size

for ch in w + '.':

ix = stoi[ch]

X.append(context)

Y.append(ix)

context = context[1:] + [ix] # crop and append

X = torch.tensor(X)

Y = torch.tensor(Y)

print(X.shape, Y.shape)

return X, Y

Creating training, testing and validation dataset using the above method. We’ll use 80%, 10% and 10% respectively

`import random`

random.seed(42)

random.shuffle(words)

n1 = int(0.8*len(words))

n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1]) # 80%

Xdev, Ydev = build_dataset(words[n1:n2]) # 10%

Xte, Yte = build_dataset(words[n2:]) # 10%

`torch.Size([182625, 3]) torch.Size([182625])`

torch.Size([22655, 3]) torch.Size([22655])

torch.Size([22866, 3]) torch.Size([22866])

Similar to the neural network discussed in previous article, we’ll create two hidden layers, HL1 and HL2

`n_embd = 10 # the dimensionality of the character embedding vectors`

n_hidden = 200 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility

C = torch.randn((vocab_size, n_embd), generator=g)

W1 = torch.randn((n_embd * block_size, n_hidden), generator=g)

W2 = torch.randn((n_hidden, vocab_size), generator=g)

b2 = torch.randn(vocab_size, generator=g)

parameters = [C, W1, W2, b2]

print(sum(p.nelement() for p in parameters)) # number of parameters in total

for p in parameters:

p.requires_grad = True

`11697`

Next we’ll setup the training loop with forward pass, backward pass and parameter update using learning rate, same as last article. We’ll also keep track of loss per step to later plot the results.

`# same optimization as last time`

max_steps = 200000

batch_size = 32

lossi = []

def run_training_loop(break_on_first=False):

for i in range(max_steps):

# minibatch construct

ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)

Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

# forward pass

emb = C[Xb] # embed the characters into vectors

embcat = emb.view(emb.shape[0], -1) # concatenate the vectors

# Linear layer

hpreact = embcat @ W1 #+ b1 # hidden layer pre-activation

# Non-linearity

h = torch.tanh(hpreact) # hidden layer

logits = h @ W2 + b2 # output layer

loss = F.cross_entropy(logits, Yb) # loss function

# backward pass

for p in parameters:

p.grad = None

loss.backward()

# update

lr = 0.1 if i < 100000 else 0.01 # step learning rate decay

for p in parameters:

p.data += -lr * p.grad

# track stats

if i % 10000 == 0: # print every once in a while

print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')

lossi.append(loss.log10().item())

if break_on_first:

return logits, h, hpreact

break

`run_training_loop()`

`0/ 200000: 29.8979`

10000/ 200000: 2.6601

20000/ 200000: 2.7236

30000/ 200000: 2.2986

40000/ 200000: 2.5010

50000/ 200000: 2.0478

60000/ 200000: 2.4736

70000/ 200000: 2.4584

80000/ 200000: 2.4450

90000/ 200000: 2.1317

100000/ 200000: 2.3553

110000/ 200000: 2.4058

120000/ 200000: 1.6513

130000/ 200000: 1.9658

140000/ 200000: 2.1240

150000/ 200000: 2.0506

160000/ 200000: 2.0208

170000/ 200000: 2.3377

180000/ 200000: 2.1702

190000/ 200000: 2.0824

Adding a method to calculate loss on train and validation split

`@torch.no_grad() # this decorator disables gradient tracking`

def split_loss(split):

x,y = {

'train': (Xtr, Ytr),

'val': (Xdev, Ydev),

'test': (Xte, Yte),

}[split]

emb = C[x] # (N, block_size, n_embd)

embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)

hpreact = embcat @ W1 # + b1

h = torch.tanh(hpreact) # (N, n_hidden)

logits = h @ W2 + b2 # (N, vocab_size)

loss = F.cross_entropy(logits, y)

print(split, loss.item())

`split_loss('train')`

split_loss('val')

`plt.plot(lossi)`

`[<matplotlib.lines.Line2D at 0x124cf4650>]`

Now that we have the initial setup done, let’s look into the optimizations that we can perform.

## Overconfident Softmax

The first problem that the above model has is the “hockeystick” nature of loss over time, as is evident from the above graph. During the first iteration, we saw the loss of 29.8979, which is way too high.

## What should be the expected loss for the first iteration?

For the first iteration, we’d expect the model to have no prediction ability, so it should assign equal probability to each of the characters, which would be 1/27. Let’s look at what the loss corresponding to this value comes out to be.

`-torch.tensor(1/27.0).log()`

`tensor(3.2958)`

So the loss reported by the neural network on the first iteration should be around 3.29, Let’s try to look at what the logits reported by HL2 after the first iteration are

`# Reinitialize weights and then run the training loop again`

logits, _, _ = run_training_loop(break_on_first=True)

logits[0]

`0/ 200000: 29.8979`

tensor([ -7.0724, 4.4451, 10.1186, 13.6337, 8.1399, 2.1134, 3.3501,

12.6121, -10.0490, -6.7101, -9.5156, -17.8216, 7.2949, -1.5835,

10.2453, -14.9260, -4.5385, -16.9217, 5.0395, -11.4323, 17.1238,

12.3386, 16.0655, 0.0758, 1.6998, 28.4674, 6.6073],

grad_fn=<SelectBackward0>)

As we can see, the logits are diverging a lot, thus making the softmax overconfidently predicit a few character where it should have almost same values for each character. We can solve this problem by multiplying W2 by a constant, say 0.01 which would reduce the value of weights, hence the value of logits reported by (W2 * input + b2), let’s look at the logits reported by changing W2 to W2 * 0.0.1

`n_embd = 10 # the dimensionality of the character embedding vectors`

n_hidden = 200 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility

C = torch.randn((vocab_size, n_embd), generator=g)

W1 = torch.randn((n_embd * block_size, n_hidden), generator=g)

W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.01

b2 = torch.randn(vocab_size, generator=g)

parameters = [C, W1, W2, b2]

print(sum(p.nelement() for p in parameters)) # number of parameters in total

for p in parameters:

p.requires_grad = True

`11697`

After running the training loop for first iteration, let’s check the value of logits again

`logits, _, _ = run_training_loop(break_on_first=True)`

logits[0]

`0/ 200000: 3.8155`

tensor([ 0.5781, -0.8984, -0.1784, 1.5460, -1.6654, -1.3676, 0.2092, 0.8925,

-1.1660, 0.3242, 0.0055, 0.4046, 1.0528, -0.4510, 0.5271, -0.6357,

2.5709, 0.9907, -0.1271, -0.9460, -0.7640, 1.0952, 0.0044, -0.3396,

-1.4397, -0.7380, 1.3038], grad_fn=<SelectBackward0>)

We can see that the values are closer to each other now. Now let’s run the entire training loop and check the final reported loss

`lossi = []`

# reinitialize the weights and run the full training loop again

run_training_loop()

`0/ 200000: 3.8155`

10000/ 200000: 2.1191

20000/ 200000: 2.6264

30000/ 200000: 2.1872

40000/ 200000: 2.4961

50000/ 200000: 1.8205

60000/ 200000: 2.1090

70000/ 200000: 2.3420

80000/ 200000: 2.3301

90000/ 200000: 2.0682

100000/ 200000: 2.2838

110000/ 200000: 2.2106

120000/ 200000: 1.6535

130000/ 200000: 1.8890

140000/ 200000: 2.0644

150000/ 200000: 1.9606

160000/ 200000: 1.9712

170000/ 200000: 2.3565

180000/ 200000: 2.0956

190000/ 200000: 2.1049

`split_loss('train')`

split_loss('val')

plt.plot(lossi)

`train 2.0693342685699463`

val 2.1324031352996826

[<matplotlib.lines.Line2D at 0x124c207d0>]

We can make two observations here

- Both training loss and validation loss are lower than previously reported
- The “hockey shape” of the loss over iteration is gone as we are now initializing W2 with more reasonable values, resulting in better logit values and softmax not being wrong confidently. This also results in model spending more iterations in actually reducing the loss rather than spending those cycles in trying to squash loss from very high values to a more reasonable one.

## Tanh Saturation

The plot below shows what a tanh function looks like, as we can see the function acts as a bounding method for the input and limits value in the range of [-1, 1]. For very hight values, tanh value remains constant at 1 and for very low values tanh value remains constant at -1. For both of these scenarios gradiant is 0. This is true for other squashing functions like sigmoid and relu.

`import numpy as np`

x = np.linspace(-100, 100, 100)

y = np.tanh(x)

plt.plot(x, y, label='y = tang(x)')

Now let’s take a look at the activations from the first hidden layer, i.e., predict and the corresponding h values, and plot them out to see what values are being used.

`logits, h, hpreact = run_training_loop(break_on_first=True)`

`0/ 200000: 3.8508`

`h.shape`

`torch.Size([32, 200])`

`plt.hist(h.view(-1).tolist(), bins=50)`

`(array([2125., 153., 117., 76., 65., 51., 44., 60., 20.,`

26., 17., 25., 22., 30., 30., 20., 28., 16.,

19., 18., 29., 24., 14., 12., 25., 29., 16.,

20., 28., 21., 24., 18., 14., 22., 6., 38.,

27., 29., 49., 17., 22., 16., 29., 44., 44.,

39., 109., 114., 174., 2385.]),

array([-1. , -0.96, -0.92, -0.88, -0.84, -0.8 , -0.76, -0.72, -0.68,

-0.64, -0.6 , -0.56, -0.52, -0.48, -0.44, -0.4 , -0.36, -0.32,

-0.28, -0.24, -0.2 , -0.16, -0.12, -0.08, -0.04, 0. , 0.04,

0.08, 0.12, 0.16, 0.2 , 0.24, 0.28, 0.32, 0.36, 0.4 ,

0.44, 0.48, 0.52, 0.56, 0.6 , 0.64, 0.68, 0.72, 0.76,

0.8 , 0.84, 0.88, 0.92, 0.96, 1. ]),

<BarContainer object of 50 artists>)

From the above figure, we can clearly see that the -1 and 1 values occur most frequently. This would mean for all of these values, when the network backpropagates, there would be no gradient, and the associated neurons would not learn anything. Let’s check one more thing

Now we’ll check how many examples have h > 0.99

`plt.figure(figsize=(20, 10))`

plt.imshow(h.abs() > 0.99, cmap='gray', interpolation='nearest')

`<matplotlib.image.AxesImage at 0x121c60a50>`

With the plot above, we see that a lot of it is white, which means a lot of the neurons are in the flat area of tanh. We would be in a lot of trouble if for any of these neurons we have an entire column white. That would be what we call a “dead neuron”.

A dead neuron is a neuron for which none of the examples land in the active and non-flat areas of tanh. Such a neuron never learns anything from the examples and is the machine learning equivalent of the biological term — “brain dead”

Another interesting thing to note here is that the neurons can become dead during optimization as well. For instance, when the learning rate is very high, it can result in preactivation(predict) values tipping over to the flat zone, thereby making the neuron dead from that moment onwards.

As we have seen the issue arises due to high values of hpreact, so we can make use of the approach we used earlier and decrease the Weights W1.

`n_embd = 10 # the dimensionality of the character embedding vectors`

n_hidden = 200 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility

C = torch.randn((vocab_size, n_embd), generator=g)

W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * 0.2

W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.01

b2 = torch.randn(vocab_size, generator=g)

parameters = [C, W1, W2, b2]

print(sum(p.nelement() for p in parameters)) # number of parameters in total

for p in parameters:

p.requires_grad = True

`11697`

`logits, h, hpreact = run_training_loop(break_on_first=True)`

`0/ 200000: 3.3189`

`plt.hist(h.view(-1).tolist(), bins=50)`

`(array([160., 168., 172., 157., 165., 139., 146., 123., 116., 122., 108.,`

105., 91., 108., 104., 97., 110., 83., 99., 114., 106., 97.,

98., 90., 113., 108., 111., 120., 104., 93., 104., 130., 98.,

117., 108., 118., 109., 109., 121., 106., 133., 126., 140., 161.,

156., 171., 196., 207., 238., 225.]),

array([-9.99249458e-01, -9.59265115e-01, -9.19280772e-01, -8.79296429e-01,

-8.39312086e-01, -7.99327743e-01, -7.59343400e-01, -7.19359057e-01,

-6.79374714e-01, -6.39390371e-01, -5.99406028e-01, -5.59421685e-01,

-5.19437342e-01, -4.79452999e-01, -4.39468656e-01, -3.99484313e-01,

-3.59499969e-01, -3.19515626e-01, -2.79531283e-01, -2.39546940e-01,

-1.99562597e-01, -1.59578254e-01, -1.19593911e-01, -7.96095681e-02,

-3.96252251e-02, 3.59117985e-04, 4.03434610e-02, 8.03278041e-02,

1.20312147e-01, 1.60296490e-01, 2.00280833e-01, 2.40265176e-01,

2.80249519e-01, 3.20233862e-01, 3.60218205e-01, 4.00202549e-01,

4.40186892e-01, 4.80171235e-01, 5.20155578e-01, 5.60139921e-01,

6.00124264e-01, 6.40108607e-01, 6.80092950e-01, 7.20077293e-01,

7.60061636e-01, 8.00045979e-01, 8.40030322e-01, 8.80014665e-01,

9.19999008e-01, 9.59983351e-01, 9.99967694e-01]),

<BarContainer object of 50 artists>)

`plt.figure(figsize=(20, 10))`

plt.imshow(h.abs() > 0.99, cmap='gray', interpolation='nearest')

`<matplotlib.image.AxesImage at 0x126356b50>`

So what we did above is essentially decrease the value of pre-activations such that the values fall in the non-flat region of tanh. The image above shows that very few of the neurons are activated with the examples for the first iteration. Let’s try to use this approach for the entire training set

`run_training_loop()`

`0/ 200000: 3.3189`

10000/ 200000: 2.0652

20000/ 200000: 2.5255

30000/ 200000: 2.0096

40000/ 200000: 2.2400

50000/ 200000: 1.8046

60000/ 200000: 2.0230

70000/ 200000: 2.2471

80000/ 200000: 2.3160

90000/ 200000: 2.1416

100000/ 200000: 2.1507

110000/ 200000: 2.1757

120000/ 200000: 1.5639

130000/ 200000: 1.8486

140000/ 200000: 2.1428

150000/ 200000: 1.9197

160000/ 200000: 2.0423

170000/ 200000: 2.3999

180000/ 200000: 2.0102

190000/ 200000: 2.0702

`split_loss('train')`

split_loss('val')

plt.plot(lossi)

`train 2.0361340045928955`

val 2.102936267852783

[<matplotlib.lines.Line2D at 0x12631a050>]

Summarizing the results so far

- Initial results

`train 2.135850667953491`

val 2.1770708560943604

- Results after softmax fix

`train 2.0693342685699463`

val 2.1324031352996826

- Results after tanh fix

`train 2.0361340045928955`

val 2.102936267852783

We’ve managed to improve the results to a validation loss of 2.109.

## Autocalculating the scale-down factor

In the example above we use 0.2 to scale down W1, this might seem random and it is. To make use of a semi-principled approach, let’s look into how much we need to scale down W1 with an example

`x = torch.randn(1000, 10)`

w = torch.randn(10, 200)

y = x @ w

print(x.mean(), x.std(), y.mean(), y.std())

plt.figure(figsize=(20, 5))

plt.subplot(121)

plt.hist(x.view(-1).tolist(), bins=50, density=True);

plt.subplot(122)

plt.hist(y.view(-1).tolist(), bins=50, density=True);

`tensor(0.0082) tensor(0.9978) tensor(-0.0081) tensor(3.2544)`

The figure on left shows `x`

with std 0.9978 and figure on right shows `y`

with std 3.2554. `y`

takes on more extreme values. Let's see what happens when we multiply x by 0.2

`x = torch.randn(1000, 10) `

w = torch.randn(10, 200) * 0.2

y = x @ w

print(x.mean(), x.std(), y.mean(), y.std())

plt.figure(figsize=(20, 5))

plt.subplot(121)

plt.hist(x.view(-1).tolist(), bins=50, density=True);

plt.subplot(122)

plt.hist(y.view(-1).tolist(), bins=50, density=True);

`tensor(0.0256) tensor(1.0056) tensor(-0.0015) tensor(0.6493)`

y now has std of 0.6493 and takes on less extreme values.

Based on the discussion in section TanH saturation, we know that we want pre-activation values to be in the range of 1 and -1, so we want the standard deviation of y to be the same as that of x. Based on mathematical calculations, that number comes out to be 1/(sqrt(fan_in)) where fan_in is the number of inputs to the layer. Let’s test this theory

`x = torch.randn(1000, 10) `

w = torch.randn(10, 200) * 1/(10 ** 0.5)

y = x @ w

print(x.mean(), x.std(), y.mean(), y.std())

plt.figure(figsize=(20, 5))

plt.subplot(121)

plt.hist(x.view(-1).tolist(), bins=50, density=True);

plt.subplot(122)

plt.hist(y.view(-1).tolist(), bins=50, density=True);

`tensor(0.0028) tensor(1.0039) tensor(0.0020) tensor(1.0034)`

The std of y is now roughly 1 and this is what we want while training our neural network. One paper that has delved deep into this topic and is often cited is Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification

## Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification

### Rectified activation units (rectifiers) are essential for state-of-the-art neural networks. In this work, we study…

arxiv.org

The paper discusses Relu and leaky Relu instead of tanh but the conclusion applied to both since both of them are squashing functions

The paper concludes that since Relu clamps the input which is less than 0 and only allows input > 0 to “pass through”, we need a gain factor of *2/sqrt(fan_in)* to compensate for that. They also discuss that by adding this gain factor, the backpropagation is also taken care of, and we don’t need to perform any additional scaling to control gradients.

## Kaiming Initialization

Pytorch has an implementation of the methods discussed in above paper in kaiming.init_normal and it provides 5/3 * (1/sqrt(fan_in)) as the scale_down factor of w.

## torch.nn.init – PyTorch 2.2 documentation

### In order to implement Self-Normalizing Neural Networks , you should use nonlinearity='linear' instead of…

pytorch.org

let’s use it with our example

`x = torch.randn(1000, 10) `

w = torch.randn(10, 200) * (5/3.0)/((10 ** 0.5))

y = x @ w

print(x.mean(), x.std(), y.mean(), y.std())

plt.figure(figsize=(20, 5))

plt.subplot(121)

plt.hist(x.view(-1).tolist(), bins=50, density=True);

plt.subplot(122)

plt.hist(y.view(-1).tolist(), bins=50, density=True);

`tensor(-0.0036) tensor(1.0131) tensor(-0.0065) tensor(1.6528)`

Using this modification in our weight initialization now and running training loop again

`n_embd = 10 # the dimensionality of the character embedding vectors`

n_hidden = 200 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility

C = torch.randn((vocab_size, n_embd), generator=g)

W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3.0)/((n_embd * block_size) ** 0.5)

W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.01

b2 = torch.randn(vocab_size, generator=g)

parameters = [C, W1, W2, b2]

print(sum(p.nelement() for p in parameters)) # number of parameters in total

for p in parameters:

p.requires_grad = True

`11697`

`run_training_loop()`

`0/ 200000: 3.8202`

10000/ 200000: 2.0395

20000/ 200000: 2.5593

30000/ 200000: 2.0449

40000/ 200000: 2.3927

50000/ 200000: 1.8230

60000/ 200000: 2.0606

70000/ 200000: 2.3983

80000/ 200000: 2.2355

90000/ 200000: 2.0813

100000/ 200000: 2.1729

110000/ 200000: 2.3369

120000/ 200000: 1.6170

130000/ 200000: 1.8658

140000/ 200000: 2.0709

150000/ 200000: 2.0079

160000/ 200000: 1.9552

170000/ 200000: 2.4202

180000/ 200000: 2.0596

190000/ 200000: 2.1027

`split_loss('train')`

split_loss('val')

plt.plot(lossi)

`train 2.0395455360412598`

val 2.1068387031555176

[<matplotlib.lines.Line2D at 0x12af168d0>]

Summarizing the results so far:

- Initial results

`train 2.135850667953491`

val 2.1770708560943604

- Results after softmax fix

`train 2.0693342685699463`

val 2.1324031352996826

- Results after tanh fix

`train 2.0361340045928955`

val 2.102936267852783

- Results with kaiming init

`train 2.0395455360412598`

val 2.1068387031555176

We can see that we arrive at roughly the same position as after tanh fix, but without using a magic numbers for scaling down W1.

This optimization was very crucial in the early years of neural networks but not so much today. Today, with techniques like Batch normalization, it’s not as crucial to initialize neural networks perfectly as it was earlier. Let’s take a look into why that’s the case

## Batch Normalization

## Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

### Training Deep Neural Networks is complicated by the fact that the distribution of each layer's inputs changes during…

arxiv.org

Batch normalization was introduced by Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. This was a very significant paper as it allows us to train large neural networks. The core insight from the paper was to normalize the hidden state to be Gaussian. Since we already know that we want the preactivation to lie within [-1,1] for the tanh function to be active, we can directly attempt to normalize the preactivation to be roughly Gaussian. Mathematically this looks like *(hpreact — h.mean())/h.std()*

All the operations involved in the above expressions are differentiable, we can just use this expressions for forward pass and calculate its derivative during backward pass.

One thing to note here, is that we only want to force pre-activations to be Guassian during first iterations and allow enough flexibility to be updated during backpropagation. For this we’ll make use of scale and shift as mentioned in the paper above.

Let’s re-initialize the training and testing loop

`n_embd = 10 # the dimensionality of the character embedding vectors`

n_hidden = 200 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility

C = torch.randn((vocab_size, n_embd), generator=g)

W1 = torch.randn((n_embd * block_size, n_hidden), generator=g)

W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.01

b2 = torch.randn(vocab_size, generator=g)

# BatchNorm parameters

bngain = torch.ones((1, n_hidden)) # scale

bnbias = torch.zeros((1, n_hidden)) # shift

parameters = [C, W1, W2, b2, bngain, bnbias]

print(sum(p.nelement() for p in parameters)) # number of parameters in total

for p in parameters:

p.requires_grad = True

`12097`

`# same optimization as last time`

max_steps = 200000

batch_size = 32

lossi = []

def run_training_loop(break_on_first=False):

for i in range(max_steps):

# minibatch construct

ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)

Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

# forward pass

emb = C[Xb] # embed the characters into vectors

embcat = emb.view(emb.shape[0], -1) # concatenate the vectors

# Linear layer

hpreact = embcat @ W1 #+ b1 # hidden layer pre-activation

hpreact = (hpreact - hpreact.mean(dim=0, keepdim=True))/(hpreact.std(dim=0, keepdim=True)) * bngain + bnbias

# Non-linearity

h = torch.tanh(hpreact) # hidden layer

logits = h @ W2 + b2 # output layer

loss = F.cross_entropy(logits, Yb) # loss function

# backward pass

for p in parameters:

p.grad = None

loss.backward()

# update

lr = 0.1 if i < 100000 else 0.01 # step learning rate decay

for p in parameters:

p.data += -lr * p.grad

# track stats

if i % 10000 == 0: # print every once in a while

print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')

lossi.append(loss.log10().item())

if break_on_first:

return logits, h, hpreact

break

`@torch.no_grad() # this decorator disables gradient tracking`

def split_loss(split):

x,y = {

'train': (Xtr, Ytr),

'val': (Xdev, Ydev),

'test': (Xte, Yte),

}[split]

emb = C[x] # (N, block_size, n_embd)

embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)

hpreact = embcat @ W1 # + b1

hpreact = (hpreact - hpreact.mean(dim=0, keepdim=True))/(hpreact.std(dim=0, keepdim=True)) * bngain + bnbias

h = torch.tanh(hpreact) # (N, n_hidden)

logits = h @ W2 + b2 # (N, vocab_size)

loss = F.cross_entropy(logits, y)

print(split, loss.item())

Now let’s train the model again

`run_training_loop()`

`0/ 200000: 3.8443`

10000/ 200000: 2.0775

20000/ 200000: 2.7031

30000/ 200000: 2.1059

40000/ 200000: 2.2516

50000/ 200000: 1.9634

60000/ 200000: 2.1270

70000/ 200000: 2.3852

80000/ 200000: 2.5304

90000/ 200000: 2.1849

100000/ 200000: 2.2979

110000/ 200000: 2.3865

120000/ 200000: 1.6278

130000/ 200000: 2.0458

140000/ 200000: 2.4157

150000/ 200000: 2.0811

160000/ 200000: 2.0468

170000/ 200000: 2.4009

180000/ 200000: 2.1332

190000/ 200000: 2.1841

`split_loss('train')`

split_loss('val')

plt.plot(lossi)

`train 2.1137709617614746`

val 2.1398322582244873

[<matplotlib.lines.Line2D at 0x12bc3b5d0>]

## Conclusion

Throughout this article we went over different approaches to homogenize the neural network and to have roughly gaussian distribution during training. We solved the “hockey stick” issue during training by scaling down the weights in last layer. We solved the tanh saturation issue by scaling down weights in hidden layer.

Let’s look at the results so far

- Initial results

`train 2.135850667953491`

val 2.1770708560943604

- Results after softmax fix

`train 2.0693342685699463`

val 2.1324031352996826

- Results after tanh fix

`train 2.0361340045928955`

val 2.102936267852783

- Results with kaiming init

`train 2.0395455360412598`

val 2.1068387031555176

- Results with batch norm

`train 2.1137709617614746`

val 2.1398322582244873

The results obtained using batch norm are comparable to the other results, but we don’t see any improvements. This is to be expected for such a simple neural network with essentially a single hidden layer. We were able to manually calculate the scale-down factor of W1, which performs better than the batch norm. but that would be tedious and difficult to do in a neural network with 100s of hidden layers, and that’s where the batch norm really shines.

For larger neural networks, it is customary to take a hidden layer and append a batch normalization layer on top of that to control the pre-activations at every point of the neural network.

Another interesting thing about batch norms is that now our logits and hidden state activations are not just the function of the input example but also the function of other examples in that batch. What this does is that it adds a bit of entropy to the example and acts as a form of “data augmentation,” which makes it harder for a neural network to overfit the example.

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI