Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions demo_tv_bounds_reconstruction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Total-variation penalization and bound constraints for tomography reconstruction
================================================================================

In this example, we reconstruct an image from its tomography projections
with an uncomplete set of projections (l/8 angles, where l is the linear
size of the image. For a correct reconstruction without a-priori information,
one would usually require l or more angles). In addition, noise is added to
the projections.

In order to reconstruct the original image, we minimize a function that
is the sum of (i) a L2 data fit term, and (ii) the total variation of the
image, and bound constraints on the pixel values. Proximal iterations
using the FISTA scheme are used.

We compare with and without the bounds

This example should take around 1mn to run and plot the results.
"""

print __doc__

import numpy as np
from reconstruction.forward_backward_tv import fista_tv
from reconstruction.projections import build_projection_operator
from reconstruction.util import generate_synthetic_data
from time import time
import matplotlib.pyplot as plt

# Synthetic data
l = 512
np.random.seed(0)
x = generate_synthetic_data(l)


# Projection operator and projections data, with noise
H = build_projection_operator(l, l / 32)
y = H * x.ravel()[:, np.newaxis]
y += 5 * np.random.randn(*y.shape)

# Display original data
plt.figure(figsize=(12, 5))
plt.subplot(2, 3, 1)
plt.imshow(x, cmap=plt.cm.gnuplot2, interpolation='nearest', vmin=-.1, vmax=1.2)
plt.title('original data (256x256)')
plt.axis('off')

for idx, (val_min, val_max, name) in enumerate([
(None, None, 'TV'),
(0, 1, 'TV + interval'),
]):
# Reconstruction
t1 = time()
res, energies = fista_tv(y, 50, 100, H, val_min=val_min,
val_max=val_max)
t2 = time()

# Fraction of errors of segmented image wrt ground truth
err = np.abs(x - (res[-1] > 0.5)).mean()
print "%s: reconstruction done in %f s, %.3f%% segmentation error" % (
name, t2 - t1, 100 * err)

plt.subplot(2, 3, 2 + idx)
plt.imshow(res[-1], cmap=plt.cm.gnuplot2, interpolation='nearest', vmin=-.1,
vmax=1.2)
plt.title('reconstruction with %s' % name)
plt.axis('off')
ax = plt.subplot(2, 3, 5 + idx)
ax.yaxis.set_scale('log')
plt.hist(res[-1].ravel(), bins=20, normed=True)
plt.yticks(())
plt.title('Histogram of pixel intensity')
plt.axis('tight')

plt.show()
15 changes: 13 additions & 2 deletions reconstruction/forward_backward_tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def tv_norm(im):
grad_x2 = np.diff(im, axis=1)
return np.sqrt(grad_x1[:, :-1]**2 + grad_x2[:-1, :]**2).sum()


def tv_norm_anisotropic(im):
"""Compute the anisotropic TV norm of an image"""
grad_x1 = np.diff(im, axis=0)
Expand All @@ -19,7 +20,8 @@ def tv_norm_anisotropic(im):

# ------------------ Proximal iterators ----------------------------

def fista_tv(y, beta, niter, H, verbose=0, mask=None):
def fista_tv(y, beta, niter, H, verbose=0, mask=None,
val_min=None, val_max=None):
"""
TV regression using FISTA algorithm
(Fast Iterative Shrinkage/Thresholding Algorithm)
Expand All @@ -42,6 +44,12 @@ def fista_tv(y, beta, niter, H, verbose=0, mask=None):

mask : array of bools

val_min: None or float, optional
an optional lower bound constraint on the reconstructed image

val_max: None or float, optional
an optional upper bound constraint on the reconstructed image

Returns
-------

Expand Down Expand Up @@ -102,7 +110,8 @@ def fista_tv(y, beta, niter, H, verbose=0, mask=None):
else:
tmp2d = tmp.reshape((l, l))
u_n = tv_denoise_fista(tmp2d,
weight=beta*gamma, eps=eps)
weight=beta*gamma, eps=eps, val_min=val_min,
val_max=val_max)
t_new = (1 + np.sqrt(1 + 4 * t_old**2))/2.
t_old = t_new
x = u_n + (t_old - 1)/t_new * (u_n - u_old)
Expand Down Expand Up @@ -218,6 +227,7 @@ def ista_tv(y, beta, niter, H=None):
energies.append(energy)
return res, energies


def gfb_tv(y, beta, niter, H=None, val_min=0, val_max=1, x0=None,
stop_tol=1.e-4):
"""
Expand Down Expand Up @@ -332,6 +342,7 @@ def gfb_tv(y, beta, niter, H=None, val_min=0, val_max=1, x0=None,
break
return res, energies


def gfb_tv_local(y, beta, niter, mask_pix, mask_reg, H=None,
val_min=0, val_max=1, x0=None):
"""
Expand Down
144 changes: 120 additions & 24 deletions reconstruction/tv_denoising.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np


def div(grad):
""" Compute divergence of image gradient """
res = np.zeros(grad.shape[1:])
Expand All @@ -11,8 +12,9 @@ def div(grad):
this_res[-1] -= this_grad[-2]
return res


def gradient(img):
"""
"""
Compute gradient of an image

Parameters
Expand All @@ -26,12 +28,12 @@ def gradient(img):
Gradient of the image: the i-th component along the first
axis is the gradient along the i-th axis of the original
array img
"""
"""
shape = [img.ndim, ] + list(img.shape)
gradient = np.zeros(shape, dtype=img.dtype)
# 'Clever' code to have a view of the gradient with dimension i stop
# at -1
slice_all = [0, slice(None, -1),]
slice_all = [0, slice(None, -1), ]
for d in range(img.ndim):
gradient[slice_all] = np.diff(img, axis=d)
slice_all[0] = d + 1
Expand All @@ -44,7 +46,7 @@ def _projector_on_dual(grad):
modifies in place the gradient to project it
on the L2 unit ball
"""
norm = np.maximum(np.sqrt(np.sum(grad**2, 0)), 1.)
norm = np.maximum(np.sqrt(np.sum(grad ** 2, 0)), 1.)
for grad_comp in grad:
grad_comp /= norm
return grad
Expand All @@ -56,22 +58,23 @@ def dual_gap(im, new, gap, weight):
see "Total variation regularization for fMRI-based prediction of behavior",
by Michel et al. (2011) for a derivation of the dual gap
"""
im_norm = (im**2).sum()
im_norm = (im ** 2).sum()
gx, gy = np.zeros_like(new), np.zeros_like(new)
gx[:-1] = np.diff(new, axis=0)
gy[:, :-1] = np.diff(new, axis=1)
if im.ndim == 3:
gz = np.zeros_like(new)
gz[..., :-1] = np.diff(new, axis=2)
tv_new = 2 * weight * np.sqrt(gx**2 + gy**2 + gz**2).sum()
tv_new = 2 * weight * np.sqrt(gx ** 2 + gy ** 2 + gz ** 2).sum()
else:
tv_new = 2 * weight * np.sqrt(gx**2 + gy**2).sum()
dual_gap = (gap**2).sum() + tv_new - im_norm + (new**2).sum()
tv_new = 2 * weight * np.sqrt(gx ** 2 + gy ** 2).sum()
dual_gap = (gap ** 2).sum() + tv_new - im_norm + (new ** 2).sum()
return 0.5 / im_norm * dual_gap

def tv_denoise_fista(im, weight=50, eps=5.e-5, n_iter_max=200,
check_gap_frequency=3):

def tv_denoise_fista(im, weight=50, eps=5.e-5, n_iter_max=200,
check_gap_frequency=3, val_min=None, val_max=None,
verbose=False):
"""
Perform total-variation denoising on 2-d and 3-d images

Expand Down Expand Up @@ -99,6 +102,15 @@ def tv_denoise_fista(im, weight=50, eps=5.e-5, n_iter_max=200,
n_iter_max: int, optional
maximal number of iterations used for the optimization.

val_min: None or float, optional
an optional lower bound constraint on the reconstructed image

val_max: None or float, optional
an optional upper bound constraint on the reconstructed image

verbose: bool, optional
if True, plot the dual gap of the optimization

Returns
-------
out: ndarray
Expand All @@ -120,50 +132,134 @@ def tv_denoise_fista(im, weight=50, eps=5.e-5, n_iter_max=200,
total variation denoising in "Fast gradient-based algorithms for
constrained total variation image denoising and deblurring problems"
(2009).

For details on implementing the bound constraints, read the Beck and
Teboulle paper.
"""
if not im.dtype.kind == 'f':
im = im.astype(np.float)
shape = [im.ndim, ] + list(im.shape)
input_img = im
if not input_img.dtype.kind == 'f':
input_img = input_img.astype(np.float)
shape = [input_img.ndim, ] + list(input_img.shape)
grad_im = np.zeros(shape)
grad_aux = np.zeros(shape)
t = 1.
i = 0
if input_img.ndim == 2:
# Upper bound on the Lipschitz constant
lipschitz_constant = 9
elif input_img.ndim == 3:
lipschitz_constant = 12
else:
raise ValueError('Cannot compute TV for images that are not '
'2D or 3D')
# negated_output is the negated primal variable in the optimization
# loop
negated_output = -input_img
# Clipping values for the inner loop
negated_val_min = np.nan
negated_val_max = np.nan
if val_min is not None:
negated_val_min = -val_min
if val_max is not None:
negated_val_max = -val_max
if (val_min is not None or val_max is not None):
# With bound constraints, the stopping criterion is on the
# evolution of the output
negated_output_old = negated_output.copy()
while i < n_iter_max:
error = weight * div(grad_aux) - im
grad_tmp = gradient(error)
grad_tmp *= 1./ (8 * weight)
grad_tmp = gradient(negated_output)
grad_tmp *= 1. / (lipschitz_constant * weight)
grad_aux += grad_tmp
grad_tmp = _projector_on_dual(grad_aux)
t_new = 1. / 2 * (1 + np.sqrt(1 + 4 * t**2))
t_new = 1. / 2 * (1 + np.sqrt(1 + 4 * t ** 2))
t_factor = (t - 1) / t_new
grad_aux = (1 + t_factor) * grad_tmp - t_factor * grad_im
grad_im = grad_tmp
t = t_new
gap = weight * div(grad_im)
# Compute the primal variable
negated_output = gap - input_img
if (val_min is not None or val_max is not None):
negated_output = negated_output.clip(negated_val_max,
negated_val_min,
out=negated_output)
if (i % check_gap_frequency) == 0:
gap = weight * div(grad_im)
new = im - gap
dgap = dual_gap(im, new, gap, weight)
if dgap < eps:
break
if val_min is None and val_max is None:
# In the case of bound constraints, we don't have
# the dual gap
dgap = dual_gap(input_img, -negated_output, gap, weight)
if verbose:
print 'Iteration % 2i, dual gap: % 6.3e' % (i, dgap)
if dgap < eps:
break
else:
diff = np.max(np.abs(negated_output_old - negated_output))
diff /= np.max(np.abs(negated_output))
if verbose:
print 'Iteration % 2i, relative difference: % 6.3e' % (i,
diff)
if diff < eps:
break
negated_output_old = negated_output
i += 1
return new
# Compute the primal variable
output = input_img - gap
if (val_min is not None or val_max is not None):
output = output.clip(-negated_val_min, -negated_val_max, out=output)
return output


def test_grad_div_adjoint(size=12, random_state=42):
# We need to check that <D x, y> = <x, DT y> for x and y random vectors
random_state = np.random.RandomState(random_state)

x = np.random.normal(size=(size, size, size))
y = np.random.normal(size=(3, size, size, size))

np.testing.assert_almost_equal(np.sum(gradient(x) * y),
-np.sum(x * div(y)))


if __name__ == '__main__':
# First our test
test_grad_div_adjoint()
from scipy.misc import lena
import matplotlib.pyplot as plt
from time import time

# Smoke test on lena
l = lena().astype(np.float)
# normalize image between 0 and 1
l /= l.max()
l += 0.1 * l.std() * np.random.randn(*l.shape)
t0 = time()
res = tv_denoise_fista(l, weight=0.05, eps=5.e-5)
res = tv_denoise_fista(l, weight=2.5, eps=5.e-5, verbose=True)
t1 = time()
print t1 - t0
plt.figure()
plt.subplot(121)
plt.imshow(l, cmap='gray')
plt.subplot(122)
plt.imshow(res, cmap='gray')

# Smoke test on a 3D random image with hidden structure
np.random.seed(42)
img = np.random.normal(size=(12, 24, 24))
img[4:8, 8:16, 8:16] += 1.5
res = tv_denoise_fista(img, weight=.6, eps=5.e-5, verbose=True)
plt.figure(figsize=(9, 3))
plt.subplot(131)
plt.imshow(img[6], cmap='gist_earth')
plt.title('Original data')
plt.subplot(132)
plt.imshow(res[6], cmap='gist_earth', vmin=-.1, vmax=.3)
plt.title('TV')

# add constraints
res_cons = tv_denoise_fista(img, weight=.6, eps=5.e-5, verbose=True,
val_min=0, val_max=1.5)
plt.subplot(133)
plt.imshow(res_cons[6], cmap='gist_earth', vmin=-.1, vmax=.3)
plt.title('TV + interval')

plt.show()
Loading