Multi-Layer Perceptron (MLP)/Deep Neural Networks (DNN)

Last time, we briefly introduced MNIST, and provided some performance metrics of the SLP model on classifying the digits.

2017-08-08_14-39-24

A subset of MNIST characters and their associated labels

While I was pleasantly (and perhaps naively) surprised as to how well the SLP performed, the classification accuracy achieved (around 86%) is considered unthinkable, maybe even borderline blasphemous, in the ML community. So how can we improve performance?

Deep Neural Network with Hidden Layers/Nodes

The obvious step above a SLP is a multi-level neural net with hidden layers, shown below:

mnist_2layers

A deep neural network for MNIST with one hidden layer

With more nodes and hidden layers, we can capture subtleties present in the input data and substantially improve performance. So how does the basic algorithm change when compared to the SLP model? The answer is, not a whole lot. The game is the same; we feedforward activations to the output layer through a filtration of weighting matrices, compute a cost function that describes how good our estimate is, and backpropagate through the hidden layers to adjust the weights using gradient descent to minimize said cost function. While the approach is the same, the math and implementation becomes a little more involved. I’ll hash out some of the main concepts below:

Feedforward Activation

To motivate feedforward activation, consider a single neuron (call it ‘neuron j’ or ) in layer  l of the network (so it’s full identifier is n_j^l, meaning it is the jth neuron in the lth layer). The activation of this neuron is dependent on the activations of all the neurons in the previous layer before it, when filtered through the weighting matrices of the previous layers. Said another way, the input to n_j^l is the network sum of all the activations in the previous layer (call these a_{l-1} ) filtered through the j th component of the weighting matrix between layers l and l-1 . The activation a_j^l is this network input, plus a bias term, fed through some activation function \sigma :

a_j^l = \sigma\left(\sum_k w_{j,k}^la_k^{l-1} +b_j^l\right)  

Note that \sigma is traditionally used to denote the sigmoidal activation function (for obvious reasons), but I’m just using it to represent any general activation function. We do this for each layer in the network, starting with the input layer, all the way through to the output layer. The MATLAB code looks like this:

% %--------------------------------------------------------------------------
% % Feedforward to compute activations
% %--------------------------------------------------------------------------
net{i} = (w{i} * a{i}')';
if i < n_L-1
    a{i+1} = activation([net{i}],activation_function);
end

Note that, in MATLAB, we use cell arrays to store network states since subsequent layers layers don’t necessarily have the same dimensionality (i.e. we could have a different number of nodes in each layer), and it would be tedious to hard-code network properties for each individual layer. The activation() function applies the desired activation elementwise to net{i}. Because this sequence of operations is mainly just vector products, it’s easy to vectorize for fast computation.

As before with the SLP, the activation of the output layer (call this a^{l=n_l} ) is the network output. This is what we compare to the true output in formulating a cost function for minimization.

Cost Function

The cost function basically tells us the ‘goodness of fit’ of our model in its current state, and it is something we want to optimize. There are many cost functions we can use (see this StackOverflow link), but the most common (and easiest to work with) is the quadratic (or maximum likelihood):

C_j = 0.5\sum_j(a_j^l-y_j)^2

where l is the output layer, and y is the target output vector. The nice thing about this cost function is the straightforwardness of the derivative, which will come in handy for backpropagation:

\nabla C = (a^{l=n_l}-y)

Backpropagation:

Now we’ve arrived at the meat of the algorithm. The purpose of backpropagation is to adjust the weights between each layer to minimize the cost function. To do this, we need to understand how changes in the weights and biases ultimately change the cost function: i.e.

\frac{\partial C}{\partial w_{j,k}^l}

This basically requires an application of the chain rule ad nauseum. Michael Nielson describes this process much better and in way more depth than I ever could, so I won’t repeat the derivation, but the salient equations are reproduces below:

tikz21

where \odot is the Hadamard or element-wise product (i.e. for C = A \odot B, C[i][j] = A[i][j] \cdot B[i][j] ), \delta is introduced to represent small changes in each neuron, and z^l is the network output at layer l (z^l = w^l \dot (a^l)^T ).

Implemented in MATLAB, this looks like the following (assuming a quadratic cost function:

% derivative of cost function
error_vector = (Y - a{end});

% activation gradients and deltas for each layer
sigma = dactivation(net{end},activation_function);
delta = error_vector.*sigma;

% %--------------------------------------------------------------------------
% % Backpropagate to adjust weights in hidden layer
% %--------------------------------------------------------------------------
for i=n_L-1:-1:1
    dw_mean = delta'*a{i};  % Sum of all delta activations
    w{i} = w{i} + rate.*dw_mean;  % Weight update
    if i ~= 1
        % Update sum of delta activations for next layer
        sigma = dactivation(net{i-1},activation_function);
        delta = (delta*w{i}).*sigma;
    end
end

So basically, the purpose of backpropagation is to make small tweaks to each weighting matrix at each time step to drive the cost function to zero or some other (ideally global) minima. We repeat the feedforward/backpropagation process until some convergence criteria is met.

Momentum, Annealing and Regularization

There are a few more bells and whistles we can add to improve performance. Adding momentum, adapting the learning rate, and regularizing the weights can improve performance both in terms of runtime and avoiding overfitting the data.

Momentum: Momentum can be physically thought of as adding ‘frictional damping’ to reduce the kinetic energy of the system, as the following two (excellent) .gifs show (taken from here). We see that, in addition to improving convergence speed, certain adaptive algorithms can ‘kick’ the current estimate out of local minima such as saddle points:

opt2opt1

Different adaptation strategies and their effects on performance: (left) contours of a loss surface and time evolution of different optimization algorithms, (right) a visualization of a saddle point in the optimization landscape, where the curvature along different dimension has different signs

There are several flavors of momentum, and in my DNN implementation, I’ve implemented regular momentum, Nesterov momentum, and ADAM (which has the attractive property of adapting the learning rate per parameter). We’ll assess performance later.

Annealing: Annealing the learning rate over time can also improve performance. With a high learning rate, the system contains too much kinetic energy and the parameter vector bounces around chaotically, unable to settle down into deeper, but narrower parts of the loss function. So one might elect to start with a high learning rate for fast convergence, and anneal the learning rate to a lower value over time to allow the parameter vector to settle into narrower parts of the loss function. There are a few flavors, and I’ve implemented exponential and 1/t annealing.

Regularization: Regularization is introduced to prevent overfitting. It is implemented by adding an additional term to the loss function that penalizes high values in the weighting matrix. In the case of L2 regularization, the cost function has an additional term \frac{1}{2}\lambda w^2 . In L1 regularization, the additional term is \lambda |w| . There’s also dropout regularization, which randomly samples the network at each training iteration, and applies inputs only to the subsampled nodes in the hidden layers. Note that, during backpropagation, we have to account for the additional regularization term when taking derivatives.

I implemented a few of the momentum, annealing and regularization approaches described here.

function w_update = update_weights(w_in,delta_w,layer,momentum,regularization,rate,n_L)
% Description: This function applies momentum or per-parameter adaptation
% based on the algorithm specified in 'momentum'
%
% INPUTS:
% w_in: the weigthing matrix of the current iteration [nxm matrix]
% delta_w: the gradient [nxm matrix]
% layer: the current layer in the DNN [1x1 scalar]
% momentum: the desired algoirthm [string literal]
% rate: the learning rate [1x1 scalar]
% n_L: the total number of layers in the network [1x1 scalar]
%
% OUTPUTS:
% w_updata: the updated weighting matrix [nxm matrix]

% Persistent variables to preserve states between function calls
persistent v;
persistent m;
persistent v_prev;
persistent init;

% momentum hyperparameters
mu = 0.5;           % For regular/Nesterov momentum
beta1 = 0.9;        % For ADAM
beta2 = 0.99;       % For ADAM
eps = 1E-8;         % For ADAM

% If first initialization, allocate space
if ~strcmp(momentum,'none')&amp;&amp;isempty(init)
    v = cell(n_L);
    v_prev = cell(n_L);
    m = cell(n_L);
    init = 1;
end

% If first time accessing current layer, initialize place in cell
if ~strcmp(momentum,'none')&amp;&amp;isempty(v{layer})
    v{layer} = zeros(size(delta_w));
    v_prev{layer} = zeros(size(delta_w));
    m{layer} = zeros(size(delta_w));
end

% momentum hyperparameter
mu = 0.5;

switch momentum
    case 'regular'
        v{layer} = -mu.*v{layer}+rate.*delta_w;
        w_update = w_in + v{layer};
    case 'nesterov'
        v_prev{layer} = v{layer};
        v{layer} = -mu*v{layer} + rate.*delta_w;
        w_update = w_in + mu*v_prev{layer} + (1+mu).*v{layer};
    case 'adam'
        m{layer} = beta1.*m{layer} + (1-beta1).*delta_w;
        v{layer} = beta2.*v{layer} + (1-beta2).*(delta_w.^2);
        w_update = w_in + rate.*m{layer}./(sqrt(v{layer})+eps);
    case 'none'
        w_update = w_in + rate.*delta_w;
end

lambda = 0.0005;

% Modify gradient update depending on regularization method
switch regularization
    case 'L2'
        w_update = w_update+lambda.*w_in;
    case 'L1'
        w_update = w_update+lambda.*sign(w_in);
    otherwise
        w_update = w_update;
end
end<span id="mce_SELREST_start" style="overflow:hidden;line-height:0;"></span>

MATLAB function for annealing the learning rate

The plots below show the effect of adaptive learning rates (per parameter) on convergence time for the wdbc dataset. We see that, using ADAM, we converge to 98.5% training accuracy in roughly 1/30th of the time it takes to converge with no adaptation (0.5 seconds vs. 15 seconds). Quite a substantial improvement! Also interesting that, in both cases, the algorithm gets stuck at a local minima in the parameter space (associated with 55% training accuracy), but Adam is able to ‘kick’ it out of this minima very quick, while with no parameter adaptation, it takes a few hundred epochs.

wdbc_multilayerann_nh3_nn4_mb16_none_none_error-1-e1529977970494.pngwdbc_multilayerann_nh3_nn4_mb16_adam_none_error-1.png

Adaptive parameter update on wdbc dataset: (left) no adaptation, (right) ADAM adaptation

In the following two bar plots, I’ve compared the different momentum algorithms in terms of number of epochs until convergence, with the following network properties:

  • iris Dataset (2 hidden layers, 4 nodes per layer, minibatch 12, 0.004 rate)
  • wdbc Dataset (2 hidden layers, 16 nodes per layer, minibatch 32, 0.004 rate)

Momentum_Comparison

Effects of rate adaptation/momentum: (left) iris dataset, (right) wdbc dataset

We see that, for both training sets, the Adam algorithm improves performance substantially over no momentum, regular momentum and Nesterov momentum.

Running the DNN on MNIST

After completing my code, I created a DNN with 3 hidden layers, each with 60 neurons, and used it to train the MNIST dataset. I just used the training set of 60,000 samples and split it up into training (42,000), testing (9,000), and validation (9,000) sets. I used minibatches of 128 samples and a learning rate of 0.005. The algorithm ran for about 46 epochs and took 20 epochs (iterations over all of the training data) to converge to a 99.2% training accuracy.

MNIST_MATLAB

DNN MNIST classification: (left) RMSE, (right) classification error

We see from the error plots below that most of the learning takes place in the first 2 epochs or so, after which fine adjustments are made until the target accuracy is reached. Metadata for the training session is given below:

DNN Results
Data Set: MNIST
Number of Hidden Layers: 3
Number of Neurons per Hidden Layer: 60
Size of minibatch: 128
Hidden Activation Function: tanh
Output Activation Function: tanh
Adaptive Parameter Update: adam
Number of Input Features: 784
Number of Output Features/Labels: 10
Number of Training Instances: 48000
Number of total epochs: 5.066667e-02
Total Training Time: 46.328590 s
Training RMS Error: 0.001436
Test RMS Error: 0.006142
Validation RMS Error: 0.006131
Training Classification Success Rate: 99.166667 percent
Testing Classification Success Rate: 96.300000 percent
Validation Classification Success Rate: 96.366667 percent

So we get a testing accuracy of about 96.3%, which basically means, for every 30 digits in the test set, the algorithm will correctly classify 29 of them. This isn’t too bad, and the error is still decreasing so there is an opportunity to improve this even more. For comparison, the state-of-the- art boasts a 99.7% accuracy, which doesn’t sounds like much of a difference until you consider the misclassification rate is now 1 out of every 1000 digits. It might as well be a world of difference.

Some test inputs are shown below. For 32 test inputs, and given our classification accuracy, we would expect one to be incorrect, which is the case. However, our algorithm can be forgiven for the misclassification (that 8 looks a lot like a 0…)

MNIST_test

Random MNIST test inputs and associated predictions

As a final sanity check, I compared the performance of my MATLAB MLP algorithm with the performance of an identical MLP implemented through PyTorch (a popular machine learning framework implemented in Python):

class MLPNet(nn.Module):
    def __init__(self):
        super(MLPNet, self).__init__()
        self.fc1 = nn.Linear(28*28,60)
        self.fc2 = nn.Linear(60,60)
        self.fc3 = nn.Linear(60,60)
        self.fc4 = nn.Linear(60,10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        x = F.tanh(self.fc3(x))
        x = self.fc4(x)
        return x

MNIST Classification comparison between my MATLAB implemenation and the PyTorch implementation is plotted below. We see that the performance between the two implementaitons is nearly indistinguishable. The PyTorch algorithm was also trained for 20 epochs (total training time was about 200 seconds, a factor of 4 slower than the MATLAB implementation), reaching a final test set classification accuracy of 96.7%.

MATLAB_PyTorch_MLP_MNIST

MLP performance for MNIST classification: (left) MATLAB, (right) PyTorch

Some argue that MNIST is falling out of fashion for reasons and (speaking of fashion) have moved towards Fashion-MNIST, which replaces the digits with low-res images of clothing articles with 10 separate classifications. With absolutely no tuning and the same parameters as I used on the regular MNIST, I re-trained on Fashion-MNIST, and achieved a testing accuracy of 87.1%. Some metadata is below:

DNN Results
Data Set: Fashion_MNIST
Number of Hidden Layers: 3
Number of Neurons per Hidden Layer: 60
Size of minibatch: 128
Activation Function: tanh
Adaptive Parameter Update: adam
Number of Input Features: 784
Number of Output Features/Labels: 10
Number of Training Instances: 42000
Number of total epochs: 10.6
Total Training Time: 3003.711874 s
Training RMS Error: 0.037384
Test RMS Error: 0.038780
Validation RMS Error: 0.038518
Training Classification Success Rate: 89.019048 percent
Testing Classification Success Rate: 87.066667 percent
Validation Classification Success Rate: 87.244444 percent

Comparing to some benchmarks here, we see that other MLP (multi-layer perceptron) algorithms are getting testing accuracies of between 87%-88%, so we’re right on the money there. In fact this seems to be an upper bound for vanilla MLPs. I would expect CNN’s to perform much better, and they do (getting around 95% testing accuracy). So maybe we’ll revisit this dataset when I get to CNNs.

Some test inputs are shown below. In my opinion, the items that were misclassified are pretty dubious (that coat looks a lot like a dress to me), so it’s hard to place too much blame on the algorithm.

Fashion_MNIST_test

Random Fashion-MNIST test inputs and associated predictions

Given the results of my algorithm on MNIST and Fashion-MNIST, and comparing to benchmark results found on the web, I think it’s fair to say that my DNN works. MATLAB code can be found on my github. It’s fairly straightforward to run, and there are a number of optional parameters that can be tweaked and tuned at the command prompt level; the following is how the training process is initialized (The README provides more details).

train_DNN(dataset,'numNodes',60,'numLayers',3,'rate',0.002,...
'minibatch',128,'activation','tanh','outputactivation', 'tanh',...
'annealing','exponential','momentum','adam','regularization','none',...
'maxEpochs',40,'targetAccuracy',96.5);

I’m about done with vanilla DNNs, I think it’s about time (foreshadowing, har har) to move on to the next thing…. (Recurrent Neural Nets)…

 

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

w

Connecting to %s

Powered by WordPress.com.

Up ↑

Create your website at WordPress.com
Get started
%d bloggers like this: