Recurrent#
This page explains the concept of a recurrent layer.
The key idea is to create a mechanism where each input affects the processing and outcome of subsequent inputs.
An RNN is essentially a single layer. At each step, it uses \(h_{t-1}\), a special state vector from the previous step.
Strictly speaking, the deduction is as follows:
Where:
\(x_t\): input at the \(t\)-th step.
\(h_t\): vector that describes hidden state at the \(t\)-th step.
\(W_1\): weights associated with the input.
\(W_2\): weights associated with the state.
\(b_1\): bias associated with the input.
\(b_2\): bias associated with the state.
\(f\): activation function, typically a hyperbolic tangent.
Realization on python#
In this section, we will step by step implement the computations performed by a recurrent layer and compare them with torch.nn.RNN
as the reference.
The following cell defines the parameters of the recurrent procedure that we will use as an example.
We have:
A sequence of \(15\) elements: \(\left\{ x_1, x_2, \dots, x_{10} \right\}\).
Each element is a vector of \(5\) elements: \(x_t \in \mathbb{R}^5\).
We are working with a sample containing \(10\) sequences.
The state vector is a \(3\)-element vector: \(h_t \in \mathbb{R}^3\).
The activation function \(f(x)\) is the hyperbolic tangent.
import torch
samples_size = 15
element_size = 5
sequence_size = 10
state_size = 3
activation = torch.nn.Tanh()
input_data = torch.rand(samples_size, sequence_size, element_size)
For given inputs:
Input weights \(W_1\) should be a \(3 \times 5\) matrix.
Input bias \(b_1\) should be a vector with \(3\) elements.
State weights \(W_2\) should be a \(3 \times 3\) matrix.
State bias \(b_2\) should be a vector with \(3\) elements.
W_1 = torch.rand(state_size, element_size)
b_1 = torch.rand(state_size)
W_2 = torch.rand(state_size, state_size)
b_2 = torch.rand(state_size)
Realisation of \(x_1 W^T_1 + b_1 + h_{0} W^T_2 + b_2\), where initial state initialised as \(h_{0}=(0)_{10}\), will take form:
state = torch.zeros(samples_size, state_size)
(input_data[:, 0, :] @ W_1.T) + b_1 + (state @ W_2.T) + b_2
tensor([[1.0187, 0.5179, 2.1283],
[2.3602, 1.7332, 3.0126],
[1.7238, 1.3923, 2.8326],
[1.6351, 1.1317, 2.6059],
[1.5822, 1.4294, 3.1752],
[1.5521, 0.8518, 2.3209],
[2.0734, 1.5617, 2.4171],
[2.1795, 1.8786, 3.2781],
[1.9290, 1.4330, 3.1887],
[1.8527, 1.4238, 3.0310],
[1.9417, 1.5560, 2.9881],
[1.5617, 0.9667, 2.4146],
[2.4599, 2.0704, 3.4121],
[2.0325, 1.6429, 3.2050],
[1.4736, 0.9799, 2.1185]])
As the result we got \(h_1\) for each sample out of \(15\).
The implementation of the full recurrent procedure for all \(15\) elements of the sequence is provided in the following cell:
states = [state]
for i in range(input_data.shape[1]):
res = activation(
(input_data[:, i, :] @ W_1.T) + b_1
+ (states[-1] @ W_2.T) + b_2
)
states.append(res)
As result we have \(11\) (including initial) states. Which is matrix of \(15\) rows - each row for corresponding observation.
len(states)
11
states[3]
tensor([[0.9638, 0.9522, 0.9989],
[0.9718, 0.9704, 0.9994],
[0.9982, 0.9973, 0.9999],
[0.9974, 0.9940, 0.9999],
[0.9912, 0.9944, 0.9998],
[0.9909, 0.9858, 0.9996],
[0.9976, 0.9981, 0.9999],
[0.9956, 0.9934, 0.9998],
[0.9956, 0.9870, 0.9998],
[0.9964, 0.9967, 0.9999],
[0.9839, 0.9770, 0.9994],
[0.9890, 0.9855, 0.9996],
[0.9906, 0.9791, 0.9996],
[0.9985, 0.9980, 0.9999],
[0.9965, 0.9937, 0.9997]])
Then we take all first rows as states of the first sample, all second rows as states of the second sample and so on - which can be realized as stacking on the columns (-2) dimension. Torch also separately represents the last states for all samples, so it is a second output.
my_ans = (
torch.stack(states[1:], dim=-2),
states[-1][None]
)
Now, we will do the same using the ready-made torch.nn.RNN
class. Before computing, we need to set the weights of the instance to match those used in the custom procedure:
rnn = torch.nn.RNN(element_size, state_size, batch_first=True)
with torch.no_grad():
# Copying the parameters that were used earlier:
rnn.weight_ih_l0.copy_(W_1)
rnn.bias_ih_l0.copy_(b_1)
rnn.weight_hh_l0.copy_(W_2)
rnn.bias_hh_l0.copy_(b_2)
torch_ans = rnn(input_data)
The following cell verifies that both outputs are identical.
torch.testing.assert_close(torch_ans[0], my_ans[0])
torch.testing.assert_close(torch_ans[1], my_ans[1])