When you let a machine learn too much, it may happen that it will do worse. It is just like us – as human being – start forgetting things, or even go crazy, when we are forced to study excessively.
All jokes aside, I am having some chances to play with deep and (reasonably) big neural network, and I just found out what have been said above.
In the first experiment, I trained a feed forward neural network on MNIST. The net has 3 hidden layers with 1024 – 1024 – 2048 hidden units, fully connected, trained by stochastic gradient descent with momentum, L2 weight decay and decaying learning rate. The cost function is Cross Entropy. The net is similar to the one described here, but 2 layer deeper. The number of errors on training/validation/test sets are displayed in the figure below
It gone wild, eventually. After 700*2000 = 1400000 iterations, the number of errors jumped to almost 10000, which is 99% of the test set. The more we train it, the more stupid it is.
Looking into the few first iterations, it can be seen that the net was actually learning well. The number of errors kept steadily decreasing, and after 50*2000=10000 iterations, the net achieves a reasonably good performance with around 170 errors on the test set. Remember that the best result reported so far on MNIST (without applying any data augmentation trick) is 160 errors. This is just to say that to obtain a competitive result on MNIST is quite simple with deep neural network.
But what the heck was happening after 1400000 iterations? Why did the error increase wildly like this? While backpropagation (and stochastic gradient descent) ensures that the more we learn, the more we decrease the error, or at least stay at the current best state. But why it went worse in this case?
Well, the reason is that I used L2 weight decay. After training for a while, the Cross Entropy falls to a local optimum and the net stops learning. However the L2 weight regularization is still there, so the net just keeps pushing all the weights toward zeros. After some points, this will just totally ruin the net.
I then performed another experiment with the same network architecture, but I used dropout as the regularization method, and another weight norm technique. I got very similar behavior
The net just went wild after a while. Looking into the first few steps, we can see that with dropout, the training error decreases more slowly, and the test error approaches very close to 160, which is the best reported result. It shows that our network has been overfitted, but different regularization methods seems to be powerful in controlling the behavior of the net. Of course in any case, we can manually stop the training process and select the “just-right” model, or even stop-early and pick a good model after very few first steps.
The nice thing is that, thanks to the recently public library for training NN on GPU called deepnet, all of those experiments did not take much time. Actually it took only 5 minutes to complete the first 80*2000 = 160000 iterations. After that much iterations, the network was already well-trained.
So the moral of the story is do not study too much, or it might drive you crazy, just like a neural net being over-trained.