Recurrent Neural Networks
Intuition
For sequential modeling, we may have inputs that can vary wildly and depend on more contextual information that a feed-forward neural network (the simplest of neural networks) can't handle.
Contextual Example:
Japan is where I grew up, but I now live in Chicago. I speak fluent _______.
You can try to predict what the word is by looking at some of the words right before it (speak, fluent) but is that really sufficient to guess this word? (The first word offers the most context).
Ordering Example:
The food was good, not bad at all
The food was bad, not good at all
Counting the words don't work because the order would matter. Above example has exact same words but different order.
How it works
RNN works like a feed-forward neural network, except that between the input and output vectors, we have a function \(h_t\) that applies a recurrence relation at each step (measured by time \(t\)).
To further illustrate, the key difference between RNN and feed-forward NN is that RNN can have a sequence of inputs, instead of a single input in feed-forward NN.
To update the function programmatically, we use the tanh
function which represents a calculus function for finding a derivative.
In the picture above, the far left NN is a feed-forward. The two networks on the right are RNN.
Training
To minimize loss at each step, we can use gradient descent via back propagation. However, this was much easier in a feed-forward neural network where one input leads to one output.
The most notable difficulty is that we would be doing back propagation for each time series that are based on recurrence relations. This can lead to exploding gradients and vanishing gradients, which leads to poor network weights.
Issues
Exploding Gradients (too high)
- Gradients just get bigger and bigger, network becomes more and more unreliable
- Solution: Use gradient clipping
Vanishing Gradients (too low)
- Adds bias to network to capture short term dependencies (lean towards more recent contextual information, dropping info such as "Japan" in the first example)
- Solution 1: Use activation functions (i.e. Sigmoid, tanh, ReLU)
- Solution 2: Initialize weights to identity matrix and bias to 0, to help prevent shrinking weights to 0
- Solution 3: Add more logic to \(h_t\) by adding some gating to control what information is passed through. LSTM is one example.
Real World Applications
- Google Translate
- Sentence to Sentiment (Happy post, sad post?)
- Music generator
- Self-driving cars (trajectory detection of objects)
- Environmental modeling (predict environmental markers such as climate, global warming)