Recurrent

Recurrent#

This page explains the concept of a recurrent layer.

The key idea is to create a mechanism where each input affects the processing and outcome of subsequent inputs.

An RNN is essentially a single layer. At each step, it uses \(h_{t-1}\), a special state vector from the previous step.

Strictly speaking, the deduction is as follows:

\[h_t = f(x_t W^T_1 + b_1 + h_{t-1} W^T_2 + b_2)\]

Where:

  • \(x_t\): input at the \(t\)-th step.

  • \(h_t\): vector that describes hidden state at the \(t\)-th step.

  • \(W_1\): weights associated with the input.

  • \(W_2\): weights associated with the state.

  • \(b_1\): bias associated with the input.

  • \(b_2\): bias associated with the state.

  • \(f\): activation function, typically a hyperbolic tangent.

Realization on python#

In this section, we will step by step implement the computations performed by a recurrent layer and compare them with torch.nn.RNN as the reference.

The following cell defines the parameters of the recurrent procedure that we will use as an example.

We have:

  • A sequence of \(15\) elements: \(\left\{ x_1, x_2, \dots, x_{10} \right\}\).

  • Each element is a vector of \(5\) elements: \(x_t \in \mathbb{R}^5\).

  • We are working with a sample containing \(10\) sequences.

  • The state vector is a \(3\)-element vector: \(h_t \in \mathbb{R}^3\).

  • The activation function \(f(x)\) is the hyperbolic tangent.

import torch

samples_size = 15
element_size = 5
sequence_size = 10
state_size = 3
activation = torch.nn.Tanh()

input_data = torch.rand(samples_size, sequence_size, element_size)

For given inputs:

  • Input weights \(W_1\) should be a \(3 \times 5\) matrix.

  • Input bias \(b_1\) should be a vector with \(3\) elements.

  • State weights \(W_2\) should be a \(3 \times 3\) matrix.

  • State bias \(b_2\) should be a vector with \(3\) elements.

W_1 = torch.rand(state_size, element_size)
b_1 = torch.rand(state_size)
W_2 = torch.rand(state_size, state_size)
b_2 = torch.rand(state_size)

Realisation of \(x_1 W^T_1 + b_1 + h_{0} W^T_2 + b_2\), where initial state initialised as \(h_{0}=(0)_{10}\), will take form:

state = torch.zeros(samples_size, state_size)
(input_data[:, 0, :] @ W_1.T) + b_1 + (state @ W_2.T) + b_2
tensor([[1.0187, 0.5179, 2.1283],
        [2.3602, 1.7332, 3.0126],
        [1.7238, 1.3923, 2.8326],
        [1.6351, 1.1317, 2.6059],
        [1.5822, 1.4294, 3.1752],
        [1.5521, 0.8518, 2.3209],
        [2.0734, 1.5617, 2.4171],
        [2.1795, 1.8786, 3.2781],
        [1.9290, 1.4330, 3.1887],
        [1.8527, 1.4238, 3.0310],
        [1.9417, 1.5560, 2.9881],
        [1.5617, 0.9667, 2.4146],
        [2.4599, 2.0704, 3.4121],
        [2.0325, 1.6429, 3.2050],
        [1.4736, 0.9799, 2.1185]])

As the result we got \(h_1\) for each sample out of \(15\).

The implementation of the full recurrent procedure for all \(15\) elements of the sequence is provided in the following cell:

states = [state]

for i in range(input_data.shape[1]):
    res = activation( 
        (input_data[:, i, :] @ W_1.T) + b_1
        + (states[-1] @ W_2.T) + b_2
    )
    states.append(res)

As result we have \(11\) (including initial) states. Which is matrix of \(15\) rows - each row for corresponding observation.

len(states)
11
states[3]
tensor([[0.9638, 0.9522, 0.9989],
        [0.9718, 0.9704, 0.9994],
        [0.9982, 0.9973, 0.9999],
        [0.9974, 0.9940, 0.9999],
        [0.9912, 0.9944, 0.9998],
        [0.9909, 0.9858, 0.9996],
        [0.9976, 0.9981, 0.9999],
        [0.9956, 0.9934, 0.9998],
        [0.9956, 0.9870, 0.9998],
        [0.9964, 0.9967, 0.9999],
        [0.9839, 0.9770, 0.9994],
        [0.9890, 0.9855, 0.9996],
        [0.9906, 0.9791, 0.9996],
        [0.9985, 0.9980, 0.9999],
        [0.9965, 0.9937, 0.9997]])

Then we take all first rows as states of the first sample, all second rows as states of the second sample and so on - which can be realized as stacking on the columns (-2) dimension. Torch also separately represents the last states for all samples, so it is a second output.

my_ans = (
    torch.stack(states[1:], dim=-2), 
    states[-1][None]
)

Now, we will do the same using the ready-made torch.nn.RNN class. Before computing, we need to set the weights of the instance to match those used in the custom procedure:

rnn = torch.nn.RNN(element_size, state_size, batch_first=True)

with torch.no_grad():
    # Copying the parameters that were used earlier:
    rnn.weight_ih_l0.copy_(W_1)
    rnn.bias_ih_l0.copy_(b_1)
    rnn.weight_hh_l0.copy_(W_2)
    rnn.bias_hh_l0.copy_(b_2)

    torch_ans = rnn(input_data)

The following cell verifies that both outputs are identical.

torch.testing.assert_close(torch_ans[0], my_ans[0])
torch.testing.assert_close(torch_ans[1], my_ans[1])