This repository was archived by the owner on Dec 6, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 215
/
Copy pathplot_1d_total_variation.py
48 lines (38 loc) · 1.61 KB
/
plot_1d_total_variation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
=====================================
Signal recovery by 1D total variation
=====================================
In this example, we generate a signal that is piecewise constant. We then
observe some random and corrupted measurements from that signal and
then try to recover that signal using L1 and 1D total variation (TV1D) penalties.
Given a ground truth vectors, the signal that we observe is given by
y = sign(X ground_truth + noise)
where X is a random matrix. We obtain the vector ground_truth by solving
an optimization problem using lightning's :class:`lightning.classification.FistaClassifier`.
The 1D total variation is also known as fused lasso.
"""
# Author: Fabian Pedregosa <f@bianp.net>
import numpy as np
import matplotlib.pyplot as plt
from lightning.classification import FistaClassifier
from sklearn.model_selection import GridSearchCV
# generate some synthetic data
n_samples = 200
ground_truth = np.concatenate((
np.ones(20), - np.ones(20), np.zeros(40)))
n_features = ground_truth.size
np.random.seed(0) # for reproducibility
X = np.random.rand(n_samples, n_features)
# generate y as a linear model, y = sign(X w + noise)
y = np.sign(X.dot(ground_truth) + 0.5 * np.random.randn(n_samples)).astype(int)
for penalty in ('l1', 'tv1d'):
clf = FistaClassifier(penalty=penalty)
gs = GridSearchCV(clf, {'alpha': np.logspace(-3, 3, 10)})
gs.fit(X, y)
coefs = gs.best_estimator_.coef_
plt.plot(coefs.ravel(), label='%s penalty' % penalty, lw=3)
plt.plot(ground_truth, lw=3, marker='^', markevery=5, markersize=10, label="ground truth")
plt.grid()
plt.legend()
plt.ylim((-1.5, 1.5))
plt.show()