☐ Sequence Processing

1 List representation

xs = [
    {"name": "Joe", "grade": 80},
    {"name": "Jack", "grade": 85},
    {"name": "Jill", "grade": 90},
]
xs
[{'name': 'Joe', 'grade': 80},
 {'name': 'Jack', 'grade': 85},
 {'name': 'Jill', 'grade': 90}]
import torch
tensors = torch.tensor([
    [1, 80.],
    [2, 85.],
    [3, 90.],
])
tensors, tensors.shape
(tensor([[ 1., 80.],
         [ 2., 85.],
         [ 3., 90.]]),
 torch.Size([3, 2]))

1.1 List processing with map / broadcast

def MAP(f, sequence):
    return [f(x) for x in sequence]
def extract_grade(x):
    return x["grade"] / 100
MAP(extract_grade, xs)
[0.8, 0.85, 0.9]
tensors[:, 1] / 100.
tensor([0.8000, 0.8500, 0.9000])

2 List processing with reduce

def REDUCE(update_fn, sequence, init_state):
    s = init_state
    for x in sequence:
        s = update_fn(s, x)
    return s
from typing import NamedTuple
class State(NamedTuple):
    total: float
    count: int
    
def update_state(s:State, x)->State:
    return State(s.total + x['grade'], s.count + 1)
REDUCE(update_state, xs, State(0, 0))
State(total=255, count=3)
def AVG(s:State)->float:
    return s.total / s.count
AVG(
    REDUCE(
        update_state, xs, State(0, 0)
    )
)
85.0

2.0.1 Reduce with tensor operations

s0 = torch.tensor([0.0, 0.0])
update_fn = lambda s, t: torch.tensor([
    s[0] + t[1],
    s[1] + 1,
])

output_fn = lambda s: s[0]/s[1]
REDUCE(update_fn, tensors, s0)
tensor([255.,   3.])
output_fn(
    REDUCE(update_fn, tensors, s0)
)
tensor(85.)

We can express update_fn as a linear layer in the form of: \[ s = As + Bt + c\]

With \(s\in\mathbb{R}^2\) and \(t\in\mathbb{R}^2\), \[ \begin{eqnarray} \left[\begin{array}{c} s[0] + t[1] \\ s[1] + 1 \end{array}\right] = \left[\begin{array}{c} 1 & 0 \\ 0 & 1 \end{array}\right]s + \left[\begin{array}{c} 0 & 1 \\ 0 & 0 \end{array}\right]t + \left[\begin{array}{c} 0 \\ 1 \end{array} \right] \end{eqnarray} \]

A = torch.tensor([[1., 0.],
                  [0., 1.]])
B = torch.tensor([[0., 1.],
                  [0., 0.]])
c = torch.tensor([0., 1.])

REDUCE(lambda s, t: A@s+B@t+c, tensors, s0)
tensor([255.,   3.])

2.1 Sequence processing with recurrent networks

Suppose we have a sequence of vectors:

  • Each element is \(x\in\mathbb{R}^n\),
  • There are \(L\) elements in the sequence.

\[\mathrm{seq}\in\mathbb{R}^{L\times n}\]


We want to design a neural network \(F\) that can accept sequences of different \(L\) as inputs.

\[ F(\mathrm{seq})\in\mathbb{R}^k \]


The design of \(F\) involves:

  1. Have an internal initial state \(s_0\in\mathbb{R}^k\),
  2. and an update function: \(h:\mathbb{R}^k\times\mathbb{R}^n\to\mathbb{R}^k\)

\[ F(\mathrm{sequence}) = \mathbf{REDUCE}(h, \mathrm{seq}, s_0) \]

2.2 Building RNN from scratch

Let’s implement the model as a PyTorch module.

import torch
from torch import nn
from my import add_method
class MyRNN(nn.Module):
    def __init__(self, dim_input, dim_state, activation_fn):
        super().__init__()
        self.A = nn.Linear(dim_state, dim_state)
        self.B = nn.Linear(dim_input, dim_state)
        self.activation = activation_fn
@add_method(MyRNN)
def forward(self, input_sequence, init_state):
    s = init_state
    for x in input_sequence:
        s = self.activation(
            (self.A(s[None, :]) + self.B(x[None, :]))
        ).squeeze(0)
    return s

This is the non-batch processing version. Namely,

  • input_sequence is just one sequence, of the shape (L, dim_input), where L is the length of the sequence.
  • init_state is a single vector of the shape (dim_state,).

There is no axis for batch.

Let’s try it out with some random inputs.

Suppose we capture the cursor movement on the screen, and model them as variable length sequences of 2D vectors. The learning task is to classify them into four known categories.

Given enough training data, we can use a RNN to perform the classification.

dim_input = 2
dim_state = 10
num_categories = 4
# A sequence of 5 vectors
seq_1 = torch.randn((5, dim_input))

seq_1
tensor([[ 0.7020, -0.7205],
        [-0.2819,  0.5182],
        [-0.3640, -0.9410],
        [-1.1234, -0.8819],
        [ 1.2646,  1.3496]])
# An initial state
init_state = torch.zeros((dim_state,))
rnn = MyRNN(
    dim_input=dim_input,
    dim_state=dim_state,
    activation_fn=nn.ReLU())

final_state = rnn(seq_1, init_state)
final_state
tensor([1.0819, 1.2778, 0.1784, 0.0000, 0.0000, 1.2444, 0.0000, 0.0000, 0.2301,
        0.0000], grad_fn=<SqueezeBackward1>)

Since we want to do 4-category classification, we will want to have a logistic regression head to compute logits.

classifier = nn.Linear(dim_state, num_categories)
classifier(final_state[None, :])
tensor([[-0.1173,  0.8585,  0.4113,  0.3788]], grad_fn=<AddmmBackward0>)

Once sufficient training data has been collected, we can utilize an appropriate loss function, optimizer and other elements of neural networks to train a complete sequence classifier.

2.3 Introducing PyTorch built-in RNN

PyTorch provides a built-in RNN. It performs batch processing on sequences.

rnn = nn.RNN(
    input_size=dim_input,
    hidden_size=dim_state,
    num_layers=1,
    nonlinearity='relu',
    batch_first=True,
)
input_batch = seq_1[None, :, :]
input_batch.shape
torch.Size([1, 5, 2])
(all_states, final_state) = rnn(input_batch)
all_states.shape
torch.Size([1, 5, 10])

The first returned value all_states is a tensor containing the states produced by RNN for every state of the reduce operation. Since the input sequence has length of 5, all_states will contain 5 state vectors. Thus, its shape is (1, 5, 10).

The general case of the all_states shape is

  • (batch_size, length, dim_out) if we use batch_first=True.
  • (length, batch_size, dim_out) if we use batch_first=False (default).
final_state.shape
torch.Size([1, 1, 10])

The second return value final_state is the final state vector, which corresponds to the very last state vector in the all_states tensor.

The shape of final_state is given by (num_layers, batch_size, dim_state). In our case, it is (1, 1, 10).

(all_states[0, -1] == final_state[0])
tensor([[True, True, True, True, True, True, True, True, True, True]])

We can verify that the last state vector in all_states is the final state vector final_state.