Train Neural Networks with Hard Activation Functions Using Target Propagation Part 1
Last Updated on November 3, 2024 by Editorial Team
Author(s): Greg Short
Originally published on Towards AI.
A brief look
In the first tutorial of this series, weβll go over how to train neural networks with hard activation functions using target propagation. A key component of neural networks is their non-linear activation functions. The non-linear activation functions make them universal approximators. In general, activation functions for neural networks need to be continuously differentiable. This is because backpropagation (backprop) relies on gradients to determine the direction to update the weights. However, activation functions such as the sign function are not continuously differentiable anywhere but can be useful because of their increased interpretability.
As these activation functions are not easy to train with standard backprop an alternative approach is required. One alternative is to use target propagation. Target propagation propagates targets backward through the network rather than error gradients. This is done by using an inverse function or approximating the inverse function for the layer. Here is the basic formula for target propagation
where t is the target, i is the layer index, f is the layer operation and inv is the inverse.
By doing this, one can bypass the need for continuously differentiable activation functions in a neural network.
Today, weβll cover the following
- Hard activation functions
- Target propagation concept
- Target propagation implementation
- Target propagation evaluation
Weβll use the framework Zenkai to implement target propagation. Itβs a framework built on PyTorch that uses approaches other than backprop. The code for this tutorial can be accessed here under notebooks/t2_targetpropagation and the code for Zenkai, used in this tutorial, can be accessed here.
Hard activation functions
Weβll use two hard activation functions: the sign function and a stochastic activation function. These functions do not work well with backprop without modifications because the derivative is either 0 or non-differentiable at all points.
Sign Activation Function
The sign function (Fig 1.) outputs -1 if the input is less than 0, 0 if equal to 0, and 1 if greater than 0. The formula for the sign function is defined as follows
Though technically it can output 0, inputs of 0 will most likely not be observed in intermediate layers so it is effectively binary. Fig. 1 shows a graph of the sign function.
Schochastic Activation Function
There can be a variety of stochastic activation functions, but the stochastic function weβll use today is based on the Heaviside step function. It will randomly output either 0 or 1 based on the value of the input. The Heaviside step function outputs a 0 for all negative values and a 1 for all non-negative numbers.
To make this function stochastic, the input will first be processed with a sigmoid. Then a threshold between 0 and 1 will be randomly selected. If the input is less the threshold, it will output a 0, otherwise, it will output a 1. This means if the output of the sigmoid is 0.1, there is a 10% chance of outputting a 1 and a 90% chance of outputting a 0. The formula is defined as follows
where Ο(x) is the standard sigmoid function and r ~ U(0, 1).
What is target propagation?
Target propagation is an alternative to backprop that propagates targets backward through each layer of a neural network. It uses the inverse of the layer to get the targets for the previous layer. Since in most cases a layer will not be invertible, an approximation can be used instead.
The options for this are
- the inverse: If the layer is invertible.
- an inverse approximation (such as pseudoinverse): Can be used for linear transformations or other cases for which an approximation to the inverse is defined.
- a learned inverse approximation (such as an autoencoder): This requires machine learning to learn the parameters approximation.
In addition, target propagation is also considered to be more biologically plausible than backprop as our brains are thought not to propagate gradients through multiple layers.
Today, weβll be using a learned approximation but the other two could also be implemented with Zenkai. Fig. 3 is a diagram of what target target propagation looks like. It is bidirectional up until the final layer, which consists only of a feedforward component.
As the mapping from HN to Y should focus on predicting the output, we will still use backprop to get the targets for HN. The targets of HN can be determined by subtracting the gradients on the outputs of HN from the outputs of HN. This can be calculated with the following formula
where t is the target for the hidden layer, y is the output, lr is the learning rate, dy is the gradient, and n is the index of the last layer in the network. The learning rate can be used to control how far the target is from the output, which will help stabilize training. Weβll call this parameter βx_lrβ below.
How weβll implement target propagation
In this implementation, weβll use 4 layers.
- Layers 1β3: Autoencoders that learn to reconstruct the input to make use of target propagation.
- Layer 4: A regular feedforward network for the final layer.
Layers 1β3 use autoencoders. Autoencoders are neural networks that learn how to reconstruct an input. They consist of two components: a feedforward network that maps the input onto an encoding and a feedback network that maps the encoding to a reconstruction of the input. The feedforward component will predict the target that is propagated backward, and the reconstruction component will predict the input to the layer from the output of the layer.
Typically it will use a regularization or constraint such as 1) using fewer hidden units than input units, 2) using L1 or L2 regularization on the encodings, 3) adding noise such as dropout to the inputs, etc.
Weβll implement target propagation with Zenkai, a framework built using PyTorch that allows one to make it easier to train neural networks or deep learning machines that do not require gradient descent such as target propagation. Zenkai makes use of PyTorchβs computational graph but on the backward pass, it calls the accumulate(), step_x(), and step() methods on the LearningMachine to give the user more control over the learning process. The order and which are called can be altered by changing the learning mode, but today weβll have it executed in the following order: accumulate, step_x, step.
The code below shows the AutoencoderLearner we will make use of Zenkai. The AutoencoderLearner inherits from nn.module so it contains the forward() method. In addition, it has three core methods not used in PyTorch for learning: accumulate(), which accumulates updates to the parameters, step_x(), and step() which computes the target for the preceding layer. For target propagation, step_x() is a key component as it passes the target through the feedback component of the autoencoder to calculate the target for the preceding layer.
class AutoencoderLearner(zenkai.LearningMachine):
"""
Use this to train an autoencoder. The step() method is not implemented
so you must use torch's optimizer to update the parameters.
Note: Some of initialization is not described here.
"""
def __init__(self, in_features: int, out_features: int, ...):
"""
the init function creates the network based on the parameters
passed in, including the optimizer
Layer is an nn.Module that wraps nn.Linear, the activation etc
"""
self.feedforward = Layer(
in_features, out_features,
...
)
self.feedback = Layer(
out_features, in_features, ...
)
# instantiate activation functions etc
self._criterion = zenkai.NNLoss('MSELoss', reduction='mean')
self._reverse_criterion = zenkai.NNLoss(rec_loss, reduction='mean')
# The optimizer is used by the step() method to update the parameters
self._optim = torch.optim.Adam(self.parameters(), lr=...)
def accumulate(self, x: IO, t: IO, state: State):
"""
Accumulate the gradients for the feedforward
and feedback models.
Args:
x (IO): the input
t (IO): the output
state (State): the learning state
"""
z = self.feedback(state._y.f)
z_loss = self._reverse_criterion.assess(x, iou(z))
t_loss = self._criterion.assess(state._y, t)
(t_loss + self.rec_weight * z_loss).backward()
def step_x(self, x: IO, t: IO, state: State) -> IO:
"""Propagate the target and the output back and
calculate the difference.
Args:
x (IO): the input
t (IO): the target
state (State): the learning state
Returns:
IO: the target for the incoming layer
"""
# do target propagation by passing the target
# through the feedback component
return zenkai.iou(self.feedback(t.f))
def step(self, x: IO, t: IO, state: State):
"""
The step method updates the parameters based on the updates
accumulated in the accumulate() method.
Args:
x (IO): the input
t (IO): the target
state (State): the learning state
"""
# update the parameters
self._optim.step()
self._optim.zero_grad()
def forward_nn(self, x: IO, state: State) -> Tensor:
"""Obtain the output of the learner.
The forward() method will call this method.
Args:
x (IO): The input
state (State): The learning state
Returns:
Tensor: The output of the function
"""
return self.feedforward(x.f)
Since this is implemented with PyTorch, it is possible to connect the AutoencoderLearner to normal PyTorch modules.
How to handle hard activation functions
One problem with using hard activation functions is that the reconstruction activation should match the activation used on the layer itβs trying to predict the output of. If a hard activation function is used, we run into the problem that the gradient of the activation function will be 0, which is the problem that we aimed to overcome in the first place.
Here are some conditions to consider:
- The feedback network must use a non-linear output activation function for each layer.
- The feedback networkβs activation should be the same as the feedforward layersβ output activations.
- The feedforwardβs output activation must not be a hard activation function. Otherwise, we do not eliminate the problem we are trying to solve because we still have to propagate gradients through a hard activation function that doesnβt have any gradients.
Letβs consider how to do this for the two hard activation functions weβre using today.
Case 1: Sign
In the case of the sign function, letβs use TanH on the outputs of a layer and Sign on the inputs of the following layer. For the feedforward network, this will be mathematically equivalent, but this will allow the feedback network to train with hard activation functions.
So, the forward component becomes
Sign => Dropout => Linear => BatchNorm => TanH
And the reverse component will be
Linear => BatchNorm => TanH
Case 2: Stochastic
For the case of the stochastic, letβs use LeakyReLU on the outputs of a layer.
Stochastic => Dropout => Linear => BatchNorm => LeakyReLU
And for the reverse component
Linear => BatchNorm => LeakyReLU
These approaches allow both the reverse component and the forward component to be non-linear functions while still making use of hard activation functions.
Evaluation
Next, letβs evaluate target propagation. We will train 4 models 1) a baseline one using LeakyReLU and backprop, 2) one using target propagation with LeakyReLU for comparison, 3) one with target propagation and a sign activation, and 4) one with target propagation and stochastic activation.
The machine architecture is given below. The baseline network will use the same number of units as the feedforward component of the target propagation network. One other thing to note is that the last layer does not use a hard activation on the input.
Machine Architecture:
- Layers:
– Layer 1: 300 units
– Layer 2: 300 units
– Layer 3: 300 units
– Layer 4: 10 units - Dataset: FashionMNIST
- Activations:
– LeakyReLU
– Sign
– Stochastic - Layer 4 β x_lr: 1e-3
- Dropout rate:
– Layer 1: 0.1
– Layers 2, 3: 0.05
– Layer 4: 0.0 - Learning Algorithm Layers 1β3: Target Propagation with Adam
- Learning Algorithm Layer 4: Backprop with Adam
Here are the training parameters for the networks.
Training Parameters:
- Epochs: 20
- Learning rate: 1e-3
- Minibatch size: 128
- LR Scheduler (for autoencoder layers)
– Step size: 40 iterations
– gamma: 0.9
The plot of the training loss is shown in Fig. 4. Here, you can see that the target propagation networks were successfully able to learn but that the backdrop-based learner learns much faster using LeakyReLU. This is to be expected, though, as the target propagation network relies on having fairly accurate reconstructions of the inputs of each layer, and LeakyReLU layers have higher representational capacity in the outputs (i.e., more possible bits of information) than the other two. So, it will tend to take longer to start learning and will be affected by inaccuracies in the reconstruction. You can also see that the loss is considerably more noisy for the target propagation networks in comparison to the baseline.
Here are the classification results on the test set after 20 epochs of training. The target propagation networks did not reach the level of the baseline after 20 epochs. Since the training loss has not converged, better results can likely be achieved by training for more iterations.
- Baseline: 0.85
- TP LeakyReLU: 0.7876
- TP Sign: 0.7118
- TP Stochastic: 0.6723
Closing
So now we have trained networks that use hard activation functions using target propagation and compared the results to a baseline network using LeakyReLU activations with backdrop. Then, networks using the sign activation function and stochastic activation function were used. The performance using LeakyReLU with backdrop was decidedly better after 20 epochs, but learning was demonstrated both with the sign activation and the stochastic activation.
Target propagation does have some downsides. Primarily, it requires the training of two networks instead of one with the implementation given here. Secondly, it can be difficult to train. If the reconstruction is poor, then the training can be hard to stabilize. While here, weβve focused on using them with hard activation functions. The applications go beyond that, and weβll look into that in future blogs. In the next article, we will look into an improvement to target propagation called difference target propagation.
References
[1] Bengio, Yoshua et al., Towards Biologically Plausible Deep Learning, (2015), http://arxiv.org/abs/1502.04156.
[2] Short, Greg, Zenkai β Framework For Exploring Beyond Backpropagation. 2023, https://arxiv.org/pdf/2311.09663.
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI