4. Support Vector Machines#
At the end of this exercise you will know:
How to train a SVM using Sequential Minimal Optimization (SMO)
How to train a SVM using Gradient Descent (GD)
How different SVM Kernels perform
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs
Summary of the mathematical formalism
Soft Margin SVM Lagrangian:
Primal problem:
Dual problem:
In the dual problem, we have used the dual coefficients \(\alpha_i\) in \(\omega=\sum_{i=1}^{m} \alpha_{i} y^{(i)} x^{(i)}\) and \(\sum_{i=1}^{m} \alpha_{i} y^{(i)}=0\) to get rid of \(\omega\) and \(b\). To then find \(b\), we use the heuristic \(b^* = \frac{1}{m_{\Sigma}} \sum_{j=1}^{m_{\Sigma}}\left(y^{(j)}-\sum_{i=1}^{m_{\Sigma}} \alpha_{i}^{*} y^{(i)}K(x^{(i)}, x^{(j)})\right)\) over the support vectors \(m_{\Sigma}\). Note, only in the linear case, the kernel becomes \(K(x^{(i)}, x^{(j)}) = \left\langle x^{(i)}, x^{(j)}\right\rangle\). Also note, that only in the dual problem definition we encounter the kernel and can use the kernel trick.
Note: for \(C \to \infty\) this problem ends up being the Hard Margin SVM.
4.1. Artificial Dataset#
We use scikit-learn and make_blobs
to generate a binary dataset with input features \(x\in \mathbb{R}^2\) and labels \(y\in \{-1, +1\}\).
# X as features and Y as labels
X, Y = make_blobs(n_samples=500, centers=2, random_state=0, cluster_std=0.6)
# by default the labels are {0, 1}, so we change them to {-1,1}
Y = np.where(Y==0, -1, 1)
# we also center the input data (per dimension) and scale it to unit variance to make trainig more efficient
X = (X - X.mean(axis=0))/X.std(axis=0)
plt.figure(figsize=(6, 6))
plt.scatter(x=X[:, 0], y=X[:, 1], c=Y)
<matplotlib.collections.PathCollection at 0x177456310>
4.2. Sequential Minimal Optimization (SMO)#
This algorithm was originally developed by John Platt in 1998 and is optimized for SVM optimization. This algorithm solves the dual problem in a gradient-free manner. It selects two multiplier \(\alpha_i\) and \(\alpha_j\) and optimizes them while keeping all other \(\alpha\)’s constant. And then itertively repeats the procedure over all \(\alpha\)’s. The efficiency lies in the heuristic used for selecting two \(\alpha\) values, which is based on information from previous iterations. In the end we obtain a vector of \(M\) values for \(\alpha\) corresponding to each training data point, for which most of the \(\alpha\) values are \(0\) and only the non-zero values contribute to the predictions made by the model.
We adapt the implementation of the SMO algorithm from this reference code by Jon Charest.
Visualization Utils
def plot_decision_boundary(model, ax, resolution=100, colors=('b', 'k', 'r'), levels=(-1, 0, 1)):
"""Plots the model's decision boundary on the input axes object.
Range of decision boundary grid is determined by the training data.
Returns decision boundary grid and axes object (`grid`, `ax`)."""
# Generate coordinate grid of shape [resolution x resolution]
# and evaluate the model over the entire space
xrange = np.linspace(model.X[:, 0].min(), model.X[:, 0].max(), resolution)
yrange = np.linspace(model.X[:, 1].min(),
model.X[:, 1].max(), resolution)
grid = [[decision_function(model.alphas, model.Y,
model.kernel, model.X,
np.array([xr, yr]), model.b) for xr in xrange] for yr in yrange]
grid = np.array(grid).reshape(len(xrange), len(yrange))
# Plot decision contours using grid and
# make a scatter plot of training data
ax.contour(xrange, yrange, grid, levels=levels, linewidths=(1, 1, 1),
linestyles=('--', '-', '--'), colors=colors)
ax.scatter(model.X[:, 0], model.X[:, 1],
c=model.Y, cmap=plt.cm.viridis, lw=0, alpha=0.25)
# Plot support vectors (non-zero alphas)
# as circled points (linewidth > 0)
mask = np.round(model.alphas, decimals=2) != 0.0
ax.scatter(model.X[mask, 0], model.X[mask, 1],
c=model.Y[mask], cmap=plt.cm.viridis, lw=1, edgecolors='k')
return grid, ax
As a first step, we define a generic SMO model
class SMOModel:
"""Container object for the model used for sequential minimal optimization."""
def __init__(self, X, Y, C, kernel, alphas, b, errors):
self.X = X # training data vector
self.Y = Y # class label vector
self.C = C # regularization parameter
self.kernel = kernel # kernel function
self.alphas = alphas # lagrange multiplier vector
self.b = b # scalar bias term
self.errors = errors # error cache used for selection of alphas
self._obj = [] # record of objective function value
self.m = len(self.X) # store size of training set
The next thing we need to define is the kernel. We start with the simplest linear kernel
The implementation of the radial basis function
is also provided for comparison.
def linear_kernel(x, y, b=1):
"""Returns the linear combination of arrays `x` and `y` with
the optional bias term `b` (set to 1 by default)."""
return x @ y.T + b # Note the @ operator for matrix multiplication
def gaussian_kernel(x, y, gamma=1):
"""Returns the gaussian similarity of arrays `x` and `y` with
kernel inverse width parameter `gamma` (set to 1 by default)."""
######################
# TODO: you might find this helpful: https://jonchar.net/notebooks/SVM/
if np.ndim(x) == 1 and np.ndim(y) == 1:
result = np.exp(- gamma * (np.linalg.norm(x - y, 2)) ** 2)
elif (np.ndim(x) > 1 and np.ndim(y) == 1) or (np.ndim(x) == 1 and np.ndim(y) > 1):
result = np.exp(- gamma * (np.linalg.norm(x - y, 2, axis=1) ** 2))
elif np.ndim(x) > 1 and np.ndim(y) > 1:
result = np.exp(- gamma * (np.linalg.norm(x[:, np.newaxis] -
y[np.newaxis, :], 2, axis=2) ** 2))
return result
#######################
Now, using the dual problem formulation and a kernel
, we define the objective and decision functions.
The decision function simply imlements \((\omega x + b)\) by using the kernel trick and the relation \(w=\sum_{i=1}^{m} \alpha_{i} y^{(i)} x^{(i)}\)
# Objective function to optimize, i.e. loss function
def objective_function(alphas, target, kernel, X_train):
"""Returns the SVM objective function based in the input model defined by:
`alphas`: vector of Lagrange multipliers
`target`: vector of class labels (-1 or 1) for training data
`kernel`: kernel function
`X_train`: training data for model."""
return np.sum(alphas) - 0.5 * np.sum(
(target[:, None] * target[None, :]) * kernel(X_train, X_train) * (
alphas[:, None] * alphas[None, :]
)
)
# Decision function, i.e. forward model evaluation
def decision_function(alphas, target, kernel, X_train, x_test, b):
"""Applies the SVM decision function to the input feature vectors in `x_test`."""
result = (alphas * target) @ kernel(X_train, x_test) - b
return result
The SMO algorithm
We are now ready to implement the SMO algorithm as given in Platt’s paper. The implementation is split into three functions: take_step
, examine_example
, and train
.
train
is the main training loop and also implements the selection of the first of the two \(\alpha\) values.examine_example
implements the selection of the second \(\alpha\) valuetake_step
optimizes the two \(\alpha\) values, the bias \(b\), and the cache.
def take_step(i1, i2, model):
# Skip if chosen alphas are the same
if i1 == i2:
return 0, model
alph1 = model.alphas[i1]
alph2 = model.alphas[i2]
y1 = model.Y[i1]
y2 = model.Y[i2]
E1 = model.errors[i1]
E2 = model.errors[i2]
s = y1 * y2
# Compute L & H, the bounds on new possible alpha values
if (y1 != y2):
L = max(0, alph2 - alph1)
H = min(model.C, model.C + alph2 - alph1)
elif (y1 == y2):
L = max(0, alph1 + alph2 - model.C)
H = min(model.C, alph1 + alph2)
if (L == H):
return 0, model
# Compute kernel & 2nd derivative eta
k11 = model.kernel(model.X[i1], model.X[i1])
k12 = model.kernel(model.X[i1], model.X[i2])
k22 = model.kernel(model.X[i2], model.X[i2])
eta = 2 * k12 - k11 - k22
# Compute new alpha 2 (a2) if eta is negative
if (eta < 0):
a2 = alph2 - y2 * (E1 - E2) / eta
# Clip a2 based on bounds L & H
if L < a2 < H:
a2 = a2
elif (a2 <= L):
a2 = L
elif (a2 >= H):
a2 = H
# If eta is non-negative, move new a2 to bound with greater objective function value
else:
alphas_adj = model.alphas.copy()
alphas_adj[i2] = L
# objective function output with a2 = L
Lobj = objective_function(alphas_adj, model.Y, model.kernel, model.X)
alphas_adj[i2] = H
# objective function output with a2 = H
Hobj = objective_function(alphas_adj, model.Y, model.kernel, model.X)
if Lobj > (Hobj + eps):
a2 = L
elif Lobj < (Hobj - eps):
a2 = H
else:
a2 = alph2
# Push a2 to 0 or C if very close
if a2 < 1e-8:
a2 = 0.0
elif a2 > (model.C - 1e-8):
a2 = model.C
# If examples can't be optimized within epsilon (eps), skip this pair
if (np.abs(a2 - alph2) < eps * (a2 + alph2 + eps)):
return 0, model
# Calculate new alpha 1 (a1)
a1 = alph1 + s * (alph2 - a2)
# Update threshold b to reflect newly calculated alphas
# Calculate both possible thresholds
b1 = E1 + y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12 + model.b
b2 = E2 + y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22 + model.b
# Set new threshold based on if a1 or a2 is bound by L and/or H
if 0 < a1 and a1 < C:
b_new = b1
elif 0 < a2 and a2 < C:
b_new = b2
# Average thresholds if both are bound
else:
b_new = (b1 + b2) * 0.5
# Update model object with new alphas & threshold
model.alphas[i1] = a1
model.alphas[i2] = a2
# Update error cache
# Error cache for optimized alphas is set to 0 if they're unbound
for index, alph in zip([i1, i2], [a1, a2]):
if 0.0 < alph < model.C:
model.errors[index] = 0.0
# Set non-optimized errors based on equation 12.11 in Platt's book
non_opt = [n for n in range(model.m) if (n != i1 and n != i2)]
model.errors[non_opt] = model.errors[non_opt] + \
y1*(a1 - alph1)*model.kernel(model.X[i1], model.X[non_opt]) + \
y2*(a2 - alph2) * \
model.kernel(model.X[i2], model.X[non_opt]) + model.b - b_new
# Update model threshold
model.b = b_new
return 1, model
def examine_example(i2, model):
y2 = model.Y[i2]
alph2 = model.alphas[i2]
E2 = model.errors[i2]
r2 = E2 * y2
# Proceed if error is within specified tolerance (tol)
if ((r2 < -tol and alph2 < model.C) or (r2 > tol and alph2 > 0)):
if len(model.alphas[(model.alphas != 0) & (model.alphas != model.C)]) > 1:
# Use 2nd choice heuristic is choose max difference in error
if model.errors[i2] > 0:
i1 = np.argmin(model.errors)
elif model.errors[i2] <= 0:
i1 = np.argmax(model.errors)
step_result, model = take_step(i1, i2, model)
if step_result:
return 1, model
# Loop through non-zero and non-C alphas, starting at a random point
for i1 in np.roll(np.where((model.alphas != 0) & (model.alphas != model.C))[0],
np.random.choice(np.arange(model.m))):
step_result, model = take_step(i1, i2, model)
if step_result:
return 1, model
# loop through all alphas, starting at a random point
for i1 in np.roll(np.arange(model.m), np.random.choice(np.arange(model.m))):
step_result, model = take_step(i1, i2, model)
if step_result:
return 1, model
return 0, model
def train(model):
numChanged = 0
examineAll = 1 # loop over each alpha in first round
while (numChanged > 0) or (examineAll):
numChanged = 0
if examineAll:
# loop over all training examples
for i in range(model.alphas.shape[0]):
examine_result, model = examine_example(i, model)
numChanged += examine_result
if examine_result:
obj_result = objective_function(
model.alphas, model.Y, model.kernel, model.X)
model._obj.append(obj_result)
else:
# loop over examples where alphas are not already at their limits
for i in np.where((model.alphas != 0) & (model.alphas != model.C))[0]:
examine_result, model = examine_example(i, model)
numChanged += examine_result
if examine_result:
obj_result = objective_function(
model.alphas, model.Y, model.kernel, model.X)
model._obj.append(obj_result)
if examineAll == 1:
examineAll = 0
elif numChanged == 0:
examineAll = 1
return model
We are now ready to define the model (after defining some hyperparameters).
# Set model parameters and initial values
C = 1
m = len(X)
initial_alphas = np.zeros(m)
initial_b = 0.0
# Set tolerances
tol = 0.01 # error tolerance
eps = 0.01 # alpha tolerance
# Instantiate model
model = SMOModel(
X, Y, C,
kernel=linear_kernel, # TODO: try linear_kernel and gaussian_kernel
alphas=initial_alphas,
b=initial_b,
errors= np.zeros(m)
)
# Initialize error cache
initial_error = decision_function(model.alphas, model.Y, model.kernel,
model.X, model.X, model.b) - model.Y
model.errors = initial_error
np.random.seed(0)
output = train(model)
fig, ax = plt.subplots()
grid, ax = plot_decision_boundary(output, ax)
# loss curve
# note: we started with all alphas = 0 and turned some of them on one by one, and then refined.
plt.plot(model._obj)
[<matplotlib.lines.Line2D at 0x32a520b20>]
4.3. Multiclass Classification with SVM and SMO#
We look at a problem we have seen before: the classification of the iris dataset. The task is to use two of the measured input features (“sepal_length” and “sepal_width”) and to build a classifier capable of distinguishing among the three possible flowers, which we index by [0, 1, 2].
Getting the data is equivalent to the process we saw in exercise on Linear and Logistic Regression in the Logistic Regression section.
# get iris dataset
from urllib.request import urlretrieve
iris = 'http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data'
urlretrieve(iris)
df0 = pd.read_csv(iris, sep=',')
# name columns
attributes = ["sepal_length", "sepal_width",
"petal_length", "petal_width", "class"]
df0.columns = attributes
# add species index
species = list(df0["class"].unique())
df0["class_idx"] = df0["class"].apply(species.index)
print("Count occurence of each class:")
print(df0["class"].value_counts())
# let's extract two of the features, and the indexed classes [0,1,2]
df = df0[["petal_length", "petal_width", "class_idx"]]
X_train = df[['petal_length', 'petal_width']].to_numpy()
Y_train = df['class_idx'].to_numpy()
print("Training data:")
print(df)
Count occurence of each class:
class
Iris-versicolor 50
Iris-virginica 50
Iris-setosa 49
Name: count, dtype: int64
Training data:
petal_length petal_width class_idx
0 1.4 0.2 0
1 1.3 0.2 0
2 1.5 0.2 0
3 1.4 0.2 0
4 1.7 0.4 0
.. ... ... ...
144 5.2 2.3 2
145 5.0 1.9 2
146 5.2 2.0 2
147 5.4 2.3 2
148 5.1 1.8 2
[149 rows x 3 columns]
Exercise
Now, implement a SVM-based multi-class classifier, which can be trained using the SMO algorithm. Compare with the solution presented below.
Hint: this repository and the one-vs-all (aka one-vs-rest) classifier.
####################
class OneVsAll:
def __init__(self, solver, num_classes, **kwargs):
self._binary_clf = [solver(i, **kwargs) for i in range(num_classes)]
self._num_classes = num_classes
def predict(self, x):
n = x.shape[0]
scores = np.zeros((n, self._num_classes))
for idx in range(self._num_classes):
model = self._binary_clf[idx]
scores[:, idx] = decision_function(
model.alphas,
model.Y,
model.kernel,
model.X,
x,
model.b)
pred = np.argmax(scores, axis=1)
return pred
def fit(self):
np.random.seed(0)
for idx in range(self._num_classes):
self._binary_clf[idx] = train(
self._binary_clf[idx]) # fit(x_train, y_tmp)
def create_binary_clf(class_idx, C, X, Y, kernel):
# Set model parameters and initial values
C = 10.
m = len(X)
initial_alphas = np.zeros(m)
initial_b = 0.0
# Set tolerances
tol = 0.01 # error tolerance
eps = 0.01 # alpha tolerance
Y_tmp = 1. * (Y == class_idx) - 1. * (Y != class_idx)
# Instantiate model
model = SMOModel(
X, Y_tmp, C,
kernel=kernel,
alphas=initial_alphas,
b=initial_b,
errors=np.zeros(m)
)
# Initialize error cache
initial_error = decision_function(model.alphas, model.Y, model.kernel,
model.X, model.X, model.b) - model.Y
model.errors = initial_error
return model
def plot_decision_boundary_multiclass(solver, ax, Y, resolution=100):
"""Plots the model's decision boundary on the input axes object.
Range of decision boundary grid is determined by the training data.
Returns decision boundary grid and axes object (`grid`, `ax`)."""
# Generate coordinate grid of shape [resolution x resolution]
# and evaluate the model over the entire space
model0 = solver._binary_clf[0]
xrange = np.linspace(model0.X[:, 0].min(),
model0.X[:, 0].max(), resolution)
yrange = np.linspace(model0.X[:, 1].min(),
model0.X[:, 1].max(), resolution)
x, y = np.meshgrid(xrange, yrange)
xy = np.array(list(map(np.ravel, [x, y]))).T # shape=(num_samples, dim)
grid = solver.predict(xy)
grid = np.array(grid).reshape(len(xrange), len(yrange))
# Plot decision contours using grid and
# make a scatter plot of training data
ax.contourf(xrange, yrange, grid, alpha=0.5)
ax.scatter(model0.X[:, 0], model0.X[:, 1], c=Y)
return grid, ax
solver = OneVsAll(
solver=create_binary_clf,
num_classes=3,
C=1.,
X=X_train,
Y=Y_train,
kernel=gaussian_kernel # linear_kernel vs gaussian_kernel
)
solver.fit()
fig, ax = plt.subplots()
grid, ax = plot_decision_boundary_multiclass(solver, ax, Y_train)
plt.show()
####################
4.4. Gradient Descent Optimization of Soft Margin Classifier#
We can also directly solve the primal problem with gradient-based optimization, if we slightly reformulate it. This reformulation requires using the hinge loss:
What we would give as an input to the hinge loss is the “raw” output of the classifier, e.g. for linear SVMs \(out= \omega x + b\), multiplied with the correct output \(y\). Thus, the hinge loss of a single sample \(i\) for a linear SVM becomes
To fully recover the Soft Margin Classifier, we simply add the squared L2 regularization to this loss, thus the total loss becomes
With \(\lambda=\frac{1}{C}\), the above equation is equivalent to the primal problem of the Soft Margin Classifier (at the top of this tutorial). The regularization parameters trades off between classifying more points correctly (low \(\lambda\) / high \(C\) -> MLE solution) and smoother decision boundaries (high \(\lambda\) / low \(C\)).
If we want to do something like kernels (although there is no dot product here), the best we can do is directly applying the input feature transformation \(x \to \varphi(x)\) which leads to the loss
However, as you might remember, there is no practically useful \(\varphi\) corresponding to the RBF kernel - there is one, but it is an infinitely long sum. Thus, the hinge loss approach is somewhat restrictive. For examples of building kernel approximations, i.e. crafted feature maps \(\phi\), visit scikit-learn’s 6.7. Kernel Approximations.
Note: there is a similar loss function corresponding to the logistic regression. See this for more details.
Visualization Utils
def visualize_torch(X, Y, model, linear=False):
"""
based on
https://scikit-learn.org/stable/auto_examples/svm/plot_svm_margin.html#sphx-glr-auto-examples-svm-plot-svm-margin-py
"""
plt.figure(figsize=(6, 6))
plt.scatter(x=X[:, 0], y=X[:, 1], c=Y, s=10)
w = model.linear.weight.squeeze().detach().numpy()
b = model.linear.bias.squeeze().detach().numpy()
delta = 0.02
if linear:
# extend bounds by "delta" to improve the plot
x_min = X[:, 0].min() - delta
x_max = X[:, 0].max() + delta
# solving $w0+x1 + w1*x2 + b = 0$ for $x2$ leads to $x2 = -w0/w1 - b/w1$
a = -w[0] / w[1]
xx = np.linspace(x_min, x_max, 50)
yy = a * xx - b / w[1]
# $margin = 1 / ||w||_2$
# Why? Recall that the distance between a point (x_p, y_P) and a line
# $ax+by+c=0$ is given by $|ax_p+by_p+c|/\sqrt{a^2+b^2}$. As we set the
# functional margin to 1, i.e. $|ax_i+by_i+c|=1$ for a support vector
# point, then the total margin becomes $1 / ||w||_2$.
margin = 1 / np.sqrt(np.sum(w**2))
yy_up = yy + np.sqrt(1+a**2) * margin
yy_down = yy - np.sqrt(1+a**2) * margin
plt.plot(xx, yy, "r-")
plt.plot(xx, yy_up, "r--")
plt.plot(xx, yy_down, "r--")
else:
x = np.arange(X[:, 0].min(), X[:, 0].max(), delta)
y = np.arange(X[:, 1].min(), X[:, 1].max(), delta)
x, y = np.meshgrid(x, y)
xy = list(map(np.ravel, [x, y]))
xy = torch.tensor(xy, dtype=torch.float32).T
z = model(xy)
z = z.detach().numpy().reshape(x.shape)
cs0 = plt.contourf(x, y, z, alpha=0.6)
plt.contour(cs0, '-', levels=[0], colors='r', linewidth=5)
plt.plot(np.nan, label='decision boundary', color='r')
plt.legend()
plt.grid()
plt.xlim([X[:, 0].min() + delta, X[:, 0].max() - delta])
plt.ylim([X[:, 1].min() + delta, X[:, 1].max() - delta])
plt.tight_layout()
plt.show()
Let’s first define a base SVM class, a linear kernel, the hinge loss, and the regularization over weights.
class SupportVectorMachine(nn.Module):
def __init__(self, input_size, phi):
super().__init__()
# X_train.shape should be (num_samples, dim x)
self.input_size = input_size
self.phi = phi
self.linear = nn.Linear(self.input_size, 1)
def forward(self, x):
phi_x = self.phi(x)
out = self.linear(phi_x)
return out
class PhiIdentity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
def hinge_loss(y, out):
"""Hinge loss"""
return torch.mean(torch.clamp(1 - y * out, min=0))
def sq_l2_reg(model):
"""Squared L2 regularization of weights"""
return model.linear.weight.square().sum()
Exercise
Implement the Radial Basis Function kernel. After that, use your new kernel implementation to run a training process and visualize results.
Hint: this reference.
class PhiRBF(nn.Module):
"""Something like a Radial Basis Function feature map
Lifts the dimension from X.shape[-1] to X.shape[0]"""
def __init__(self, X_train, gamma):
super().__init__()
self.X_train = X_train
self.gamma = gamma
def forward(self, x):
##############################
# TODO: implement the forward methods
# The choice of this specific phi is arbitrary and is not the true RBF.
# We lift the input space from the space of `x` to a space of
# dimension equal to the number of training data points.
out = self.X_train.repeat(x.size(0), 1, 1)
out = torch.exp(-self.gamma * ((x[:, None] - out) ** 2).sum(dim=2))
return out
##############################
Some preparation before we train the model
# prepare the data
X = torch.Tensor(X)
Y = torch.Tensor(Y)
N = len(Y)
torch.manual_seed(42)
# set hyperparameters
learning_rate = 0.1
epochs = 1000
batch_size = 100
reg_lambda = 0.001
phi_type = 'lin' # TODO: try both "lin" and "rbf"
if phi_type == 'lin':
phi = PhiIdentity()
input_size = X.shape[1]
elif phi_type == 'rbf':
phi = PhiRBF(X_train=X, gamma=1.0)
input_size = X.shape[0]
# initialize model
model = SupportVectorMachine(input_size, phi)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
model.train() # model.eval() for evaluation
# print initial parameters
for name, param in model.named_parameters():
print(name, ": ", param.data)
linear.weight : tensor([[0.5406, 0.5869]])
linear.bias : tensor([-0.1657])
Now, we can train the model.
We iterate over the data by following these two steps in each epoch:
randomly permute all indices up to the number of training data points at each epoch.
iterate over all batches by picking the samples corresponding to the current subset of indices
for epoch in range(epochs):
random_nums = torch.randperm(N)
# Iterate over the individual batches
for i in range(0, N, batch_size):
x = X[random_nums[i:i + batch_size]]
y = Y[random_nums[i:i + batch_size]]
optimizer.zero_grad()
output = model(x)
loss = hinge_loss(y.unsqueeze(1), output) + \
reg_lambda * sq_l2_reg(model)
loss.backward()
optimizer.step()
print('epoch {}, loss {}'.format(epoch, loss.item()))
epoch 0, loss 0.6529232263565063
epoch 1, loss 0.32932916283607483
epoch 2, loss 0.21102982759475708
epoch 3, loss 0.16052725911140442
epoch 4, loss 0.06884933263063431
epoch 5, loss 0.06849965453147888
epoch 6, loss 0.06332068890333176
epoch 7, loss 0.023523667827248573
epoch 8, loss 0.011341162957251072
epoch 9, loss 0.014603596180677414
epoch 10, loss 0.02068745531141758
epoch 11, loss 0.022596335038542747
epoch 12, loss 0.036175437271595
epoch 13, loss 0.011910860426723957
epoch 14, loss 0.020319964736700058
epoch 15, loss 0.01004534400999546
epoch 16, loss 0.019348010420799255
epoch 17, loss 0.02869299054145813
epoch 18, loss 0.016887042671442032
epoch 19, loss 0.017847873270511627
epoch 20, loss 0.011874405667185783
epoch 21, loss 0.027022400870919228
epoch 22, loss 0.017283424735069275
epoch 23, loss 0.011979827657341957
epoch 24, loss 0.03560100868344307
epoch 25, loss 0.02742944099009037
epoch 26, loss 0.008837727829813957
epoch 27, loss 0.019427519291639328
epoch 28, loss 0.014539340510964394
epoch 29, loss 0.027257705107331276
epoch 30, loss 0.013951050117611885
epoch 31, loss 0.016992002725601196
epoch 32, loss 0.02810915745794773
epoch 33, loss 0.019938455894589424
epoch 34, loss 0.01238109078258276
epoch 35, loss 0.01626821793615818
epoch 36, loss 0.011231929995119572
epoch 37, loss 0.009365127421915531
epoch 38, loss 0.010968014597892761
epoch 39, loss 0.016715802252292633
epoch 40, loss 0.008273046463727951
epoch 41, loss 0.016769345849752426
epoch 42, loss 0.03628223389387131
epoch 43, loss 0.021709803491830826
epoch 44, loss 0.01958717778325081
epoch 45, loss 0.009943410754203796
epoch 46, loss 0.016492299735546112
epoch 47, loss 0.018892180174589157
epoch 48, loss 0.023943839594721794
epoch 49, loss 0.03196638077497482
epoch 50, loss 0.018467655405402184
epoch 51, loss 0.010412736795842648
epoch 52, loss 0.015215294435620308
epoch 53, loss 0.011900645680725574
epoch 54, loss 0.01181163638830185
epoch 55, loss 0.0287797674536705
epoch 56, loss 0.009267928078770638
epoch 57, loss 0.020183205604553223
epoch 58, loss 0.026072073727846146
epoch 59, loss 0.014507857151329517
epoch 60, loss 0.02144475281238556
epoch 61, loss 0.02239488996565342
epoch 62, loss 0.013368265703320503
epoch 63, loss 0.02006545662879944
epoch 64, loss 0.03154150769114494
epoch 65, loss 0.019181571900844574
epoch 66, loss 0.013332795351743698
epoch 67, loss 0.011456651613116264
epoch 68, loss 0.011319841258227825
epoch 69, loss 0.033634938299655914
epoch 70, loss 0.008578321896493435
epoch 71, loss 0.01856151781976223
epoch 72, loss 0.02798723429441452
epoch 73, loss 0.020916078239679337
epoch 74, loss 0.02493014745414257
epoch 75, loss 0.02221701666712761
epoch 76, loss 0.007936620153486729
epoch 77, loss 0.01581569015979767
epoch 78, loss 0.023603588342666626
epoch 79, loss 0.021080585196614265
epoch 80, loss 0.01964881457388401
epoch 81, loss 0.013153811916708946
epoch 82, loss 0.025596708059310913
epoch 83, loss 0.029831478372216225
epoch 84, loss 0.01684911549091339
epoch 85, loss 0.008610464632511139
epoch 86, loss 0.01449434831738472
epoch 87, loss 0.017094910144805908
epoch 88, loss 0.01653248257935047
epoch 89, loss 0.012447424232959747
epoch 90, loss 0.015650784596800804
epoch 91, loss 0.013402217999100685
epoch 92, loss 0.010773280635476112
epoch 93, loss 0.007344100624322891
epoch 94, loss 0.02304656058549881
epoch 95, loss 0.013455504551529884
epoch 96, loss 0.01634683459997177
epoch 97, loss 0.019475333392620087
epoch 98, loss 0.014577014371752739
epoch 99, loss 0.025311222299933434
epoch 100, loss 0.02871127985417843
epoch 101, loss 0.028206856921315193
epoch 102, loss 0.00799972377717495
epoch 103, loss 0.01218842901289463
epoch 104, loss 0.014394376426935196
epoch 105, loss 0.013174625113606453
epoch 106, loss 0.021523792296648026
epoch 107, loss 0.015360859222710133
epoch 108, loss 0.0323307141661644
epoch 109, loss 0.02972329966723919
epoch 110, loss 0.007502985652536154
epoch 111, loss 0.02229909598827362
epoch 112, loss 0.019809341058135033
epoch 113, loss 0.013716340996325016
epoch 114, loss 0.011643065139651299
epoch 115, loss 0.007505903486162424
epoch 116, loss 0.024656234309077263
epoch 117, loss 0.011912629008293152
epoch 118, loss 0.013722234405577183
epoch 119, loss 0.01405499316751957
epoch 120, loss 0.011613825336098671
epoch 121, loss 0.022435631603002548
epoch 122, loss 0.029105838388204575
epoch 123, loss 0.030918173491954803
epoch 124, loss 0.021611273288726807
epoch 125, loss 0.007471728604286909
epoch 126, loss 0.013262221589684486
epoch 127, loss 0.017662616446614265
epoch 128, loss 0.01211739145219326
epoch 129, loss 0.013487694784998894
epoch 130, loss 0.017615212127566338
epoch 131, loss 0.015423174947500229
epoch 132, loss 0.02071080356836319
epoch 133, loss 0.01910341903567314
epoch 134, loss 0.016328997910022736
epoch 135, loss 0.015450742095708847
epoch 136, loss 0.02429443784058094
epoch 137, loss 0.011115816421806812
epoch 138, loss 0.019154727458953857
epoch 139, loss 0.008755515329539776
epoch 140, loss 0.00999169796705246
epoch 141, loss 0.011265546083450317
epoch 142, loss 0.01751871034502983
epoch 143, loss 0.015367773361504078
epoch 144, loss 0.007419365458190441
epoch 145, loss 0.018673870712518692
epoch 146, loss 0.011314205825328827
epoch 147, loss 0.02611156925559044
epoch 148, loss 0.03926727920770645
epoch 149, loss 0.016636013984680176
epoch 150, loss 0.00800884049385786
epoch 151, loss 0.024944977834820747
epoch 152, loss 0.00798041746020317
epoch 153, loss 0.027744676917791367
epoch 154, loss 0.021698933094739914
epoch 155, loss 0.010719476267695427
epoch 156, loss 0.010714606381952763
epoch 157, loss 0.02242470160126686
epoch 158, loss 0.007664702832698822
epoch 159, loss 0.023317936807870865
epoch 160, loss 0.020734209567308426
epoch 161, loss 0.030158013105392456
epoch 162, loss 0.01752951741218567
epoch 163, loss 0.00773422047495842
epoch 164, loss 0.019544146955013275
epoch 165, loss 0.007499003782868385
epoch 166, loss 0.01942308619618416
epoch 167, loss 0.007399627473205328
epoch 168, loss 0.015846431255340576
epoch 169, loss 0.02299070730805397
epoch 170, loss 0.02391822636127472
epoch 171, loss 0.014458242803812027
epoch 172, loss 0.008431077003479004
epoch 173, loss 0.0074187153950333595
epoch 174, loss 0.013393940404057503
epoch 175, loss 0.023419635370373726
epoch 176, loss 0.007446145638823509
epoch 177, loss 0.02020084857940674
epoch 178, loss 0.017660096287727356
epoch 179, loss 0.017055165022611618
epoch 180, loss 0.010244715958833694
epoch 181, loss 0.026752935722470284
epoch 182, loss 0.020140865817666054
epoch 183, loss 0.02482450008392334
epoch 184, loss 0.02673717960715294
epoch 185, loss 0.03359007462859154
epoch 186, loss 0.027558252215385437
epoch 187, loss 0.03429476171731949
epoch 188, loss 0.022292418405413628
epoch 189, loss 0.014880955219268799
epoch 190, loss 0.02280324511229992
epoch 191, loss 0.015874642878770828
epoch 192, loss 0.01517421379685402
epoch 193, loss 0.007817331701517105
epoch 194, loss 0.03291420266032219
epoch 195, loss 0.02059967629611492
epoch 196, loss 0.007735683582723141
epoch 197, loss 0.02794863097369671
epoch 198, loss 0.016022272408008575
epoch 199, loss 0.020696958526968956
epoch 200, loss 0.007486091926693916
epoch 201, loss 0.01950039342045784
epoch 202, loss 0.008072391152381897
epoch 203, loss 0.015519271604716778
epoch 204, loss 0.01061311922967434
epoch 205, loss 0.01703711785376072
epoch 206, loss 0.01950174570083618
epoch 207, loss 0.007373841013759375
epoch 208, loss 0.012818759307265282
epoch 209, loss 0.009988040663301945
epoch 210, loss 0.01591399312019348
epoch 211, loss 0.016092125326395035
epoch 212, loss 0.015514088794589043
epoch 213, loss 0.016746358945965767
epoch 214, loss 0.012399723753333092
epoch 215, loss 0.017697608098387718
epoch 216, loss 0.01643187925219536
epoch 217, loss 0.016673961654305458
epoch 218, loss 0.012276725843548775
epoch 219, loss 0.009525096043944359
epoch 220, loss 0.017072822898626328
epoch 221, loss 0.017055056989192963
epoch 222, loss 0.028141535818576813
epoch 223, loss 0.025859642773866653
epoch 224, loss 0.023126821964979172
epoch 225, loss 0.011897667311131954
epoch 226, loss 0.02505098097026348
epoch 227, loss 0.013122893869876862
epoch 228, loss 0.02160230092704296
epoch 229, loss 0.01138235628604889
epoch 230, loss 0.015782013535499573
epoch 231, loss 0.01297757774591446
epoch 232, loss 0.0172936599701643
epoch 233, loss 0.01064112689346075
epoch 234, loss 0.00875353254377842
epoch 235, loss 0.016898836940526962
epoch 236, loss 0.02861199714243412
epoch 237, loss 0.010159878991544247
epoch 238, loss 0.0168612003326416
epoch 239, loss 0.02260568179190159
epoch 240, loss 0.00875672698020935
epoch 241, loss 0.02409258484840393
epoch 242, loss 0.00749982800334692
epoch 243, loss 0.010913580656051636
epoch 244, loss 0.013152629137039185
epoch 245, loss 0.021902093663811684
epoch 246, loss 0.032939422875642776
epoch 247, loss 0.02427181601524353
epoch 248, loss 0.02059311233460903
epoch 249, loss 0.017180975526571274
epoch 250, loss 0.018093666061758995
epoch 251, loss 0.007502728141844273
epoch 252, loss 0.019565751776099205
epoch 253, loss 0.024369962513446808
epoch 254, loss 0.007513335905969143
epoch 255, loss 0.008334974758327007
epoch 256, loss 0.013441191986203194
epoch 257, loss 0.0188884399831295
epoch 258, loss 0.014966906048357487
epoch 259, loss 0.007504779379814863
epoch 260, loss 0.0183013416826725
epoch 261, loss 0.026826374232769012
epoch 262, loss 0.027259986847639084
epoch 263, loss 0.020195286720991135
epoch 264, loss 0.02040087804198265
epoch 265, loss 0.015563320368528366
epoch 266, loss 0.024225285276770592
epoch 267, loss 0.021226368844509125
epoch 268, loss 0.02058425173163414
epoch 269, loss 0.015351721085608006
epoch 270, loss 0.017954474315047264
epoch 271, loss 0.008661163970828056
epoch 272, loss 0.017460431903600693
epoch 273, loss 0.007597330957651138
epoch 274, loss 0.025724701583385468
epoch 275, loss 0.02070033550262451
epoch 276, loss 0.030581418424844742
epoch 277, loss 0.02360602095723152
epoch 278, loss 0.025621917098760605
epoch 279, loss 0.017670635133981705
epoch 280, loss 0.0154962707310915
epoch 281, loss 0.034805137664079666
epoch 282, loss 0.013794045895338058
epoch 283, loss 0.01184186153113842
epoch 284, loss 0.016611086204648018
epoch 285, loss 0.015127686783671379
epoch 286, loss 0.012672672048211098
epoch 287, loss 0.01942654326558113
epoch 288, loss 0.01515969354659319
epoch 289, loss 0.014757199212908745
epoch 290, loss 0.02449042908847332
epoch 291, loss 0.01794654317200184
epoch 292, loss 0.024638647213578224
epoch 293, loss 0.0104353167116642
epoch 294, loss 0.027055881917476654
epoch 295, loss 0.013050513342022896
epoch 296, loss 0.010601221583783627
epoch 297, loss 0.012260094285011292
epoch 298, loss 0.015035903081297874
epoch 299, loss 0.029929516837000847
epoch 300, loss 0.020099438726902008
epoch 301, loss 0.013573633506894112
epoch 302, loss 0.02014109492301941
epoch 303, loss 0.009499229490756989
epoch 304, loss 0.009850900620222092
epoch 305, loss 0.021772779524326324
epoch 306, loss 0.01923082023859024
epoch 307, loss 0.011842112988233566
epoch 308, loss 0.012719891965389252
epoch 309, loss 0.008080152794718742
epoch 310, loss 0.010718028992414474
epoch 311, loss 0.013380920514464378
epoch 312, loss 0.02217014878988266
epoch 313, loss 0.022562431171536446
epoch 314, loss 0.020005710422992706
epoch 315, loss 0.0109880231320858
epoch 316, loss 0.018386948853731155
epoch 317, loss 0.027493512257933617
epoch 318, loss 0.03860053792595863
epoch 319, loss 0.014543496072292328
epoch 320, loss 0.00828151311725378
epoch 321, loss 0.029730817303061485
epoch 322, loss 0.02774086222052574
epoch 323, loss 0.03813919425010681
epoch 324, loss 0.014449570327997208
epoch 325, loss 0.024452466517686844
epoch 326, loss 0.027696583420038223
epoch 327, loss 0.018638866022229195
epoch 328, loss 0.02394719235599041
epoch 329, loss 0.01826174184679985
epoch 330, loss 0.018967613577842712
epoch 331, loss 0.01554875448346138
epoch 332, loss 0.02289714105427265
epoch 333, loss 0.011939074844121933
epoch 334, loss 0.01804123818874359
epoch 335, loss 0.01591324806213379
epoch 336, loss 0.018665527924895287
epoch 337, loss 0.013349080458283424
epoch 338, loss 0.019267939031124115
epoch 339, loss 0.020053885877132416
epoch 340, loss 0.031883370131254196
epoch 341, loss 0.017944278195500374
epoch 342, loss 0.007581883575767279
epoch 343, loss 0.0315198190510273
epoch 344, loss 0.013411542400717735
epoch 345, loss 0.008408541791141033
epoch 346, loss 0.011493694968521595
epoch 347, loss 0.008248405531048775
epoch 348, loss 0.010008057579398155
epoch 349, loss 0.013388369232416153
epoch 350, loss 0.00899459607899189
epoch 351, loss 0.020080648362636566
epoch 352, loss 0.026478182524442673
epoch 353, loss 0.013271904550492764
epoch 354, loss 0.014000393450260162
epoch 355, loss 0.015805188566446304
epoch 356, loss 0.011977043934166431
epoch 357, loss 0.018513716757297516
epoch 358, loss 0.007750302087515593
epoch 359, loss 0.007939244620501995
epoch 360, loss 0.017750468105077744
epoch 361, loss 0.027842773124575615
epoch 362, loss 0.01719547249376774
epoch 363, loss 0.012407287955284119
epoch 364, loss 0.025100676342844963
epoch 365, loss 0.027935069054365158
epoch 366, loss 0.032712988555431366
epoch 367, loss 0.01950753480195999
epoch 368, loss 0.019254595041275024
epoch 369, loss 0.02313658595085144
epoch 370, loss 0.011981004849076271
epoch 371, loss 0.02242247201502323
epoch 372, loss 0.02032533474266529
epoch 373, loss 0.022688889876008034
epoch 374, loss 0.00826448854058981
epoch 375, loss 0.02683805115520954
epoch 376, loss 0.007939985953271389
epoch 377, loss 0.019312545657157898
epoch 378, loss 0.027878258377313614
epoch 379, loss 0.015221317298710346
epoch 380, loss 0.013193020597100258
epoch 381, loss 0.023049067705869675
epoch 382, loss 0.014024722389876842
epoch 383, loss 0.02394382283091545
epoch 384, loss 0.009695753455162048
epoch 385, loss 0.010693741030991077
epoch 386, loss 0.0290047787129879
epoch 387, loss 0.016627943143248558
epoch 388, loss 0.010754554532468319
epoch 389, loss 0.024565115571022034
epoch 390, loss 0.007644571363925934
epoch 391, loss 0.007629550062119961
epoch 392, loss 0.021021034568548203
epoch 393, loss 0.013780297711491585
epoch 394, loss 0.007618877571076155
epoch 395, loss 0.014168493449687958
epoch 396, loss 0.022333158180117607
epoch 397, loss 0.019907163456082344
epoch 398, loss 0.012690842151641846
epoch 399, loss 0.007581888698041439
epoch 400, loss 0.014067014679312706
epoch 401, loss 0.010768814012408257
epoch 402, loss 0.023471597582101822
epoch 403, loss 0.028813408687710762
epoch 404, loss 0.01580222137272358
epoch 405, loss 0.010210936889052391
epoch 406, loss 0.02547287940979004
epoch 407, loss 0.0159046221524477
epoch 408, loss 0.015087504871189594
epoch 409, loss 0.03215155377984047
epoch 410, loss 0.012034345418214798
epoch 411, loss 0.026140984147787094
epoch 412, loss 0.016037197783589363
epoch 413, loss 0.014943236485123634
epoch 414, loss 0.019016113132238388
epoch 415, loss 0.02099243924021721
epoch 416, loss 0.01902170106768608
epoch 417, loss 0.023729944601655006
epoch 418, loss 0.014432249590754509
epoch 419, loss 0.010911463759839535
epoch 420, loss 0.022941552102565765
epoch 421, loss 0.015976009890437126
epoch 422, loss 0.01589987799525261
epoch 423, loss 0.02047642506659031
epoch 424, loss 0.01793763041496277
epoch 425, loss 0.01040224265307188
epoch 426, loss 0.015696926042437553
epoch 427, loss 0.007656199857592583
epoch 428, loss 0.027224380522966385
epoch 429, loss 0.01614195853471756
epoch 430, loss 0.02295997366309166
epoch 431, loss 0.01813630945980549
epoch 432, loss 0.01594565436244011
epoch 433, loss 0.01789017952978611
epoch 434, loss 0.025183260440826416
epoch 435, loss 0.01884186454117298
epoch 436, loss 0.010210203006863594
epoch 437, loss 0.027670148760080338
epoch 438, loss 0.022403676062822342
epoch 439, loss 0.008679332211613655
epoch 440, loss 0.017477864399552345
epoch 441, loss 0.00756862061098218
epoch 442, loss 0.014442622661590576
epoch 443, loss 0.009209012612700462
epoch 444, loss 0.011491131968796253
epoch 445, loss 0.008771305903792381
epoch 446, loss 0.015956006944179535
epoch 447, loss 0.023526687175035477
epoch 448, loss 0.02055303379893303
epoch 449, loss 0.016042321920394897
epoch 450, loss 0.019551770761609077
epoch 451, loss 0.013386141508817673
epoch 452, loss 0.02307922951877117
epoch 453, loss 0.0074484278447926044
epoch 454, loss 0.022188372910022736
epoch 455, loss 0.012916529551148415
epoch 456, loss 0.007357793860137463
epoch 457, loss 0.011774584650993347
epoch 458, loss 0.02229812741279602
epoch 459, loss 0.010826005600392818
epoch 460, loss 0.021893363445997238
epoch 461, loss 0.014071229845285416
epoch 462, loss 0.01613173820078373
epoch 463, loss 0.014923090115189552
epoch 464, loss 0.018942900002002716
epoch 465, loss 0.02347574383020401
epoch 466, loss 0.016991011798381805
epoch 467, loss 0.02192123606801033
epoch 468, loss 0.01190897449851036
epoch 469, loss 0.009975586086511612
epoch 470, loss 0.022624965757131577
epoch 471, loss 0.028198711574077606
epoch 472, loss 0.014094446785748005
epoch 473, loss 0.02174823358654976
epoch 474, loss 0.0076257530599832535
epoch 475, loss 0.014296388253569603
epoch 476, loss 0.03339005261659622
epoch 477, loss 0.020608361810445786
epoch 478, loss 0.029442910104990005
epoch 479, loss 0.017182596027851105
epoch 480, loss 0.03411737456917763
epoch 481, loss 0.018307488411664963
epoch 482, loss 0.012989819049835205
epoch 483, loss 0.02618907392024994
epoch 484, loss 0.01539299264550209
epoch 485, loss 0.016315942630171776
epoch 486, loss 0.023433012887835503
epoch 487, loss 0.007587025407701731
epoch 488, loss 0.007555147632956505
epoch 489, loss 0.02295236848294735
epoch 490, loss 0.018658645451068878
epoch 491, loss 0.008372826501727104
epoch 492, loss 0.017782099545001984
epoch 493, loss 0.025661949068307877
epoch 494, loss 0.027667922899127007
epoch 495, loss 0.014580126851797104
epoch 496, loss 0.02234005555510521
epoch 497, loss 0.027612512931227684
epoch 498, loss 0.011857802048325539
epoch 499, loss 0.014246853068470955
epoch 500, loss 0.010020069777965546
epoch 501, loss 0.0158458910882473
epoch 502, loss 0.0104269590228796
epoch 503, loss 0.011290634982287884
epoch 504, loss 0.02396392449736595
epoch 505, loss 0.018720589578151703
epoch 506, loss 0.015410799533128738
epoch 507, loss 0.01504942961037159
epoch 508, loss 0.014464166015386581
epoch 509, loss 0.02986173704266548
epoch 510, loss 0.007543663028627634
epoch 511, loss 0.01589716598391533
epoch 512, loss 0.022516366094350815
epoch 513, loss 0.007407485041767359
epoch 514, loss 0.020130887627601624
epoch 515, loss 0.025671381503343582
epoch 516, loss 0.011908264830708504
epoch 517, loss 0.0160747729241848
epoch 518, loss 0.007362619042396545
epoch 519, loss 0.025915466248989105
epoch 520, loss 0.010516630485653877
epoch 521, loss 0.012154947966337204
epoch 522, loss 0.010719838552176952
epoch 523, loss 0.02021731436252594
epoch 524, loss 0.0075440313667058945
epoch 525, loss 0.023976413533091545
epoch 526, loss 0.02109687030315399
epoch 527, loss 0.02267010509967804
epoch 528, loss 0.01650257594883442
epoch 529, loss 0.013281196355819702
epoch 530, loss 0.02610887959599495
epoch 531, loss 0.03023149073123932
epoch 532, loss 0.015768514946103096
epoch 533, loss 0.018717600032687187
epoch 534, loss 0.010700278915464878
epoch 535, loss 0.019617998972535133
epoch 536, loss 0.016131967306137085
epoch 537, loss 0.02214258722960949
epoch 538, loss 0.020171090960502625
epoch 539, loss 0.01864553987979889
epoch 540, loss 0.026814978569746017
epoch 541, loss 0.01264863833785057
epoch 542, loss 0.024785056710243225
epoch 543, loss 0.022091856226325035
epoch 544, loss 0.018694594502449036
epoch 545, loss 0.02411961555480957
epoch 546, loss 0.02832145430147648
epoch 547, loss 0.008746175095438957
epoch 548, loss 0.024111120030283928
epoch 549, loss 0.019991474226117134
epoch 550, loss 0.019441690295934677
epoch 551, loss 0.02060997113585472
epoch 552, loss 0.028627660125494003
epoch 553, loss 0.019727610051631927
epoch 554, loss 0.014555417001247406
epoch 555, loss 0.018485071137547493
epoch 556, loss 0.020580286160111427
epoch 557, loss 0.01896529830992222
epoch 558, loss 0.010614166036248207
epoch 559, loss 0.034860868006944656
epoch 560, loss 0.013769550248980522
epoch 561, loss 0.019377578049898148
epoch 562, loss 0.01672223210334778
epoch 563, loss 0.007509451825171709
epoch 564, loss 0.022813932970166206
epoch 565, loss 0.00795042049139738
epoch 566, loss 0.016818825155496597
epoch 567, loss 0.0396093912422657
epoch 568, loss 0.013739877380430698
epoch 569, loss 0.02423848584294319
epoch 570, loss 0.02447105385363102
epoch 571, loss 0.01688484475016594
epoch 572, loss 0.009280113503336906
epoch 573, loss 0.013076773844659328
epoch 574, loss 0.015209261327981949
epoch 575, loss 0.007500879000872374
epoch 576, loss 0.007520255167037249
epoch 577, loss 0.008797317743301392
epoch 578, loss 0.032516222447156906
epoch 579, loss 0.010184166952967644
epoch 580, loss 0.022670993581414223
epoch 581, loss 0.008377740159630775
epoch 582, loss 0.023882955312728882
epoch 583, loss 0.016264095902442932
epoch 584, loss 0.01842571422457695
epoch 585, loss 0.020984617993235588
epoch 586, loss 0.016903087496757507
epoch 587, loss 0.008182183839380741
epoch 588, loss 0.014422125183045864
epoch 589, loss 0.019617440178990364
epoch 590, loss 0.01046663522720337
epoch 591, loss 0.022957241162657738
epoch 592, loss 0.02028406411409378
epoch 593, loss 0.01848560944199562
epoch 594, loss 0.013535361737012863
epoch 595, loss 0.019871799275279045
epoch 596, loss 0.018350819125771523
epoch 597, loss 0.01166064664721489
epoch 598, loss 0.026517361402511597
epoch 599, loss 0.012123551219701767
epoch 600, loss 0.007492431439459324
epoch 601, loss 0.009064696729183197
epoch 602, loss 0.016980953514575958
epoch 603, loss 0.012284808792173862
epoch 604, loss 0.010763145983219147
epoch 605, loss 0.015930887311697006
epoch 606, loss 0.02236921526491642
epoch 607, loss 0.013067113235592842
epoch 608, loss 0.013040078803896904
epoch 609, loss 0.0187620148062706
epoch 610, loss 0.01771625690162182
epoch 611, loss 0.028800735250115395
epoch 612, loss 0.014437931589782238
epoch 613, loss 0.012829842045903206
epoch 614, loss 0.01822667196393013
epoch 615, loss 0.014977355487644672
epoch 616, loss 0.008858875371515751
epoch 617, loss 0.010579529218375683
epoch 618, loss 0.009964755736291409
epoch 619, loss 0.022864125669002533
epoch 620, loss 0.012249489314854145
epoch 621, loss 0.013470891863107681
epoch 622, loss 0.019624970853328705
epoch 623, loss 0.02231425791978836
epoch 624, loss 0.013070001266896725
epoch 625, loss 0.01797177642583847
epoch 626, loss 0.010366102680563927
epoch 627, loss 0.02318733185529709
epoch 628, loss 0.017895208671689034
epoch 629, loss 0.007616092916578054
epoch 630, loss 0.01532946340739727
epoch 631, loss 0.024392887949943542
epoch 632, loss 0.009857799857854843
epoch 633, loss 0.008954904042184353
epoch 634, loss 0.01734844036400318
epoch 635, loss 0.009175509214401245
epoch 636, loss 0.026873426511883736
epoch 637, loss 0.007454643491655588
epoch 638, loss 0.022965088486671448
epoch 639, loss 0.013309992849826813
epoch 640, loss 0.027318114414811134
epoch 641, loss 0.015342509374022484
epoch 642, loss 0.021810680627822876
epoch 643, loss 0.018697772175073624
epoch 644, loss 0.018323123455047607
epoch 645, loss 0.020469317212700844
epoch 646, loss 0.007497725076973438
epoch 647, loss 0.010091399773955345
epoch 648, loss 0.03622547537088394
epoch 649, loss 0.010620559565722942
epoch 650, loss 0.019516462460160255
epoch 651, loss 0.01588595286011696
epoch 652, loss 0.014486374333500862
epoch 653, loss 0.009768723510205746
epoch 654, loss 0.009969720616936684
epoch 655, loss 0.01657717674970627
epoch 656, loss 0.03168190270662308
epoch 657, loss 0.007639228831976652
epoch 658, loss 0.008672168478369713
epoch 659, loss 0.01615973934531212
epoch 660, loss 0.009895333088934422
epoch 661, loss 0.014828264713287354
epoch 662, loss 0.022810539230704308
epoch 663, loss 0.02142714336514473
epoch 664, loss 0.016831161454319954
epoch 665, loss 0.022515010088682175
epoch 666, loss 0.028656331822276115
epoch 667, loss 0.012317425571382046
epoch 668, loss 0.029015595093369484
epoch 669, loss 0.015419980511069298
epoch 670, loss 0.01893225684762001
epoch 671, loss 0.01917259767651558
epoch 672, loss 0.012058115564286709
epoch 673, loss 0.0175024401396513
epoch 674, loss 0.012152372859418392
epoch 675, loss 0.0313471257686615
epoch 676, loss 0.012579049915075302
epoch 677, loss 0.007517106831073761
epoch 678, loss 0.009409354999661446
epoch 679, loss 0.017972709611058235
epoch 680, loss 0.007631160784512758
epoch 681, loss 0.008369705639779568
epoch 682, loss 0.019922196865081787
epoch 683, loss 0.026277055963873863
epoch 684, loss 0.034345533698797226
epoch 685, loss 0.025323880836367607
epoch 686, loss 0.007380353752523661
epoch 687, loss 0.0130592230707407
epoch 688, loss 0.012950627133250237
epoch 689, loss 0.02312638610601425
epoch 690, loss 0.023301204666495323
epoch 691, loss 0.01695537194609642
epoch 692, loss 0.014894520863890648
epoch 693, loss 0.022557681426405907
epoch 694, loss 0.012436386197805405
epoch 695, loss 0.014360709115862846
epoch 696, loss 0.028585214167833328
epoch 697, loss 0.013222431764006615
epoch 698, loss 0.029339930042624474
epoch 699, loss 0.017362181097269058
epoch 700, loss 0.03155801445245743
epoch 701, loss 0.021206025034189224
epoch 702, loss 0.015177502296864986
epoch 703, loss 0.015819016844034195
epoch 704, loss 0.013090841472148895
epoch 705, loss 0.010226870886981487
epoch 706, loss 0.038064055144786835
epoch 707, loss 0.007726406678557396
epoch 708, loss 0.025303607806563377
epoch 709, loss 0.013491473160684109
epoch 710, loss 0.03141431882977486
epoch 711, loss 0.020030580461025238
epoch 712, loss 0.024771401658654213
epoch 713, loss 0.01287803240120411
epoch 714, loss 0.019465051591396332
epoch 715, loss 0.007687435485422611
epoch 716, loss 0.030669208616018295
epoch 717, loss 0.011763809248805046
epoch 718, loss 0.01880601793527603
epoch 719, loss 0.016273679211735725
epoch 720, loss 0.007592839654535055
epoch 721, loss 0.014606602489948273
epoch 722, loss 0.019653892144560814
epoch 723, loss 0.02158265933394432
epoch 724, loss 0.03668655455112457
epoch 725, loss 0.026853155344724655
epoch 726, loss 0.022007213905453682
epoch 727, loss 0.007745882961899042
epoch 728, loss 0.017784185707569122
epoch 729, loss 0.029641779139637947
epoch 730, loss 0.023073801770806313
epoch 731, loss 0.007881280034780502
epoch 732, loss 0.026650767773389816
epoch 733, loss 0.016735125333070755
epoch 734, loss 0.007739576976746321
epoch 735, loss 0.011589385569095612
epoch 736, loss 0.0076589686796069145
epoch 737, loss 0.028237471356987953
epoch 738, loss 0.013116644695401192
epoch 739, loss 0.01999349519610405
epoch 740, loss 0.018335534259676933
epoch 741, loss 0.01360531710088253
epoch 742, loss 0.01348479837179184
epoch 743, loss 0.007502962835133076
epoch 744, loss 0.020506395027041435
epoch 745, loss 0.030463233590126038
epoch 746, loss 0.018440289422869682
epoch 747, loss 0.017376858741044998
epoch 748, loss 0.030064217746257782
epoch 749, loss 0.018774488940835
epoch 750, loss 0.02545275166630745
epoch 751, loss 0.01795100048184395
epoch 752, loss 0.01855287328362465
epoch 753, loss 0.016717970371246338
epoch 754, loss 0.012752700597047806
epoch 755, loss 0.013211419805884361
epoch 756, loss 0.014406370930373669
epoch 757, loss 0.03285311535000801
epoch 758, loss 0.018850035965442657
epoch 759, loss 0.03428535535931587
epoch 760, loss 0.027191074565052986
epoch 761, loss 0.022385479882359505
epoch 762, loss 0.033801108598709106
epoch 763, loss 0.011411282233893871
epoch 764, loss 0.02633909322321415
epoch 765, loss 0.015571681782603264
epoch 766, loss 0.015001367777585983
epoch 767, loss 0.020825568586587906
epoch 768, loss 0.013522672466933727
epoch 769, loss 0.0206284336745739
epoch 770, loss 0.021215088665485382
epoch 771, loss 0.01331806555390358
epoch 772, loss 0.017225706949830055
epoch 773, loss 0.009636943228542805
epoch 774, loss 0.014846649020910263
epoch 775, loss 0.007610867731273174
epoch 776, loss 0.010264521464705467
epoch 777, loss 0.023951904848217964
epoch 778, loss 0.018817398697137833
epoch 779, loss 0.007554836571216583
epoch 780, loss 0.017606671899557114
epoch 781, loss 0.030771329998970032
epoch 782, loss 0.02165372669696808
epoch 783, loss 0.010484383441507816
epoch 784, loss 0.016524914652109146
epoch 785, loss 0.014409644529223442
epoch 786, loss 0.011205517686903477
epoch 787, loss 0.017194442451000214
epoch 788, loss 0.014458948746323586
epoch 789, loss 0.022276483476161957
epoch 790, loss 0.024534404277801514
epoch 791, loss 0.02850051037967205
epoch 792, loss 0.007673670072108507
epoch 793, loss 0.01781575381755829
epoch 794, loss 0.007449339143931866
epoch 795, loss 0.008626946248114109
epoch 796, loss 0.02208532951772213
epoch 797, loss 0.01692277193069458
epoch 798, loss 0.0353664755821228
epoch 799, loss 0.009494869038462639
epoch 800, loss 0.007684657350182533
epoch 801, loss 0.023190513253211975
epoch 802, loss 0.034511417150497437
epoch 803, loss 0.009023312479257584
epoch 804, loss 0.021446753293275833
epoch 805, loss 0.025723297148942947
epoch 806, loss 0.012182077392935753
epoch 807, loss 0.022535176947712898
epoch 808, loss 0.012011760845780373
epoch 809, loss 0.02956685796380043
epoch 810, loss 0.008102417923510075
epoch 811, loss 0.02392469346523285
epoch 812, loss 0.00747322803363204
epoch 813, loss 0.025016775354743004
epoch 814, loss 0.019854087382555008
epoch 815, loss 0.019345834851264954
epoch 816, loss 0.016949571669101715
epoch 817, loss 0.016314715147018433
epoch 818, loss 0.007533114869147539
epoch 819, loss 0.014514897018671036
epoch 820, loss 0.007350636646151543
epoch 821, loss 0.012861606664955616
epoch 822, loss 0.030636867508292198
epoch 823, loss 0.012880083173513412
epoch 824, loss 0.0075751193799078465
epoch 825, loss 0.01344993244856596
epoch 826, loss 0.017664803192019463
epoch 827, loss 0.00815372634679079
epoch 828, loss 0.017833387479186058
epoch 829, loss 0.016653377562761307
epoch 830, loss 0.01588224619626999
epoch 831, loss 0.024150222539901733
epoch 832, loss 0.01071011833846569
epoch 833, loss 0.021543540060520172
epoch 834, loss 0.010035737417638302
epoch 835, loss 0.007750558201223612
epoch 836, loss 0.028765324503183365
epoch 837, loss 0.01692085526883602
epoch 838, loss 0.028056804090738297
epoch 839, loss 0.024473560974001884
epoch 840, loss 0.01740868017077446
epoch 841, loss 0.010254323482513428
epoch 842, loss 0.026383226737380028
epoch 843, loss 0.01658186875283718
epoch 844, loss 0.02445029653608799
epoch 845, loss 0.017158638685941696
epoch 846, loss 0.012533286586403847
epoch 847, loss 0.020487215369939804
epoch 848, loss 0.012785225175321102
epoch 849, loss 0.024625064805150032
epoch 850, loss 0.024948693811893463
epoch 851, loss 0.01598278433084488
epoch 852, loss 0.02164105512201786
epoch 853, loss 0.011759608052670956
epoch 854, loss 0.018907997757196426
epoch 855, loss 0.024952838197350502
epoch 856, loss 0.02362382970750332
epoch 857, loss 0.012039531022310257
epoch 858, loss 0.01568535342812538
epoch 859, loss 0.016189908608794212
epoch 860, loss 0.018393775448203087
epoch 861, loss 0.023780561983585358
epoch 862, loss 0.032943081110715866
epoch 863, loss 0.01767968386411667
epoch 864, loss 0.013374028727412224
epoch 865, loss 0.014736316166818142
epoch 866, loss 0.01716749556362629
epoch 867, loss 0.024247953668236732
epoch 868, loss 0.032660819590091705
epoch 869, loss 0.009964062832295895
epoch 870, loss 0.010327998548746109
epoch 871, loss 0.034459780901670456
epoch 872, loss 0.015104319900274277
epoch 873, loss 0.012209195643663406
epoch 874, loss 0.016967345029115677
epoch 875, loss 0.01296798326075077
epoch 876, loss 0.01111768651753664
epoch 877, loss 0.008476654067635536
epoch 878, loss 0.016839487478137016
epoch 879, loss 0.026238437741994858
epoch 880, loss 0.031103627756237984
epoch 881, loss 0.011663423851132393
epoch 882, loss 0.007876222021877766
epoch 883, loss 0.007449612952768803
epoch 884, loss 0.03644057363271713
epoch 885, loss 0.011196950450539589
epoch 886, loss 0.018665995448827744
epoch 887, loss 0.03263906389474869
epoch 888, loss 0.02173745632171631
epoch 889, loss 0.012189552187919617
epoch 890, loss 0.010980552062392235
epoch 891, loss 0.028880640864372253
epoch 892, loss 0.012488307431340218
epoch 893, loss 0.023488827049732208
epoch 894, loss 0.021437574177980423
epoch 895, loss 0.04317639023065567
epoch 896, loss 0.007776353973895311
epoch 897, loss 0.013368988409638405
epoch 898, loss 0.012168614193797112
epoch 899, loss 0.0122537761926651
epoch 900, loss 0.007529892958700657
epoch 901, loss 0.02691866084933281
epoch 902, loss 0.01695902831852436
epoch 903, loss 0.0327284149825573
epoch 904, loss 0.03219088539481163
epoch 905, loss 0.021859683096408844
epoch 906, loss 0.008296103216707706
epoch 907, loss 0.012783198617398739
epoch 908, loss 0.02704237401485443
epoch 909, loss 0.007805245462805033
epoch 910, loss 0.03286232426762581
epoch 911, loss 0.03748755156993866
epoch 912, loss 0.008069809526205063
epoch 913, loss 0.00879546906799078
epoch 914, loss 0.008985240943729877
epoch 915, loss 0.017346035689115524
epoch 916, loss 0.00773931248113513
epoch 917, loss 0.018249772489070892
epoch 918, loss 0.015175414271652699
epoch 919, loss 0.01641557738184929
epoch 920, loss 0.012115797027945518
epoch 921, loss 0.019136447459459305
epoch 922, loss 0.008605287410318851
epoch 923, loss 0.02661602571606636
epoch 924, loss 0.014502530917525291
epoch 925, loss 0.018323784694075584
epoch 926, loss 0.023610500618815422
epoch 927, loss 0.021013332530856133
epoch 928, loss 0.02292006090283394
epoch 929, loss 0.03079128824174404
epoch 930, loss 0.020340565592050552
epoch 931, loss 0.012389863841235638
epoch 932, loss 0.01695152372121811
epoch 933, loss 0.02451608143746853
epoch 934, loss 0.017874451354146004
epoch 935, loss 0.014015205204486847
epoch 936, loss 0.01189585030078888
epoch 937, loss 0.01737883687019348
epoch 938, loss 0.01146402396261692
epoch 939, loss 0.00889667496085167
epoch 940, loss 0.017544014379382133
epoch 941, loss 0.02616111747920513
epoch 942, loss 0.011684872210025787
epoch 943, loss 0.019128764048218727
epoch 944, loss 0.02096438966691494
epoch 945, loss 0.007531927898526192
epoch 946, loss 0.023502491414546967
epoch 947, loss 0.023917686194181442
epoch 948, loss 0.021799884736537933
epoch 949, loss 0.025123808532953262
epoch 950, loss 0.015130491927266121
epoch 951, loss 0.008404581807553768
epoch 952, loss 0.017272084951400757
epoch 953, loss 0.034752897918224335
epoch 954, loss 0.007398664485663176
epoch 955, loss 0.023963527753949165
epoch 956, loss 0.008190229535102844
epoch 957, loss 0.01572483777999878
epoch 958, loss 0.007463965564966202
epoch 959, loss 0.029094137251377106
epoch 960, loss 0.01076517254114151
epoch 961, loss 0.011605532839894295
epoch 962, loss 0.025566309690475464
epoch 963, loss 0.02511732652783394
epoch 964, loss 0.0234988983720541
epoch 965, loss 0.020106850191950798
epoch 966, loss 0.02058454044163227
epoch 967, loss 0.016476023942232132
epoch 968, loss 0.03169805929064751
epoch 969, loss 0.029034823179244995
epoch 970, loss 0.01968291401863098
epoch 971, loss 0.014024743810296059
epoch 972, loss 0.019483700394630432
epoch 973, loss 0.020644718781113625
epoch 974, loss 0.0238460972905159
epoch 975, loss 0.009490416385233402
epoch 976, loss 0.008157865144312382
epoch 977, loss 0.01973652094602585
epoch 978, loss 0.008165491744875908
epoch 979, loss 0.023046698421239853
epoch 980, loss 0.03185794875025749
epoch 981, loss 0.008724672719836235
epoch 982, loss 0.00870177149772644
epoch 983, loss 0.007649124599993229
epoch 984, loss 0.011187388561666012
epoch 985, loss 0.01703462190926075
epoch 986, loss 0.015550121665000916
epoch 987, loss 0.0202175285667181
epoch 988, loss 0.008466003462672234
epoch 989, loss 0.012255251407623291
epoch 990, loss 0.007416947279125452
epoch 991, loss 0.017378276214003563
epoch 992, loss 0.01041489839553833
epoch 993, loss 0.022198719903826714
epoch 994, loss 0.00880573783069849
epoch 995, loss 0.02067861706018448
epoch 996, loss 0.01431190688163042
epoch 997, loss 0.010802632197737694
epoch 998, loss 0.016903884708881378
epoch 999, loss 0.009909210726618767
# print final parameters
for name, param in model.named_parameters():
print(name, ": ", param.data)
linear.weight : tensor([[ 0.4861, -2.7325]])
linear.bias : tensor([-0.0111])
In this formulation all samples are used to fit the parameters. This in contranst to the SMO solution might take longer to optimize, but is more stable because we don’t select individial samples and ignore all the others.
visualize_torch(X, Y, model=model)
/var/folders/7g/3mxmtrb16h7gh3hsxzbkh_5w0000gn/T/ipykernel_87032/3784346825.py:44: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:264.)
xy = torch.tensor(xy, dtype=torch.float32).T
/var/folders/7g/3mxmtrb16h7gh3hsxzbkh_5w0000gn/T/ipykernel_87032/3784346825.py:49: UserWarning: The following kwargs were not used by contour: 'linewidth'
plt.contour(cs0, '-', levels=[0], colors='r', linewidth=5)