壓縮感知:使用 L1 先驗 (Lasso) 進行斷層掃描重建#

此範例顯示從一組沿不同角度獲取的平行投影重建影像。此類資料集是在電腦斷層掃描 (CT) 中獲取的。

在沒有任何樣本先驗資訊的情況下,重建影像所需的投影數量與影像的線性尺寸 l(以像素為單位)相當。為簡單起見,我們在此考慮稀疏影像,其中只有物件邊界上的像素具有非零值。此類資料可能對應於例如細胞材料。然而,請注意,大多數影像在不同的基底(例如 Haar 小波)中是稀疏的。只獲取 l/7 個投影,因此有必要使用樣本上可用的先驗資訊(其稀疏性):這是壓縮感知的一個範例。

斷層掃描投影操作是一種線性轉換。除了對應於線性迴歸的資料擬合項之外,我們還會懲罰影像的 L1 範數,以考慮其稀疏性。產生的最佳化問題稱為Lasso。我們使用類別 Lasso,它使用座標下降演算法。重要的是,此實作在稀疏矩陣上的計算效率比此處使用的投影運算子更高。

即使雜訊被加入到投影中,使用 L1 懲罰進行重建也能產生零誤差的結果(所有像素都成功標記為 0 或 1)。相比之下,L2 懲罰 (Ridge) 會產生大量像素標記錯誤。在重建影像上觀察到重要的偽影,與 L1 懲罰相反。請特別注意將角落像素分離的圓形偽影,這些像素對投影的貢獻少於中心圓盤。

original image, L2 penalization, L1 penalization
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage, sparse

from sklearn.linear_model import Lasso, Ridge


def _weights(x, dx=1, orig=0):
    x = np.ravel(x)
    floor_x = np.floor((x - orig) / dx).astype(np.int64)
    alpha = (x - orig - floor_x * dx) / dx
    return np.hstack((floor_x, floor_x + 1)), np.hstack((1 - alpha, alpha))


def _generate_center_coordinates(l_x):
    X, Y = np.mgrid[:l_x, :l_x].astype(np.float64)
    center = l_x / 2.0
    X += 0.5 - center
    Y += 0.5 - center
    return X, Y


def build_projection_operator(l_x, n_dir):
    """Compute the tomography design matrix.

    Parameters
    ----------

    l_x : int
        linear size of image array

    n_dir : int
        number of angles at which projections are acquired.

    Returns
    -------
    p : sparse matrix of shape (n_dir l_x, l_x**2)
    """
    X, Y = _generate_center_coordinates(l_x)
    angles = np.linspace(0, np.pi, n_dir, endpoint=False)
    data_inds, weights, camera_inds = [], [], []
    data_unravel_indices = np.arange(l_x**2)
    data_unravel_indices = np.hstack((data_unravel_indices, data_unravel_indices))
    for i, angle in enumerate(angles):
        Xrot = np.cos(angle) * X - np.sin(angle) * Y
        inds, w = _weights(Xrot, dx=1, orig=X.min())
        mask = np.logical_and(inds >= 0, inds < l_x)
        weights += list(w[mask])
        camera_inds += list(inds[mask] + i * l_x)
        data_inds += list(data_unravel_indices[mask])
    proj_operator = sparse.coo_matrix((weights, (camera_inds, data_inds)))
    return proj_operator


def generate_synthetic_data():
    """Synthetic binary data"""
    rs = np.random.RandomState(0)
    n_pts = 36
    x, y = np.ogrid[0:l, 0:l]
    mask_outer = (x - l / 2.0) ** 2 + (y - l / 2.0) ** 2 < (l / 2.0) ** 2
    mask = np.zeros((l, l))
    points = l * rs.rand(2, n_pts)
    mask[(points[0]).astype(int), (points[1]).astype(int)] = 1
    mask = ndimage.gaussian_filter(mask, sigma=l / n_pts)
    res = np.logical_and(mask > mask.mean(), mask_outer)
    return np.logical_xor(res, ndimage.binary_erosion(res))


# Generate synthetic images, and projections
l = 128
proj_operator = build_projection_operator(l, l // 7)
data = generate_synthetic_data()
proj = proj_operator @ data.ravel()[:, np.newaxis]
proj += 0.15 * np.random.randn(*proj.shape)

# Reconstruction with L2 (Ridge) penalization
rgr_ridge = Ridge(alpha=0.2)
rgr_ridge.fit(proj_operator, proj.ravel())
rec_l2 = rgr_ridge.coef_.reshape(l, l)

# Reconstruction with L1 (Lasso) penalization
# the best value of alpha was determined using cross validation
# with LassoCV
rgr_lasso = Lasso(alpha=0.001)
rgr_lasso.fit(proj_operator, proj.ravel())
rec_l1 = rgr_lasso.coef_.reshape(l, l)

plt.figure(figsize=(8, 3.3))
plt.subplot(131)
plt.imshow(data, cmap=plt.cm.gray, interpolation="nearest")
plt.axis("off")
plt.title("original image")
plt.subplot(132)
plt.imshow(rec_l2, cmap=plt.cm.gray, interpolation="nearest")
plt.title("L2 penalization")
plt.axis("off")
plt.subplot(133)
plt.imshow(rec_l1, cmap=plt.cm.gray, interpolation="nearest")
plt.title("L1 penalization")
plt.axis("off")

plt.subplots_adjust(hspace=0.01, wspace=0.01, top=1, bottom=0, left=0, right=1)

plt.show()

腳本總執行時間:(0 分鐘 10.068 秒)

相關範例

在硬幣圖片上進行結構化 Ward 階層式分群的演示

在硬幣圖片上進行結構化 Ward 階層式分群的演示

用於影像分割的譜分群

用於影像分割的譜分群

用於稀疏訊號的基於 L1 的模型

用於稀疏訊號的基於 L1 的模型

階層式分群:結構化與非結構化 Ward

階層式分群:結構化與非結構化 Ward

由 Sphinx-Gallery 產生的圖庫