RNN - Recurrent Neural Networks

Recurrent Neural Networks

Intuition

For sequential modeling, we may have inputs that can vary wildly and depend on more contextual information that a feed-forward neural network (the simplest of neural networks) can't handle.

Contextual Example:

Japan is where I grew up, but I now live in Chicago. I speak fluent _______.

You can try to predict what the word is by looking at some of the words right before it (speak, fluent) but is that really sufficient to guess this word? (The first word offers the most context).

Ordering Example:

The food was good, not bad at all

The food was bad, not good at all

Counting the words don't work because the order would matter. Above example has exact same words but different order.

How it works

RNN works like a feed-forward neural network, except that between the input and output vectors, we have a function \(h_t\) that applies a recurrence relation at each step (measured by time \(t\)).

To further illustrate, the key difference between RNN and feed-forward NN is that RNN can have a sequence of inputs, instead of a single input in feed-forward NN.

To update the function programmatically, we use the tanh function which represents a calculus function for finding a derivative.

In the picture above, the far left NN is a feed-forward. The two networks on the right are RNN.

Training

To minimize loss at each step, we can use gradient descent via back propagation. However, this was much easier in a feed-forward neural network where one input leads to one output.

The most notable difficulty is that we would be doing back propagation for each time series that are based on recurrence relations. This can lead to exploding gradients and vanishing gradients, which leads to poor network weights.

Issues

Exploding Gradients (too high)

  • Gradients just get bigger and bigger, network becomes more and more unreliable
    • Solution: Use gradient clipping

Vanishing Gradients (too low)

  • Adds bias to network to capture short term dependencies (lean towards more recent contextual information, dropping info such as "Japan" in the first example)
    • Solution 1: Use activation functions (i.e. Sigmoid, tanh, ReLU)
    • Solution 2: Initialize weights to identity matrix and bias to 0, to help prevent shrinking weights to 0
    • Solution 3: Add more logic to \(h_t\) by adding some gating to control what information is passed through. LSTM is one example.

Real World Applications

  • Google Translate
  • Sentence to Sentiment (Happy post, sad post?)
  • Music generator
  • Self-driving cars (trajectory detection of objects)
  • Environmental modeling (predict environmental markers such as climate, global warming)