import torch
import matplotlib.pyplot as pl
import numpy as np
☑ Regression
1 The data
= torch.linspace(0, 1, 10)
x = 3 * x + torch.sin(6 * x) + 1 y_true
'--o'); pl.plot(x, y_true,
2 A model
We start with a model that relates \(y\) with the input \(x\).
\[ y = f(x | W) \]
where \(W\) is one or more tunable parameters. In the case of linear regression, we have \(W=[a, b]\) such that
\[ f(x|W) = ax + b \]
Given some weights \(W = (a, b)\), we can assess how well the model fits the true values of \(y\) by computing the error:
\[ \mathrm{err} = \frac{\sum_{i}|y_i - f(x_i|W)|^2}{n} \]
Basically, we have:
\[ \mathrm{err} : W \mapsto \mathbb{R} \]
The error function can help us to improve the model by computing the gradient with respect to \(W\).
\[ W' = W - \epsilon\cdot \nabla\mathrm{err} \]
For sufficiently small \(\epsilon\), we are guaranteed to have \(\mathrm{err}(W') < \mathrm{err}(W)\)
3 PyTorch Model
= torch.tensor([0., 0.], requires_grad=True) W
= W[0]*x + W[1] y_pred
= torch.linalg.vector_norm(y_pred - y_true) / x.shape[0]
err err
tensor(0.8162, grad_fn=<DivBackward0>)
err.backward()
= W.grad
grad grad
tensor([-0.1710, -0.3053])
3.1 Update the weights with a step
Let’s update the weights with a small step size:
\[ W_{n+1} = W_{n} - \epsilon\cdot \nabla\mathrm{err} \]
= 0.01
epsilon with torch.no_grad():
* grad)
W.sub_(epsilon W
tensor([0.0017, 0.0031], requires_grad=True)
Q: Why do we need with torch.no_grad()
?
A: By default, any computation that involves \(W\) will trigger a gradient propagation. But \(W \to W-\epsilon\cdot\nabla\mathrm{err}\) is not part of model computation, but rather the optimization computation. So, we have to disable the gradient tracking.
#
# Previous error is 0.8162
#
= W[0]*x + W[1]
y_pred = torch.linalg.vector_norm(y_pred - y_true) / x.shape[0]
err err
tensor(0.8150, grad_fn=<DivBackward0>)
3.2 One more gradient update
W.grad.zero_()
err.backward()= W.grad
grad
with torch.no_grad():
* grad)
W.sub_(epsilon
W
tensor([0.0034, 0.0061], requires_grad=True)
Q: Why do we need W.grad.zero_()
?
A: The exact action of err.backward()
is to accumulate the gradient to the existing W.grad
. So, by default, it will retain the gradient value from the previous update. Since we only care about the gradient from the current error, we need to clear W.grad
before err.backward()
.
#
# Previous error is 0.8150
#
= W[0]*x + W[1]
y_pred = torch.linalg.vector_norm(y_pred - y_true) / x.shape[0]
err err
tensor(0.8138, grad_fn=<DivBackward0>)
4 A training loop
def init(W):
with torch.no_grad():
W.zero_()
def err(W):
= W[0] * x + W[1]
y return torch.linalg.norm(y - y_true) / x.shape[0]
def update_step():
W.grad.zero_()= err(W)
e
e.backward()with torch.no_grad():
* W.grad)
W.sub_(epsilon return e.detach().item()
def report(i):
with torch.no_grad():
= err(W)
e print("[{}] a={:.2f} b={:.2f}, err={:.2f}".format(i, W[0], W[1], e))
= 2000
N = np.zeros(N)
e
init(W)for i in range(N):
= update_step()
e[i] if i % (N//10) == 0:
report(i) report(N)
[0] a=0.00 b=0.00, err=0.81
[200] a=0.34 b=0.61, err=0.57
[400] a=0.67 b=1.19, err=0.35
[600] a=0.96 b=1.66, err=0.20
[800] a=1.12 b=1.87, err=0.16
[1000] a=1.17 b=1.90, err=0.16
[1200] a=1.21 b=1.90, err=0.15
[1400] a=1.23 b=1.88, err=0.15
[1600] a=1.26 b=1.87, err=0.15
[1800] a=1.28 b=1.86, err=0.15
[2000] a=1.30 b=1.85, err=0.15
5 Fitting the data
= W.detach().numpy()
(a, b)
= a * x + b y_pred
Q: Why do we need W.detach()
?
A: This is another way to avoid triggering gradient computation on W
. Basically W.detach()
takes a snapshot of W
by making a copy of it. The copy will not participate in the gradient computation.
2,1,1)
pl.subplot('--')
pl.plot(x, y_true, '-');
pl.plot(x, y_pred,
2,1,2)
pl.subplot(; pl.plot(e)
6 PyTorch API
class LineFitting(torch.nn.Module):
def __init__(self):
super().__init__()
= torch.tensor([0, 0], dtype=torch.float32)
W self.W = torch.nn.Parameter(W)
self.length = x
def forward(self, x):
return self.W[0] * x + self.W[1]
#
# The model can be used as a function to evaluate the forward computation
#
= LineFitting()
line line(x)
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<AddBackward0>)
def train(x, y_true, model, LossFn, OptimizerFn, lr, epochs):
= LossFn()
loss = OptimizerFn(model.parameters(), lr=lr)
optimizer for epoch in range(epochs):
optimizer.zero_grad()= model(x)
y
loss(y, y_true).backward()
optimizer.step()if epoch % (epochs//10) == 0:
with torch.no_grad():
= loss(y, y_true).numpy()
l = next(model.parameters()).numpy()
W print(l, W)
0.1, 10) train(x, y_true, line, torch.nn.MSELoss, torch.optim.SGD,
6.661948 [0.27913672 0.49835616]
3.8140697 [0.48879483 0.8691274 ]
2.2304645 [0.6466221 1.1447786]
1.3497871 [0.76577795 1.3495169 ]
0.8599378 [0.85607487 1.5013919 ]
0.58739096 [0.9248301 1.6138622]
0.4356683 [0.97749996 1.6969628 ]
0.35112858 [1.0181533 1.7581764]
0.30394772 [1.0498246 1.803082 ]
0.27754354 [1.0747765 1.8358393]
def plot(x, y_true, model):
with torch.no_grad():
= model(x)
y '--o')
pl.plot(x, y_true, '-'); pl.plot(x, y,
plot(x, y_true, line)
7 Fitting using power series
Let’s consider the following model:
\[ y_i = \sum_{k=0}^n w_k\cdot x_i^k \]
class PolyFit(torch.nn.Module):
def __init__(self, degree):
super().__init__()
self.degree = degree
self.W = torch.nn.Parameter(torch.zeros(degree))
def forward(self, x):
= torch.zeros_like(x)
y for i in range(self.degree):
+= self.W[i] * (x ** i)
y return y
= PolyFit(5) poly
0.05, 100_000) train(x, y_true, poly, torch.nn.MSELoss, torch.optim.SGD,
6.661948 [0.24917808 0.13956836 0.10177316 0.08296516 0.07186685]
0.096582726 [ 1.5249301 4.626327 -5.4320383 -1.7490873 4.497506 ]
0.0407205 [ 1.3294919 6.8686876 -9.060711 -2.8561032 7.314907 ]
0.019431062 [ 1.2087605 8.256573 -11.31992 -3.5080438 9.038196 ]
0.011308028 [ 1.1341342 9.117172 -12.734032 -3.8792846 10.086427 ]
0.008199797 [ 1.087948 9.652406 -13.626566 -4.0774508 10.718156 ]
0.007001873 [ 1.0593315 9.986803 -14.197132 -4.1690907 11.093055 ]
0.006531924 [ 1.0415397 10.197327 -14.569023 -4.194967 11.309404 ]
0.0063396967 [ 1.0304453 10.331186 -14.817809 -4.180671 11.428036 ]
0.0062534497 [ 1.0234796 10.417755 -14.990616 -4.1417074 11.486474 ]
; plot(x, y_true, poly)