From b09e498c593b8b4a89b8e12b4d5a35370d317c0e Mon Sep 17 00:00:00 2001 From: noakraicer Date: Thu, 29 Aug 2024 15:45:25 +0300 Subject: [PATCH 1/6] noa --- configs/benchmark_hyperspectral.yaml | 114 +++++++++++++++++++++++++++ lensless/__init__.py | 1 + lensless/recon/gd.py | 92 ++++++++++++++++++--- lensless/recon/recon.py | 14 ++-- lensless/recon/rfft_convolve.py | 6 +- lensless/utils/dataset.py | 13 ++- lensless/utils/image.py | 2 +- lensless/utils/io.py | 16 +++- scripts/eval/benchmark_recon.py | 13 ++- 9 files changed, 241 insertions(+), 30 deletions(-) create mode 100644 configs/benchmark_hyperspectral.yaml diff --git a/configs/benchmark_hyperspectral.yaml b/configs/benchmark_hyperspectral.yaml new file mode 100644 index 00000000..5f775875 --- /dev/null +++ b/configs/benchmark_hyperspectral.yaml @@ -0,0 +1,114 @@ +# python scripts/eval/benchmark_recon.py +#Hydra config +hydra: + run: + dir: "benchmark/${now:%Y-%m-%d}/${now:%H-%M-%S}" + job: + chdir: True + + +dataset: PolarLitis # DiffuserCam, DigiCamCelebA, HFDataset +seed: 0 +batchsize: 1 # must be 1 for iterative approaches + +huggingface: + repo: "noakraicer/polarlitis" + cache_dir: null # where to read/write dataset. Defaults to `"~/.cache/huggingface/datasets"`. + psf: psf.mat + mask: mask.npy # null for simulating PSF + image_res: [250, 250] # used during measurement + rotate: False # if measurement is upside-down + flipud: False + flip_lensed: False # if rotate or flipud is True, apply to lensed + + alignment: + top_left: null + height: null + + downsample: 1 + downsample_lensed: 2 + split_seed: null + single_channel_psf: True + +device: "cuda" +# numbers of iterations to benchmark +n_iter_range: [5, 10, 20, 50, 100, 200, 300] +# number of files to benchmark +n_files: null # null for all files +#How much should the image be downsampled +downsample: 2 +#algorithm to benchmark +algorithms: ["HyperSpectralFISTA"] #["ADMM", "ADMM_Monakhova2019", "FISTA", "GradientDescent", "NesterovGradientDescent"] + +# baseline from Monakhova et al. 2019, https://arxiv.org/abs/1908.11502 +baseline: "MONAKHOVA 100iter" + +save_idx: [0, 1, 2, 3, 4] # provide index of files to save e.g. [1, 5, 10] +gamma_psf: 1.5 # gamma factor for PSF + + +# Hyperparameters +nesterov: + p: 0 + mu: 0.9 +fista: + tk: 1 +admm: + mu1: 1e-6 + mu2: 1e-5 + mu3: 4e-5 + tau: 0.0001 + + +# for DigiCamCelebA +files: + test_size: 0.15 + downsample: 1 + celeba_root: /scratch/bezzam + + + # dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K + # psf: data/psf/adafruit_random_2mm_20231907.png + # vertical_shift: null + # horizontal_shift: null + # crop: null + + dataset: /scratch/bezzam/celeba/celeba_adafruit_random_30cm_2mm_20231004_26K + psf: rpi_hq_adafruit_psf_2mm/raw_data_rgb.png + vertical_shift: -117 + horizontal_shift: -25 + crop: + vertical: [0, 525] + horizontal: [265, 695] + +# for prepping ground truth data +#for simulated dataset +simulation: + grayscale: False + output_dim: null # should be set if no PSF is used + # random variations + object_height: 0.33 # [m], range for random height or scalar + flip: True # change the orientation of the object (from vertical to horizontal) + random_shift: False + random_vflip: 0.5 + random_hflip: 0.5 + random_rotate: False + # these distance parameters are typically fixed for a given PSF + # for DiffuserCam psf # for tape_rgb psf + # scene2mask: 10e-2 # scene2mask: 40e-2 + # mask2sensor: 9e-3 # mask2sensor: 4e-3 + # -- for CelebA + scene2mask: 0.25 # [m] + mask2sensor: 0.002 # [m] + deadspace: True # whether to account for deadspace for programmable mask + # see waveprop.devices + use_waveprop: False # for PSF simulation + sensor: "rpi_hq" + snr_db: 10 + # simulate different sensor resolution + # output_dim: [24, 32] # [H, W] or null + # Downsampling for PSF + downsample: 8 + # max val in simulated measured (quantized 8 bits) + quantize: False # must be False for differentiability + max_val: 255 diff --git a/lensless/__init__.py b/lensless/__init__.py index 70990774..4d67f179 100644 --- a/lensless/__init__.py +++ b/lensless/__init__.py @@ -20,6 +20,7 @@ NesterovGradientDescent, FISTA, GradientDescentUpdate, + HyperSpectralFISTA ) from .recon.tikhonov import CodedApertureReconstruction from .hardware.sensor import VirtualSensor, SensorOptions diff --git a/lensless/recon/gd.py b/lensless/recon/gd.py index dc61e809..0b6946d8 100644 --- a/lensless/recon/gd.py +++ b/lensless/recon/gd.py @@ -64,7 +64,7 @@ class GradientDescent(ReconstructionAlgorithm): Object for applying projected gradient descent. """ - def __init__(self, psf, dtype=None, proj=non_neg, **kwargs): + def __init__(self, psf,mask, dtype=None, proj=non_neg, **kwargs): """ Parameters @@ -83,25 +83,25 @@ def __init__(self, psf, dtype=None, proj=non_neg, **kwargs): assert callable(proj) self._proj = proj - super(GradientDescent, self).__init__(psf, dtype, **kwargs) + super(GradientDescent, self).__init__(psf,mask, dtype, **kwargs) if self._denoiser is not None: print("Using denoiser in gradient descent.") # redefine projection function self._proj = self._denoiser - + self.mask=mask def reset(self): if self.is_torch: if self._initial_est is not None: self._image_est = self._initial_est else: # initial guess, half intensity image - psf_flat = self._psf.reshape(-1, self._psf_shape[3]) - pixel_start = ( - torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values - ) / 2 + # psf_flat = self._psf.reshape(-1, self._psf_shape[3]) + # pixel_start = ( + # torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values + # ) / 2 # initialize image estimate as [Batch, Depth, Height, Width, Channels] - self._image_est = torch.ones_like(self._psf[None, ...]) * pixel_start + self._image_est = torch.zeros((1,250,250,3)) # set step size as < 2 / lipschitz Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) @@ -123,8 +123,8 @@ def reset(self): self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0)) def _grad(self): - diff = self._convolver.convolve(self._image_est) - self._data - return self._convolver.deconvolve(diff) + diff = np.sum(self.mask * self._convolver.convolve(self._image_est), -1) - self._data # (H, W, 1) + return self._convolver.deconvolve(diff * self.mask) # (H, W, C) where C is number of hyperspectral channels def _update(self, iter): self._image_est -= self._alpha * self._grad() @@ -238,6 +238,78 @@ def _update(self, iter): self._xk = xk +def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs): + + # load data + psf, data = load_data(psf_fp=psf_fp, data_fp=data_fp, plot=False, **kwargs) + + # create reconstruction object + recon = GradientDescent(psf, n_iter=n_iter, proj=proj) + + # set data + recon.set_data(data) + + # perform reconstruction + start_time = time.time() + res = recon.apply(plot=False) + proc_time = time.time() - start_time + + if verbose: + print(f"Reconstruction time : {proc_time} s") + print(f"Reconstruction shape: {res.shape}") + return res +class HyperSpectralFISTA(GradientDescent): + """ + Object for applying projected gradient descent with FISTA (Fast Iterative + Shrinkage-Thresholding Algorithm) for acceleration. + + Paper: https://www.ceremade.dauphine.fr/~carlier/FISTA + + """ + + def __init__(self, psf,mask, dtype=None, proj=non_neg, tk=1.0, **kwargs): + """ + + Parameters + ---------- + psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + Point spread function (PSF) that models forward propagation. + Must be of shape (depth, height, width, channels) even if + depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` + to load a PSF from a file such that it is in the correct format. + dtype : float32 or float64 + Data type to use for optimization. Default is float32. + proj : :py:class:`function` + Projection function to apply at each iteration. Default is + non-negative. + tk : float + Initial step size parameter for FISTA. It is updated at each iteration + according to Eq. 4.2 of paper. By default, initialized to 1.0. + + """ + self._initial_tk = tk + + super(HyperSpectralFISTA, self).__init__(psf,mask, dtype, proj, **kwargs) + + self._tk = tk + self._xk = self._image_est + + def reset(self, tk=None): + super(HyperSpectralFISTA, self).reset() + if tk: + self._tk = tk + else: + self._tk = self._initial_tk + self._xk = self._image_est + def _update(self, iter): + self._image_est -= self._alpha * self._grad() + xk = self._form_image() + tk = (1 + np.sqrt(1 + 4 * self._tk**2)) / 2 + self._image_est = xk + (self._tk - 1) / tk * (xk - self._xk) + self._tk = tk + self._xk = xk + + def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs): # load data diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index ff1fc55c..7de40520 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -203,6 +203,7 @@ class ReconstructionAlgorithm(abc.ABC): def __init__( self, psf, + mask, dtype=None, pad=True, n_iter=100, @@ -369,12 +370,13 @@ def set_data(self, data): assert len(data.shape) >= 3, "Data must be at least 3D: [..., width, height, channel]." # assert same shapes - assert np.all( - self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1] - ), "PSF and data shape mismatch" - - if len(data.shape) == 3: - self._data = data[None, None, ...] + # assert np.all( + # self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1] + # ), "PSF and data shape mismatch" + if len(data.shape)==3: + self._data = data.unsqueeze(-1) + # if len(data.shape) == 3: + # self._data = data[None, None, ...] elif len(data.shape) == 4: self._data = data[None, ...] else: diff --git a/lensless/recon/rfft_convolve.py b/lensless/recon/rfft_convolve.py index e7b9be74..a4848e45 100644 --- a/lensless/recon/rfft_convolve.py +++ b/lensless/recon/rfft_convolve.py @@ -24,7 +24,7 @@ class RealFFTConvolve2D: - def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs): + def __init__(self, psf, dtype=None, pad=True, norm=None, rgb=None, **kwargs): """ Linear operator that performs convolution in Fourier domain, and assumes real-valued signals. @@ -135,10 +135,10 @@ def convolve(self, x): Convolve with pre-computed FFT of provided PSF. """ if self.pad: - self._padded_data = self._pad(x) + self._padded_data = self._pad(x).to(self._psf.device) else: if self.is_torch: - self._padded_data = x # .type(self.dtype).to(self._psf.device) + self._padded_data = x else: self._padded_data[:] = x # .astype(self.dtype) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 5ce95f7a..d58eca05 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -17,7 +17,7 @@ from torchvision.transforms import functional as F from lensless.hardware.trainable_mask import prep_trainable_mask, AdafruitLCD from lensless.utils.simulation import FarFieldSimulator -from lensless.utils.io import load_image, load_psf, save_image +from lensless.utils.io import load_image, load_psf, save_image,load_mask from lensless.utils.image import is_grayscale, resize, rgb2gray import re from lensless.hardware.utils import capture @@ -1271,6 +1271,7 @@ def __init__( split, n_files=None, psf=None, + mask=None, rotate=False, # just the lensless image flipud=False, flip_lensed=False, @@ -1409,11 +1410,11 @@ def __init__( if psf is not None: # download PSF from huggingface psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset") - psf, _ = load_psf( + psf = load_psf( psf_fp, shape=lensless.shape, return_float=True, - return_bg=True, + return_bg=False, flip=self.rotate, flip_ud=flipud, bg_pix=(0, 15), @@ -1424,6 +1425,10 @@ def __init__( if single_channel_psf: # replicate across three channels self.psf = self.psf.repeat(1, 1, 1, 3) + if mask is not None: + mask_fp = hf_hub_download(repo_id=huggingface_repo, filename=mask, repo_type="dataset") + mask = load_mask(mask_fp) + self.mask= torch.from_numpy(mask) elif "mask_label" in data_0: self.multimask = True @@ -1563,7 +1568,9 @@ def _get_images_pair(self, idx): # convert to float if lensless_np.dtype == np.uint8: lensless_np = lensless_np.astype(np.float32) / 255 + lensless_np = lensless_np / np.max(lensless_np) lensed_np = lensed_np.astype(np.float32) / 255 + lensed_np = lensed_np / np.max(lensed_np) else: # 16 bit lensless_np = lensless_np.astype(np.float32) / 65535 diff --git a/lensless/utils/image.py b/lensless/utils/image.py index eed00121..e3ebe0c2 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -264,7 +264,7 @@ def get_max_val(img, nbits=None): max_val : int Maximum pixel value. """ - assert img.dtype not in FLOAT_DTYPES + # assert img.dtype not in FLOAT_DTYPES if nbits is None: nbits = int(np.ceil(np.log2(img.max()))) diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 5596befd..1d5e4d1a 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -8,7 +8,7 @@ import os.path import warnings - +import scipy import cv2 import numpy as np from PIL import Image @@ -17,6 +17,10 @@ from lensless.utils.image import bayer2rgb_cc, print_image_info, resize, rgb2gray, get_max_val from lensless.utils.plot import plot_image +def load_mask(fp): + mask = np.load(fp) + return np.expand_dims(mask, axis=0) + def load_image( fp, @@ -121,6 +125,9 @@ def load_image( black_level = np.array(raw.black_level_per_channel[:3]).astype(np.float32) elif "npy" in fp or "npz" in fp: img = np.load(fp) + elif "mat" in fp: + mat = scipy.io.loadmat(fp) + img = mat['psf'][:,:,0] else: img = cv2.imread(fp, cv2.IMREAD_UNCHANGED) @@ -202,7 +209,10 @@ def load_image( else: if dtype is None: dtype = original_dtype - img = img.astype(dtype) + img = img.astype(np.float64) + + img = img[10:260, 35:320-35] + img = img / np.linalg.norm(img) return img @@ -380,7 +390,7 @@ def load_psf( if return_bg: return psf, bg else: - return psf + return psf.astype(np.float64) def load_data( diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 76fbc367..b7571698 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -25,7 +25,7 @@ import pathlib as plib from lensless.eval.benchmark import benchmark import matplotlib.pyplot as plt -from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent +from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent,HyperSpectralFISTA from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, HFDataset from lensless.utils.io import save_image from lensless.utils.image import gamma_correction @@ -35,7 +35,7 @@ from torch.utils.data import Subset -@hydra.main(version_base=None, config_path="../../configs", config_name="benchmark") +@hydra.main(version_base=None, config_path="../../configs", config_name="benchmark_hyperspectral") def benchmark_recon(config): # set seed @@ -86,7 +86,7 @@ def benchmark_recon(config): _, benchmark_dataset = torch.utils.data.random_split( dataset, [train_size, test_size], generator=generator ) - elif dataset == "HFDataset": + elif dataset == "PolarLitis": split_test = "test" if config.huggingface.split_seed is not None: @@ -120,6 +120,7 @@ def benchmark_recon(config): huggingface_repo=config.huggingface.repo, cache_dir=config.huggingface.cache_dir, psf=config.huggingface.psf, + mask = config.huggingface.mask, n_files=n_files, split=split_test, display_res=config.huggingface.image_res, @@ -138,6 +139,8 @@ def benchmark_recon(config): psf = benchmark_dataset.psf[first_psf_key].to(device) else: psf = benchmark_dataset.psf.to(device) + mask = benchmark_dataset.mask.to(device) + else: raise ValueError(f"Dataset {dataset} not supported") @@ -190,6 +193,8 @@ def benchmark_recon(config): ) if algo == "FISTA": model_list.append(("FISTA", FISTA(psf, tk=config.fista.tk))) + if algo == "HyperSpectralFISTA": + model_list.append(("HyperSpectralFISTA", HyperSpectralFISTA(psf,mask, tk=config.fista.tk))) if algo == "GradientDescent": model_list.append(("GradientDescent", GradientDescent(psf))) if algo == "NesterovGradientDescent": @@ -243,7 +248,7 @@ def benchmark_recon(config): :2 ] # take first two in case multimask dataset ground_truth_np = ground_truth.cpu().numpy()[0] - lensless_np = lensless.cpu().numpy()[0] + lensless_np = lensless.cpu().numpy() if crop is not None: ground_truth_np = ground_truth_np[ From 5999d7cf3218b58ec24d04379c87c00feff24b82 Mon Sep 17 00:00:00 2001 From: noakraicer Date: Thu, 29 Aug 2024 16:19:14 +0300 Subject: [PATCH 2/6] final changes --- configs/benchmark_hyperspectral.yaml | 2 +- lensless/recon/gd.py | 6 +++--- lensless/recon/recon.py | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/configs/benchmark_hyperspectral.yaml b/configs/benchmark_hyperspectral.yaml index 5f775875..53671cf4 100644 --- a/configs/benchmark_hyperspectral.yaml +++ b/configs/benchmark_hyperspectral.yaml @@ -32,7 +32,7 @@ huggingface: device: "cuda" # numbers of iterations to benchmark -n_iter_range: [5, 10, 20, 50, 100, 200, 300] +n_iter_range: [2000] # number of files to benchmark n_files: null # null for all files #How much should the image be downsampled diff --git a/lensless/recon/gd.py b/lensless/recon/gd.py index 0b6946d8..c5af193c 100644 --- a/lensless/recon/gd.py +++ b/lensless/recon/gd.py @@ -101,12 +101,12 @@ def reset(self): # torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values # ) / 2 # initialize image estimate as [Batch, Depth, Height, Width, Channels] - self._image_est = torch.zeros((1,250,250,3)) + self._image_est = torch.zeros((1,250,250,3)).to(self._psf.device) # set step size as < 2 / lipschitz Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) - self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values) + self._alpha = 1/4770.13 else: if self._initial_est is not None: @@ -123,7 +123,7 @@ def reset(self): self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0)) def _grad(self): - diff = np.sum(self.mask * self._convolver.convolve(self._image_est), -1) - self._data # (H, W, 1) + diff = torch.sum(self.mask * self._convolver.convolve(self._image_est), axis=-1, keepdims=True) - self._data # (H, W, 1) return self._convolver.deconvolve(diff * self.mask) # (H, W, C) where C is number of hyperspectral channels def _update(self, iter): diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index 7de40520..b17b165f 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -571,6 +571,9 @@ def apply( for i in range(n_iter): self._update(i) + if i%50==0: + img = self._form_image() + if self.compensation_branch is not None and i < self._n_iter - 1: self.compensation_branch_inputs.append(self._form_image()) From 68cc8d555ffd0f33761db63c5a8f55f5066db915 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 6 Sep 2024 16:19:35 +0000 Subject: [PATCH 3/6] Ensure backward compatability. --- lensless/recon/gd.py | 133 +++++++++++++------------ lensless/recon/recon.py | 17 ++-- lensless/recon/rfft_convolve.py | 16 ++- scripts/recon/hyperspectral.py | 166 ++++++++++++++++++++++++++++++++ 4 files changed, 249 insertions(+), 83 deletions(-) create mode 100644 scripts/recon/hyperspectral.py diff --git a/lensless/recon/gd.py b/lensless/recon/gd.py index c5af193c..5c386d08 100644 --- a/lensless/recon/gd.py +++ b/lensless/recon/gd.py @@ -64,7 +64,7 @@ class GradientDescent(ReconstructionAlgorithm): Object for applying projected gradient descent. """ - def __init__(self, psf,mask, dtype=None, proj=non_neg, **kwargs): + def __init__(self, psf, dtype=None, proj=non_neg, **kwargs): """ Parameters @@ -83,30 +83,30 @@ def __init__(self, psf,mask, dtype=None, proj=non_neg, **kwargs): assert callable(proj) self._proj = proj - super(GradientDescent, self).__init__(psf,mask, dtype, **kwargs) + super(GradientDescent, self).__init__(psf, dtype, **kwargs) if self._denoiser is not None: print("Using denoiser in gradient descent.") # redefine projection function self._proj = self._denoiser - self.mask=mask + def reset(self): if self.is_torch: if self._initial_est is not None: self._image_est = self._initial_est else: # initial guess, half intensity image - # psf_flat = self._psf.reshape(-1, self._psf_shape[3]) - # pixel_start = ( - # torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values - # ) / 2 + psf_flat = self._psf.reshape(-1, self._psf_shape[3]) + pixel_start = ( + torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values + ) / 2 # initialize image estimate as [Batch, Depth, Height, Width, Channels] - self._image_est = torch.zeros((1,250,250,3)).to(self._psf.device) + self._image_est = torch.ones_like(self._psf[None, ...]) * pixel_start # set step size as < 2 / lipschitz Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) - self._alpha = 1/4770.13 + self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values) else: if self._initial_est is not None: @@ -123,8 +123,8 @@ def reset(self): self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0)) def _grad(self): - diff = torch.sum(self.mask * self._convolver.convolve(self._image_est), axis=-1, keepdims=True) - self._data # (H, W, 1) - return self._convolver.deconvolve(diff * self.mask) # (H, W, C) where C is number of hyperspectral channels + diff = self._convolver.convolve(self._image_est) - self._data + return self._convolver.deconvolve(diff) def _update(self, iter): self._image_est -= self._alpha * self._grad() @@ -238,76 +238,75 @@ def _update(self, iter): self._xk = xk -def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs): - - # load data - psf, data = load_data(psf_fp=psf_fp, data_fp=data_fp, plot=False, **kwargs) - - # create reconstruction object - recon = GradientDescent(psf, n_iter=n_iter, proj=proj) - - # set data - recon.set_data(data) - - # perform reconstruction - start_time = time.time() - res = recon.apply(plot=False) - proc_time = time.time() - start_time - - if verbose: - print(f"Reconstruction time : {proc_time} s") - print(f"Reconstruction shape: {res.shape}") - return res -class HyperSpectralFISTA(GradientDescent): +class HyperSpectralFISTA(FISTA): """ - Object for applying projected gradient descent with FISTA (Fast Iterative - Shrinkage-Thresholding Algorithm) for acceleration. - - Paper: https://www.ceremade.dauphine.fr/~carlier/FISTA + Applying HyperSpectral FISTA as in: https://github.com/Waller-Lab/SpectralDiffuserCam """ - def __init__(self, psf,mask, dtype=None, proj=non_neg, tk=1.0, **kwargs): + def __init__(self, psf, mask, **kwargs): """ Parameters ---------- - psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` - Point spread function (PSF) that models forward propagation. - Must be of shape (depth, height, width, channels) even if - depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` - to load a PSF from a file such that it is in the correct format. - dtype : float32 or float64 - Data type to use for optimization. Default is float32. - proj : :py:class:`function` - Projection function to apply at each iteration. Default is - non-negative. - tk : float - Initial step size parameter for FISTA. It is updated at each iteration - according to Eq. 4.2 of paper. By default, initialized to 1.0. + mask : + Hyperspectral mask """ - self._initial_tk = tk + # same PSF for all hyperspectral channels + assert psf.shape[-1] == 1 + assert mask.shape[-3:-1] == psf.shape[-3:-1] + self._mask = mask[None, ...] # adding batch dimension - super(HyperSpectralFISTA, self).__init__(psf,mask, dtype, proj, **kwargs) + super(HyperSpectralFISTA, self).__init__(psf, **kwargs) - self._tk = tk - self._xk = self._image_est + def reset(self): + + # TODO set lipschitz constant correctly/differently? + + if self.is_torch: + if self._initial_est is not None: + self._image_est = self._initial_est + else: + # initial guess, half intensity image + psf_flat = self._psf.reshape(-1, self._psf_shape[3]) + pixel_start = ( + torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values + ) / 2 + # initialize image estimate as [Batch, Depth, Height, Width, Channels] + self._image_est = torch.ones_like(self._mask) * pixel_start + + # set step size as < 2 / lipschitz + Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) + H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) + self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values) - def reset(self, tk=None): - super(HyperSpectralFISTA, self).reset() - if tk: - self._tk = tk else: - self._tk = self._initial_tk - self._xk = self._image_est - def _update(self, iter): - self._image_est -= self._alpha * self._grad() - xk = self._form_image() - tk = (1 + np.sqrt(1 + 4 * self._tk**2)) / 2 - self._image_est = xk + (self._tk - 1) / tk * (xk - self._xk) - self._tk = tk - self._xk = xk + if self._initial_est is not None: + self._image_est = self._initial_est + else: + psf_flat = self._psf.reshape(-1, self._psf_shape[3]) + pixel_start = (np.max(psf_flat, axis=0) + np.min(psf_flat, axis=0)) / 2 + # initialize image estimate as [Batch, Depth, Height, Width, Channels] + self._image_est = np.ones_like(self._mask) * pixel_start + + # set step size as < 2 / lipschitz + Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) + H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) + self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0)) + + # TODO how was his value determined? + self._alpha = 1 / 4770.13 + + def _grad(self): + # make sure to sum on correct axis, and apply mask on correct dimensions + diff = ( + np.sum(self._mask * self._convolver.convolve(self._image_est), -1, keepdims=True) + - self._data + ) # (B, D, H, W, 1) + return self._convolver.deconvolve( + diff * self._mask + ) # (H, W, C) where C is number of hyperspectral channels def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs): diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index b17b165f..5f5ce924 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -203,7 +203,6 @@ class ReconstructionAlgorithm(abc.ABC): def __init__( self, psf, - mask, dtype=None, pad=True, n_iter=100, @@ -370,13 +369,11 @@ def set_data(self, data): assert len(data.shape) >= 3, "Data must be at least 3D: [..., width, height, channel]." # assert same shapes - # assert np.all( - # self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1] - # ), "PSF and data shape mismatch" - if len(data.shape)==3: - self._data = data.unsqueeze(-1) - # if len(data.shape) == 3: - # self._data = data[None, None, ...] + assert np.all( + self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1] + ), "PSF and data shape mismatch" + if len(data.shape) == 3: + self._data = data[None, None, ...] elif len(data.shape) == 4: self._data = data[None, ...] else: @@ -571,9 +568,7 @@ def apply( for i in range(n_iter): self._update(i) - if i%50==0: - img = self._form_image() - + if self.compensation_branch is not None and i < self._n_iter - 1: self.compensation_branch_inputs.append(self._form_image()) diff --git a/lensless/recon/rfft_convolve.py b/lensless/recon/rfft_convolve.py index a4848e45..ea057971 100644 --- a/lensless/recon/rfft_convolve.py +++ b/lensless/recon/rfft_convolve.py @@ -82,18 +82,24 @@ def _crop(self, x): ] def _pad(self, v): + + shape = self._padded_shape.copy() + if v.shape[-1] != self._padded_shape[-1]: + # different number of channels in PSF and data + assert v.shape[-1] == 1 or self._padded_shape[-1] == 1 + shape[-1] = v.shape[-1] + if len(v.shape) == 5: batch_size = v.shape[0] - shape = [batch_size] + self._padded_shape - elif len(v.shape) == 4: - shape = self._padded_shape - else: + shape = [batch_size] + shape + elif len(v.shape) != 4: raise ValueError("Expected 4D or 5D tensor") if self.is_torch: vpad = torch.zeros(size=shape, dtype=v.dtype, device=v.device) else: vpad = np.zeros(shape).astype(v.dtype) + vpad[ ..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], : ] = v @@ -135,7 +141,7 @@ def convolve(self, x): Convolve with pre-computed FFT of provided PSF. """ if self.pad: - self._padded_data = self._pad(x).to(self._psf.device) + self._padded_data = self._pad(x) else: if self.is_torch: self._padded_data = x diff --git a/scripts/recon/hyperspectral.py b/scripts/recon/hyperspectral.py new file mode 100644 index 00000000..a236c856 --- /dev/null +++ b/scripts/recon/hyperspectral.py @@ -0,0 +1,166 @@ +""" +Apply gradient descent. + +``` +python scripts/recon/gradient_descent.py +``` + +""" + +import hydra +from hydra.utils import to_absolute_path +import os +import numpy as np +import time +import pathlib as plib +import matplotlib.pyplot as plt +from lensless.utils.io import load_data +from lensless import ( + GradientDescentUpdate, + GradientDescent, + NesterovGradientDescent, + FISTA, +) + + +@hydra.main(version_base=None, config_path="../../configs", config_name="defaults_recon") +def gradient_descent( + config, +): + + # load mask and PSF + mask_fp = "/root/FORKS/LenslessPiCamNoa/data/mask.npy" + mask = np.load(mask_fp) + mask = np.expand_dims(mask, axis=0) + mask = mask.astype(np.float32) + + # load PSF + import scipy + + psf_fp = "/root/FORKS/LenslessPiCamNoa/data/psf.mat" + mat = scipy.io.loadmat(psf_fp) + psf = mat["psf"][:, :, 0] + psf = psf.astype(np.float32) + psf = psf[10:260, 35 : 320 - 35] + psf = psf / np.linalg.norm(psf) + psf = np.expand_dims(psf, axis=0) # add depth + psf = np.expand_dims(psf, axis=-1) # add channels + + # load data + from lensless.utils.io import load_image + + data_fp = "/root/FORKS/LenslessPiCamNoa/data/266_lensless.png" + data = load_image( + data_fp, + return_float=True, + normalize=False, + dtype=np.float32, + ) + # -- add depth and channels dimensions + data = np.expand_dims(data, axis=0) + data = np.expand_dims(data, axis=-1) + + # apply gradient descent + from lensless import HyperSpectralFISTA + + save = config["save"] + if save: + save = os.getcwd() + + recon = HyperSpectralFISTA( + psf, + mask, + norm=None, + # norm="ortho", + ) + recon.set_data(data) + res = recon.apply( + n_iter=500, + disp_iter=50, + save=save, + gamma=1.0, + plot=False, + ) + + if config.torch: + img = res[0].cpu().numpy() + else: + img = res[0] + + if config["display"]["plot"]: + plt.show() + if save: + np.save(plib.Path(save) / "final_reconstruction.npy", img) + print(f"Files saved to : {save}") + + raise ValueError + + psf, data = load_data( + psf_fp=to_absolute_path(config.input.psf), + data_fp=to_absolute_path(config.input.data), + dtype=config.input.dtype, + downsample=config["preprocess"]["downsample"], + bayer=config["preprocess"]["bayer"], + blue_gain=config["preprocess"]["blue_gain"], + red_gain=config["preprocess"]["red_gain"], + plot=config["display"]["plot"], + flip=config["preprocess"]["flip"], + gamma=config["display"]["gamma"], + gray=config["preprocess"]["gray"], + single_psf=config["preprocess"]["single_psf"], + shape=config["preprocess"]["shape"], + use_torch=config.torch, + torch_device=config.torch_device, + ) + + disp = config["display"]["disp"] + if disp < 0: + disp = None + + save = config["save"] + if save: + save = os.getcwd() + + start_time = time.time() + + if config["gradient_descent"]["method"] == GradientDescentUpdate.VANILLA: + recon = GradientDescent(psf) + elif config["gradient_descent"]["method"] == GradientDescentUpdate.NESTEROV: + recon = NesterovGradientDescent( + psf, + p=config["gradient_descent"]["nesterov"]["p"], + mu=config["gradient_descent"]["nesterov"]["mu"], + ) + else: + recon = FISTA( + psf, + tk=config["gradient_descent"]["fista"]["tk"], + ) + + recon.set_data(data) + print(f"Setup time : {time.time() - start_time} s") + + start_time = time.time() + res = recon.apply( + n_iter=config["gradient_descent"]["n_iter"], + disp_iter=disp, + save=save, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + ) + print(f"Processing time : {time.time() - start_time} s") + + if config.torch: + img = res[0].cpu().numpy() + else: + img = res[0] + + if config["display"]["plot"]: + plt.show() + if save: + np.save(plib.Path(save) / "final_reconstruction.npy", img) + print(f"Files saved to : {save}") + + +if __name__ == "__main__": + gradient_descent() From 189aacabf809877678458f164b78d4db01c98070 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 6 Sep 2024 16:22:25 +0000 Subject: [PATCH 4/6] Small edits. --- lensless/recon/gd.py | 4 ++-- scripts/recon/hyperspectral.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lensless/recon/gd.py b/lensless/recon/gd.py index 5c386d08..b7b69e98 100644 --- a/lensless/recon/gd.py +++ b/lensless/recon/gd.py @@ -295,8 +295,8 @@ def reset(self): H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0)) - # TODO how was his value determined? - self._alpha = 1 / 4770.13 + # # TODO how was his value determined? + # self._alpha = 1 / 4770.13 def _grad(self): # make sure to sum on correct axis, and apply mask on correct dimensions diff --git a/scripts/recon/hyperspectral.py b/scripts/recon/hyperspectral.py index a236c856..a3766c89 100644 --- a/scripts/recon/hyperspectral.py +++ b/scripts/recon/hyperspectral.py @@ -1,8 +1,8 @@ """ -Apply gradient descent. +Apply FISTA for hyperspectral data recovery. ``` -python scripts/recon/gradient_descent.py +python scripts/recon/hyperspectral.py ``` """ @@ -70,8 +70,8 @@ def gradient_descent( recon = HyperSpectralFISTA( psf, mask, - norm=None, - # norm="ortho", + # norm=None, + norm="ortho", ) recon.set_data(data) res = recon.apply( From 7937bcec6422f238b8c349d68d90687e1f5805be Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 6 Sep 2024 16:25:24 +0000 Subject: [PATCH 5/6] Clean up example script. --- scripts/recon/hyperspectral.py | 80 +++------------------------------- 1 file changed, 6 insertions(+), 74 deletions(-) diff --git a/scripts/recon/hyperspectral.py b/scripts/recon/hyperspectral.py index a3766c89..d5d2f6d3 100644 --- a/scripts/recon/hyperspectral.py +++ b/scripts/recon/hyperspectral.py @@ -16,10 +16,7 @@ import matplotlib.pyplot as plt from lensless.utils.io import load_data from lensless import ( - GradientDescentUpdate, - GradientDescent, - NesterovGradientDescent, - FISTA, + HyperSpectralFISTA, ) @@ -60,13 +57,12 @@ def gradient_descent( data = np.expand_dims(data, axis=0) data = np.expand_dims(data, axis=-1) - # apply gradient descent - from lensless import HyperSpectralFISTA - + # apply FISTA save = config["save"] if save: save = os.getcwd() + start_time = time.time() recon = HyperSpectralFISTA( psf, mask, @@ -74,6 +70,9 @@ def gradient_descent( norm="ortho", ) recon.set_data(data) + print(f"Setup time : {time.time() - start_time} s") + + start_time = time.time() res = recon.apply( n_iter=500, disp_iter=50, @@ -81,73 +80,6 @@ def gradient_descent( gamma=1.0, plot=False, ) - - if config.torch: - img = res[0].cpu().numpy() - else: - img = res[0] - - if config["display"]["plot"]: - plt.show() - if save: - np.save(plib.Path(save) / "final_reconstruction.npy", img) - print(f"Files saved to : {save}") - - raise ValueError - - psf, data = load_data( - psf_fp=to_absolute_path(config.input.psf), - data_fp=to_absolute_path(config.input.data), - dtype=config.input.dtype, - downsample=config["preprocess"]["downsample"], - bayer=config["preprocess"]["bayer"], - blue_gain=config["preprocess"]["blue_gain"], - red_gain=config["preprocess"]["red_gain"], - plot=config["display"]["plot"], - flip=config["preprocess"]["flip"], - gamma=config["display"]["gamma"], - gray=config["preprocess"]["gray"], - single_psf=config["preprocess"]["single_psf"], - shape=config["preprocess"]["shape"], - use_torch=config.torch, - torch_device=config.torch_device, - ) - - disp = config["display"]["disp"] - if disp < 0: - disp = None - - save = config["save"] - if save: - save = os.getcwd() - - start_time = time.time() - - if config["gradient_descent"]["method"] == GradientDescentUpdate.VANILLA: - recon = GradientDescent(psf) - elif config["gradient_descent"]["method"] == GradientDescentUpdate.NESTEROV: - recon = NesterovGradientDescent( - psf, - p=config["gradient_descent"]["nesterov"]["p"], - mu=config["gradient_descent"]["nesterov"]["mu"], - ) - else: - recon = FISTA( - psf, - tk=config["gradient_descent"]["fista"]["tk"], - ) - - recon.set_data(data) - print(f"Setup time : {time.time() - start_time} s") - - start_time = time.time() - res = recon.apply( - n_iter=config["gradient_descent"]["n_iter"], - disp_iter=disp, - save=save, - gamma=config["display"]["gamma"], - plot=config["display"]["plot"], - ) print(f"Processing time : {time.time() - start_time} s") if config.torch: From 3d53857aa9e6a706c92bfb620b751424540e8767 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 6 Sep 2024 16:36:51 +0000 Subject: [PATCH 6/6] Clean up script and backward compatability. --- lensless/utils/image.py | 2 +- lensless/utils/io.py | 9 +++------ scripts/recon/hyperspectral.py | 23 ++++++++++++++--------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/lensless/utils/image.py b/lensless/utils/image.py index e3ebe0c2..eed00121 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -264,7 +264,7 @@ def get_max_val(img, nbits=None): max_val : int Maximum pixel value. """ - # assert img.dtype not in FLOAT_DTYPES + assert img.dtype not in FLOAT_DTYPES if nbits is None: nbits = int(np.ceil(np.log2(img.max()))) diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 1d5e4d1a..1d505193 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -17,6 +17,7 @@ from lensless.utils.image import bayer2rgb_cc, print_image_info, resize, rgb2gray, get_max_val from lensless.utils.plot import plot_image + def load_mask(fp): mask = np.load(fp) return np.expand_dims(mask, axis=0) @@ -127,7 +128,7 @@ def load_image( img = np.load(fp) elif "mat" in fp: mat = scipy.io.loadmat(fp) - img = mat['psf'][:,:,0] + img = mat["psf"][:, :, 0] else: img = cv2.imread(fp, cv2.IMREAD_UNCHANGED) @@ -209,10 +210,6 @@ def load_image( else: if dtype is None: dtype = original_dtype - img = img.astype(np.float64) - - img = img[10:260, 35:320-35] - img = img / np.linalg.norm(img) return img @@ -390,7 +387,7 @@ def load_psf( if return_bg: return psf, bg else: - return psf.astype(np.float64) + return psf def load_data( diff --git a/scripts/recon/hyperspectral.py b/scripts/recon/hyperspectral.py index d5d2f6d3..be2adafc 100644 --- a/scripts/recon/hyperspectral.py +++ b/scripts/recon/hyperspectral.py @@ -14,10 +14,11 @@ import time import pathlib as plib import matplotlib.pyplot as plt -from lensless.utils.io import load_data +from lensless.utils.io import load_image from lensless import ( HyperSpectralFISTA, ) +import scipy @hydra.main(version_base=None, config_path="../../configs", config_name="defaults_recon") @@ -25,16 +26,23 @@ def gradient_descent( config, ): - # load mask and PSF + # set paths mask_fp = "/root/FORKS/LenslessPiCamNoa/data/mask.npy" + psf_fp = "/root/FORKS/LenslessPiCamNoa/data/psf.mat" + data_fp = "/root/FORKS/LenslessPiCamNoa/data/266_lensless.png" + + ### - put your paths + # mask_fp = None + # psf_fp = None + # data_fp = None + + # load mask and PSF mask = np.load(mask_fp) mask = np.expand_dims(mask, axis=0) mask = mask.astype(np.float32) # load PSF - import scipy - psf_fp = "/root/FORKS/LenslessPiCamNoa/data/psf.mat" mat = scipy.io.loadmat(psf_fp) psf = mat["psf"][:, :, 0] psf = psf.astype(np.float32) @@ -44,9 +52,6 @@ def gradient_descent( psf = np.expand_dims(psf, axis=-1) # add channels # load data - from lensless.utils.io import load_image - - data_fp = "/root/FORKS/LenslessPiCamNoa/data/266_lensless.png" data = load_image( data_fp, return_float=True, @@ -74,8 +79,8 @@ def gradient_descent( start_time = time.time() res = recon.apply( - n_iter=500, - disp_iter=50, + n_iter=100, + disp_iter=20, save=save, gamma=1.0, plot=False,