#!/home/esrf/favre/dev/pynx-py38-power9-env/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2022-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

# TODO: This is a tentative multiprocessing-based implementation of the holotomo script.
#  Still needs:
#  - a LOT of cleanup and OO organisation [https://xkcd.com/2054/]
#  - use a parallel process to monitor CPU, I/O and GPU usage !
#  - handle parallel process failures
#  - etc...

"""
Script for single or multi-distance holo-tomography reconstruction,
tuned for the ESRF ID16A and B beamlines.
"""
import timeit
import time
import sys
import os
import platform
import argparse
import logging
from io import StringIO, TextIOBase
import gc
from multiprocessing import Pool, Process, Queue, JoinableQueue, Array, set_start_method
import ctypes
import psutil
import numpy as np
import h5py
from scipy.ndimage import median_filter
import fabio
from tifffile import imwrite
from ...utils.array import rebin, pad2
from ...utils.math import primes, test_smaller_primes
from pynx.version import get_git_version

_pynx_version = get_git_version()

try:
    # Get the real number of processor cores available
    # os.sched_getaffinity is only available on some *nix platforms
    nproc = len(os.sched_getaffinity(0)) * psutil.cpu_count(logical=False) // psutil.cpu_count(logical=True)
except AttributeError:
    nproc = os.cpu_count()


class DeltaTimeFormatter(logging.Formatter):
    start_time = time.time()

    def format(self, record):
        record.delta = "%+8.2fs" % (time.time() - self.start_time)
        return super().format(record)


def new_logger_stream(logger_name: str, level: int = logging.INFO, flush_to=None, stdout=True):
    """Create / access a logger with a new StringIO output, allowing
    to share log between multiple processes using stream.getvalue().
    :param logger_name: the logger name
    :param level: the logging level
    :param flush_to: an opened file object to fluch the previous stream to, if any
    :param stdout: if True, also output log to stdout
    :return: a tuple (logger, stream)
    """
    logger = logging.getLogger(logger_name)
    if isinstance(flush_to, TextIOBase) and len(logger.handlers):
        h = logger.handlers[-1]
        h.flush()
        flush_to.write(h.stream.getvalue())
        flush_to.flush()
    logger.handlers.clear()
    LOGFORMAT = '%(asctime)s - %(delta)s - %(levelname)8s - %(name)16s - %(message)s'
    fmt = DeltaTimeFormatter(LOGFORMAT)

    if stdout:
        h = logging.StreamHandler(sys.stdout)
        h.setFormatter(fmt)
        logger.addHandler(h)
    logger.setLevel(level)
    stream = StringIO()
    handler = logging.StreamHandler(stream)
    handler.setFormatter(fmt)
    logger.addHandler(handler)
    return logger, stream


def load_align_zoom(ichunk, data_src, prefix, proj_idx, planes, binning, magnification,
                    padding, pad_method, rhapp, reference_plane, spikes_threshold,
                    align_method, align_interval, align_fourier_low_cutoff, prefix_output, nb_chunk,
                    motion, logger, iobs_array, params=None):
    from scipy.interpolate import interp1d
    from tomoscan.esrf.edfscan import EDFTomoScan
    from pynx.holotomo.utils import load_data_pool_init, load_data_kw, zoom_pad_images, zoom_pad_images_init, \
        zoom_pad_images_kw, align_images_kw, align_images_pool_init, remove_spikes
    nz = len(planes)
    t0 = timeit.default_timer()
    logger.info("#" * 48)
    logger.info("# Loading data & align / zoom images for chunk %d/%d:" % (ichunk + 1, nb_chunk))
    logger.info("# [%d,%d...%d] - %d projections" % (proj_idx[0], proj_idx[1],
                                                     proj_idx[-1], len(proj_idx)))
    logger.info("#" * 48)
    dark = None
    for iz in range(nz):
        dark_url = EDFTomoScan.get_darks_url("%s%d_" % (data_src[:-2], planes[iz]))
        dark_name = dark_url[0].file_path()
        logger.info("Loading dark: %s" % dark_name)
        d = fabio.open(dark_name).data
        if binning is not None:
            if binning > 1:
                d = rebin(d, binning)
        ny, nx = d.shape
        if dark is None:
            dark = np.empty((nz, ny, nx), dtype=np.float32)
        dark[iz] = d

    ny, nx = dark.shape[-2:]
    logger.info("Frame size: %d x %d" % (ny, nx))

    nb_proj = len(proj_idx)  # Number of projections to load
    pady, padx = padding
    nb_proj_sharray = len(iobs_array) // (nz * (ny + 2 * pady) * (nx + 2 * padx))  # Actual len of shared array

    ref = np.empty_like(dark)
    for iz in range(nz):
        for k, v in EDFTomoScan.get_flats_url("%s%d_" % (data_src[:-2], planes[iz])).items():
            logger.info("Loading empty reference image [iz=%d, idx=%d]: %s" % (planes[iz], k, v.file_path()))
            d = fabio.open(v.file_path()).data
            if binning is not None:
                if binning > 1:
                    d = rebin(d, binning)
            ref[iz] = d - dark[iz]
            # TODO: load multiple empty beam images (begin/end)
            break
    if spikes_threshold:
        # Also remove spies on reference/empty beam images
        for r in ref:
            remove_spikes(r, threshold=spikes_threshold)
    # Prepare shared arrays for loading & zoom/pad
    dark_old = dark
    dark_array = Array(ctypes.c_float, dark.size)
    dark = np.frombuffer(dark_array.get_obj(), dtype=ctypes.c_float).reshape(dark_old.shape)
    iobs = np.frombuffer(iobs_array.get_obj(), dtype=ctypes.c_float).reshape((nb_proj_sharray, nz,
                                                                              ny + 2 * pady, nx + 2 * padx))
    iobs = iobs[:nb_proj]

    # EDFTomoScan scans the directory, but we use a pattern
    img_name = data_src[:-2] + "%d_/" + prefix + "%d_%04d.edf"

    logger.info("Loading projections images: %s [%d processes]" % (img_name, nproc))
    vkw = [{'i': i, 'idx': proj_idx[i], 'nproj': nb_proj_sharray, 'nx': nx, 'ny': ny, 'planes': planes,
            'img_name': img_name, 'padding': padding, 'binning': binning, 'spikes_threshold': spikes_threshold}
           for i in range(len(proj_idx))]
    with Pool(nproc, initializer=load_data_pool_init, initargs=(iobs_array, dark_array)) as pool:
        pool.map(load_data_kw, vkw)  # , chunksize=1

    dt = timeit.default_timer() - t0
    nbytes = nz * nx * ny * (nb_proj + 2) * 4 * binning ** 2  # number of bytes actually read from disk
    logger.info("Time to load & uncompress data: %4.1fs [%6.2f Mbytes/s]" % (dt, nbytes / dt / 1024 ** 2))
    ################################################################
    # Zoom & register images, keep first distance pixel size and size
    # Do this using multiple process to speedup
    ################################################################
    if nz > 1:
        logger.info("Magnification relative to iz=0: %s" % str(magnification / magnification[0]))
        t0 = timeit.default_timer()
        ref_zoom = zoom_pad_images(ref, magnification, padding, nz)
        vkw = [{'magnification': magnification, 'padding': padding, 'nb_proj_sharray': nb_proj_sharray,
                'pad_method': pad_method, 'nx': nx, 'ny': ny, 'nz': nz, 'i': i} for i in range(nb_proj)]
        with Pool(nproc, initializer=zoom_pad_images_init, initargs=(iobs_array,)) as pool:
            pool.map(zoom_pad_images_kw, vkw)

        logger.info("Zoom & pad images: dt = %6.2fs [padding: (%d,%d), %s]" %
                    (timeit.default_timer() - t0, pady, padx, pad_method))

        # Align images
        t0 = timeit.default_timer()
        if rhapp is not None:
            logger.info("Aligning images: loading shifts from holoCT (rhapp)")
            nb = np.loadtxt(rhapp, skiprows=4, max_rows=1, dtype=int)[2]
            m = np.loadtxt(rhapp, skiprows=5, max_rows=nb * 8, dtype=np.float32).reshape((nb, 4, 2))
            dx_holoct = m[..., 1] / binning
            dy_holoct = m[..., 0] / binning

            dx_holoct = np.take(dx_holoct, proj_idx, axis=0)
            dy_holoct = np.take(dy_holoct, proj_idx, axis=0)
        else:
            dx_holoct, dy_holoct = None, None

        if align_method != 'rhapp':
            logger.info("Aligning images, method: %s" % align_method)
            if align_interval > 1:
                idx = list(range(0, nb_proj, align_interval))
                proj_idx_interval = list(proj_idx[::align_interval])
                if proj_idx_interval[-1] != proj_idx[-1]:
                    proj_idx_interval.append(proj_idx[-1])
                    idx.append(nb_proj - 1)
                proj_idx_interval = np.array(proj_idx_interval)
            else:
                idx = range(nb_proj)
                proj_idx_interval = proj_idx
            # This can sometimes (<1 in 10) fail (hang indefinitely). Why ?
            # res = pool.map(align_images, range(nb_proj))

            # Supply 3 projections so averaging/differentiating can be applied
            # We bin large images for alignment, with upsampling no loss of accuracy.
            vkw = [{'i': i, 'nproj': nb_proj_sharray, 'nz': nz, 'ny': ny, 'nx': nx, 'padding': padding,
                    'upsample_factor': 10, 'low_cutoff': align_fourier_low_cutoff,
                    'low_width': 0.03, 'high_cutoff': None, 'high_width': 0.03} for i in idx]
            align_ok = False
            nb_nok = 0
            while not align_ok:
                if nb_nok >= 4:
                    logger.critical("Alignment: 4 alignment failures, bailing out")
                    sys.exit(1)
                try:
                    res = []
                    with Pool(nproc, initializer=align_images_pool_init, initargs=(iobs_array,)) as pool:
                        results = pool.imap(align_images_kw, vkw, chunksize=1)
                        for i in range(len(vkw)):
                            r = results.next(timeout=20)
                            res.append(r)
                    align_ok = True
                    logger.info("align OK")
                except:
                    logger.info("Timeout, re-trying")
                    nb_nok += 1

            # print(res)

            dx = np.zeros((len(vkw), nz), dtype=np.float32)
            dy = np.zeros((len(vkw), nz), dtype=np.float32)
            for i in range(len(vkw)):
                dx[i] = res[i][0]
                dy[i] = res[i][1]

            # Take into account random motion displacements if available
            # When preset, the smooth displacements can be obtained
            # by comparing corrected displacements dmx[iz]=(dx+mx)[iz] between distances.
            # Once we have dmx_fit[iz], the real positions to use
            # are dmx_fit[iz] - mx[iz]
            # Finally in order to
            mx = np.zeros((nb_proj, nz), dtype=np.float32)
            my = np.zeros((nb_proj, nz), dtype=np.float32)
            if motion:
                logger.info("Loading motion correction files: %s" % data_src[:-2] + "%d_/correct.txt")
                for iz in range(nz):
                    m = np.loadtxt("%s%d_/correct.txt" % (data_src[:-2], iz + 1), dtype=np.float32)
                    my[:, iz] = m[:, 1][proj_idx] / binning
                    mx[:, iz] = m[:, 0][proj_idx] / binning
                dx += np.take(mx, idx, axis=0)
                dy += np.take(my, idx, axis=0)

            # Switch reference plane
            izref = list(planes).index(reference_plane)
            for iz in range(nz):
                if iz != izref:
                    dx[:, iz] -= dx[:, izref]
                    dy[:, iz] -= dy[:, izref]
            dx[:, izref] = 0
            dy[:, izref] = 0

            # Keep copy of raw positions before fit/filtering
            dxraw = dx.copy()
            dyraw = dy.copy()

            if 'fit' in align_method:
                # Use polyfit to smooth shifts ? #####################################
                # TODO: use shift corrections common to all parallel optimisations with a prior determination ?
                # First smooth using a median filter to remove outliers
                dx = median_filter(dx, (3, 1))
                dy = median_filter(dy, (3, 1))
                for iz in range(nz):
                    if iz != izref:
                        polx = np.polynomial.polynomial.polyfit(proj_idx_interval, dx[:, iz], 6)
                        poly = np.polynomial.polynomial.polyfit(proj_idx_interval, dy[:, iz], 6)
                        dx[:, iz] = np.polynomial.polynomial.polyval(proj_idx_interval, polx)
                        dy[:, iz] = np.polynomial.polynomial.polyval(proj_idx_interval, poly)
            elif 'median' in align_method:
                dx = median_filter(dx, (3, 1))
                dy = median_filter(dy, (3, 1))

            # interpolate missing points
            if align_interval > 1:
                dx0, dy0 = dx, dy
                dx = np.empty((nb_proj, nz), dtype=np.float32)
                dy = np.empty((nb_proj, nz), dtype=np.float32)
                for iz in range(nz):
                    dx[:, iz] = interp1d(proj_idx_interval, dx0[:, iz], kind='quadratic')(proj_idx)
                    dy[:, iz] = interp1d(proj_idx_interval, dy0[:, iz], kind='quadratic')(proj_idx)

            # Save alignments to an hdf5 file
            # Assume this is sequential, so no access conflict
            # TODO: use a separate worker to avoid any conflict
            with h5py.File(prefix_output + '_shifts.h5', 'a') as h:
                # Save parameters used for the script.
                if 'parameters' not in h and params is not None:
                    config = h.create_group("parameters")
                    config.attrs['NX_class'] = 'NXcollection'
                    for k, v in vars(params).items():
                        if v is not None:
                            if type(v) is dict:
                                # This can happen if complex configuration is passed on
                                if len(v):
                                    kd = config.create_group(k)
                                    kd.attrs['NX_class'] = 'NXcollection'
                                    for kk, vv in v.items():
                                        kd.create_dataset(kk, data=vv)
                            else:
                                config.create_dataset(k, data=v)
                    # Also save magnification
                    config.create_dataset("magnification", data=magnification)

                proj_idx1 = np.array(proj_idx, dtype=int)
                dx1 = dx.copy()
                dy1 = dy.copy()
                for iz in range(1, nz):
                    # Always save the displacements relative to the first distance,
                    # so we can see if the curves looks nice (after motion correction)
                    dx1[:, iz] -= dx1[:, 0]
                    dy1[:, iz] -= dy1[:, 0]
                    dxraw[:, iz] -= dxraw[:, 0]
                    dyraw[:, iz] -= dyraw[:, 0]
                if motion:
                    mx1 = mx.copy()
                    my1 = my.copy()
                if 'dx' in h:
                    dx1 = np.append(dx1, h['dx'][()], axis=0)
                    dy1 = np.append(dy1, h['dy'][()], axis=0)
                    proj_idx1 = np.append(proj_idx1, h['idx'][()])
                    idx = np.argsort(proj_idx1)
                    dx1 = np.take(dx1, idx, axis=0)
                    dy1 = np.take(dy1, idx, axis=0)
                    proj_idx1 = proj_idx1[idx]
                    del h['dx'], h['dy'], h['idx']
                    if motion:
                        mx1 = np.append(mx1, h['mx'][()], axis=0)
                        my1 = np.append(my1, h['my'][()], axis=0)
                        mx1 = np.take(mx1, idx, axis=0)
                        my1 = np.take(my1, idx, axis=0)
                        del h['mx'], h['my']
                    dxraw = np.append(dxraw, h['dx_raw'][()], axis=0)
                    dyraw = np.append(dyraw, h['dy_raw'][()], axis=0)
                    proj_idxraw = np.append(proj_idx_interval, h['idx_raw'][()])
                    idx = np.argsort(proj_idxraw)
                    dxraw = np.take(dxraw, idx, axis=0)
                    dyraw = np.take(dyraw, idx, axis=0)
                    proj_idxraw = proj_idxraw[idx]
                    del h['dx_raw'], h['dy_raw'], h['idx_raw']
                else:
                    proj_idxraw = proj_idx_interval
                h.create_dataset('dx', data=dx1)
                h.create_dataset('dy', data=dy1)
                h.create_dataset('idx', data=proj_idx1)
                if motion:
                    h.create_dataset('mx', data=mx1)
                    h.create_dataset('my', data=my1)
                h.create_dataset('dx_raw', data=dxraw)
                h.create_dataset('dy_raw', data=dyraw)
                h.create_dataset('idx_raw', data=proj_idxraw)
        else:
            logger.info("Aligning images: using shifts imported from holoCT (rhapp)")
            if dx_holoct is None:
                raise RuntimeError("Image alignment: method is rhapp but align_rhapp was not supplied")
            dx = dx_holoct
            dy = dy_holoct
            mx = np.zeros_like(dx)
            my = np.zeros_like(dx)

        # Now the fit/filtering & export is done, add back the random motion
        dx -= mx
        dy -= my

        # # Now the fit/filtering & export is done, add back the random motion, corrected for the reference plane %-)
        # for iz in range(nz):
        #     if iz != reference_plane:
        #         dx[:, iz] -= mx[:, iz] - mx[:, reference_plane]
        #         dy[:, iz] -= my[:, iz] - my[:, reference_plane]

        logger.info("Align images: dt = %6.2fs" % (timeit.default_timer() - t0))
    else:
        if padding:
            # Move data to centre of frame & pad
            for i in range(len(iobs)):
                iobs[i, 0] = pad2(iobs[i, 0, :ny, :nx], pad_width=padding, mode=pad_method, mask2neg=True)
            ref_zoom = np.empty((nz, ny + 2 * pady, nx + 2 * padx), dtype=np.float32)
            ref_zoom[0] = pad2(ref[0], pad_width=padding, mode=pad_method, mask2neg=True)
        else:
            ref_zoom = ref
        dx, dy = None, None
    logger.info("Returning Iobs and reference frames")
    return ref_zoom, dx, dy


def find_cor_iobs_sliding_window(data_src, reference_plane, motion, logger, q: Queue):
    from tomoscan.esrf.edfscan import EDFTomoScan
    from nabu.resources.dataset_analyzer import analyze_dataset
    from nabu.pipeline.estimators import CORFinder
    di = analyze_dataset(data_src[:-2] + "%d_" % reference_plane)
    angular_step = EDFTomoScan.get_scan_range(data_src) / (di.n_angles - 1)
    di.rotation_angles = np.deg2rad(np.linspace(0, angular_step * di.n_angles, di.n_angles, True))
    logger.info("Launching Nabu's CORFinder/sliding-window (motion=%s)" % str(motion))
    cf = CORFinder(di)
    cor_sliding = cf.find_cor(method="sliding-window")

    if motion:
        logger.info("Taking int account motion correction for CoR: %s" % data_src[:-2] + "%d_/correct.txt")
        m = np.loadtxt("%s%d_/correct.txt" % (data_src[:-2], reference_plane), dtype=np.float32)
        mx = m[:, 0].take(cf._radios_indices)
        logger.info("Sliding window: CoR: %6.2f , motion correction: %5.2f  %5.2f => %6.2f" %
                    (cor_sliding, mx[0], mx[1], cor_sliding - 0.5 * (mx[1] + mx[0])))
        cor_sliding -= 0.5 * (mx[1] + mx[0])

    q.put(cor_sliding)


def find_cor_phase_registration(d0, d180, i0, i180):
    from ...utils.registration import phase_cross_correlation
    from scipy.fft import fft2, ifft2
    from scipy.ndimage import fourier_shift
    d0 -= d0.mean()
    d180 -= d180.mean()
    dxy, err, dphi = phase_cross_correlation(d0, d180, upsample_factor=10)
    dy, dx = dxy
    # Shift image and re-correlate after cropping wrapped borders
    d0 = ifft2(fourier_shift(fft2(d0), [-dy / 2, -dx / 2])).real
    d180 = ifft2(fourier_shift(fft2(d180), [dy / 2, dx / 2])).real
    adx = int(abs(np.round(dx / 2)))
    if adx >= 1:
        d0 = d0[:, adx:-adx]
        d180 = d180[:, adx:-adx]
    ady = int(abs(np.round(dy / 2)))
    if ady >= 1:
        d0 = d0[ady:-ady]
        d180 = d180[ady:-ady]
    dxy, err, dphi = phase_cross_correlation(d0, d180, upsample_factor=20)
    dy1, dx1 = dxy
    return dx + dx1, dy + dy1, err


def find_cor_phase_registration_kw(kwargs):
    return find_cor_phase_registration(**kwargs)


def find_cor_sino(s, side, window_width, neighborhood, shift_value, logger):
    from nabu.estimation.cor_sino import SinoCor
    s = s.astype(np.float32, copy=False)
    # Perform some filtering ?
    s = median_filter(s, 3)
    s -= s.mean(axis=0)
    cor_finder = SinoCor(sinogram=s, logger=logger)
    cor_finder.estimate_cor_coarse(side=side, window_width=window_width)
    return cor_finder.estimate_cor_fine(neighborhood=neighborhood, shift_value=shift_value)


def find_cor_sino_kw(kwargs):
    return find_cor_sino(**kwargs)


def find_cor(data_src, prefix, proj_idx, binning, reference_plane, method,
             motion, ph_shape, ph_array_dtype, logger, ph_array):
    """
    Find the Center of Rotation using tomotools & Nabu, with the original iobs or the
    phased projections or sinograms
    """
    from tomoscan.esrf.edfscan import EDFTomoScan
    # from nabu.estimation.cor import CenterOfRotationSlidingWindow
    logger.info("Searching for Centre of rotation [NB: reported values are not binned]")

    # Default CoR using Nabu's CORFinder ##################################
    # TODO: handle random motion correction files
    logger.info("Launching Nabu's CORFinder/sliding-window")
    q_iobs_sliding_window = Queue()
    p_iobs_sliding_window = Process(target=find_cor_iobs_sliding_window,
                                    args=(data_src, reference_plane, motion, logger, q_iobs_sliding_window))
    p_iobs_sliding_window.start()

    # CoR using +180 phased projections (even ones) ##################################
    # Try to get ~20 couple of projections if we have more than 180°
    scan_range, tomo_n = EDFTomoScan.get_scan_range(data_src), EDFTomoScan.get_tomo_n(data_src)
    tomo_angular_step_orig = scan_range / (tomo_n - 1)  # Keep in degrees to match user input
    tomo_angular_step = tomo_angular_step_orig * (proj_idx.max() - proj_idx.min()) / (len(proj_idx) - 1)
    nproj, ny, nx = ph_shape
    angles = np.deg2rad(np.arange(nproj) * tomo_angular_step)
    ph = np.frombuffer(ph_array.get_obj(), dtype=ph_array_dtype).reshape(ph_shape)

    angles -= angles.min()
    nb_above_180 = (angles > np.pi).sum()
    step = 2  # We only want even frames
    if nb_above_180 > 20:
        step = nb_above_180 // 20
        step += step % 2
    idx0 = np.arange(0, nb_above_180 + 1, step)
    vcor = []
    vkw = []
    for i0 in idx0:
        da = (angles[::2] - angles[i0] - np.pi) % (2 * np.pi)
        i180 = np.argmin(np.minimum(da, 2 * np.pi - da)) * 2
        # Projections, without borders
        d0 = ph[i0, 100:-100, 100:-100].astype(np.float32)
        d180 = np.fliplr(ph[i180, 100:-100, 100:-100]).astype(np.float32)
        vkw.append({'d0': d0, 'd180': d180, 'i0': i0, 'i180': i180})
    nprocmax = int(np.ceil(len(vkw) / 2))
    with Pool(nproc if nproc < nprocmax else nprocmax) as pool:
        results = pool.imap(find_cor_phase_registration_kw, vkw)
        for i in range(len(vkw)):
            dx, dy, err = results.next()
            logger.info("Cross-correlation between #%4d and #%4d: dx=%5.2f dy=%5.2f err=%6e => cor=%6.2f/binning" %
                        (vkw[i]['i0'], vkw[i]['i180'], dx, dy, err, (dx + nx) / 2 * binning))
            vcor.append((dx + nx) / 2 * binning)

    vcor = np.array(vcor)
    with np.printoptions(precision=2, suppress=False, floatmode='fixed', linewidth=200):
        str_vcor = str(vcor)
    logger.info("CoR using registration of phased projections at +180°: " + str_vcor)
    cor_ph_reg = np.median(vcor)
    corstd = np.std(vcor)
    logger.info("CoR using registration of phased projections at +180°: <%6.2f> +/-<%6.3f>" % (cor_ph_reg, corstd))

    if method == "phase_registration":
        cor = cor_ph_reg

    # Evaluate CoR from half-sinogram (even projections)  ##################################
    if np.isclose(scan_range, 360, atol=0.2):
        # TODO: also use sinocor if there is enough range beyond 180° ?
        side = "right"
        window_width = None
        neighborhood = 7
        shift_value = 0.1

        # Extract sinograms with even projections #################
        vy = np.arange(ny // 8, 7 * ny // 8 + 1, (3 * ny) // 80, dtype=int)
        sino = ph[::2, ny // 8: 7 * ny // 8 + 1:(3 * ny) // 80]
        # We need an even shape
        if len(sino) % 2:
            sino = sino[:-1]
        # We use N layers and use the median value
        vcor_sino = []
        logger.info("Searching for CoR using %d sinograms along the sample height" % sino.shape[1])
        vkw = []
        for i in range(sino.shape[1]):
            vkw.append({'s': sino[:, i], 'side': side, 'window_width': window_width,
                        'neighborhood': neighborhood, 'shift_value': shift_value, 'logger': logger})

        with Pool(nproc if nproc < nprocmax else nprocmax) as pool:
            results = pool.imap(find_cor_sino_kw, vkw)
            for i in range(len(vkw)):
                vcor_sino.append(results.next() * binning)
                logger.info("SinoCoR for layer y=%4d:  cor=%6.2f/binning" % (vy[i], vcor_sino[-1]))
        vcor_sino = np.array(vcor_sino)
        cor_sino = np.median(vcor_sino)
        corstd = np.std(vcor_sino)
        with np.printoptions(precision=2, suppress=False, floatmode='fixed', linewidth=200):
            str_vcor = str(vcor_sino)
        logger.info("CoR for the different layers: " + str_vcor)
        logger.info("CoR using Nabu's SinoCor: <%6.2f> +/-%6.3f" % (cor_sino, corstd))
        if method == "sino":
            cor = cor_sino
    else:
        logger.info("Scan range is %5.1f°, not performing Sino alignment" % abs(scan_range))
        if method == "sino":
            logger.warn("CoR method=sino but scan range too small - revert to default method")
            method = "iobs_sliding_window"

    # Get back result from iobs_sliding_window process
    cor_sliding = q_iobs_sliding_window.get()
    p_iobs_sliding_window.join()
    if method == "iobs_sliding_window":
        cor = cor_sliding
    logger.info("CORFinder/sliding-window result: cor=%6.2f/binning" % cor_sliding)

    logger.info("Returning CoR using selected method [%s]: %6.2f" % (method, cor))
    return cor

    if False:
        # Evaluate CoR using multiple differential projections #####
        # Tomo angular parameters
        scan_range, tomo_n = EDFTomoScan.get_scan_range(data_src), EDFTomoScan.get_tomo_n(data_src)
        tomo_angular_step_orig = scan_range / (tomo_n - 1)  # Keep in degrees to match user input
        nproj = len(proj_idx)
        tomo_angular_step = tomo_angular_step_orig * (proj_idx.max() - proj_idx.min()) / (nproj - 1)
        angles = np.deg2rad(np.arange(nproj) * tomo_angular_step)
        angles -= angles.min()
        nb_above_180 = (angles > np.pi).sum()
        step = 1
        if nb_above_180 > 20:
            step = nb_above_180 // 20
            step += step % 2
        idx0 = np.arange(0, nb_above_180 + 1, step)
        print(idx0)
        vcor = []
        v_projs = []
        for i0 in idx0:
            img_name = data_src[:-2] + "%d_/" + prefix + "%d_%04d.edf"
            i180 = np.argmin(np.abs(angles - angles[i0] - np.pi))
            print(i0, i180)
            # Load projections, difference with neighbouring ones
            d0 = fabio.open(img_name % (reference_plane, reference_plane, i0)).data
            d0m = [d0]
            for i in [i0 - 1, i0 + 1]:
                if 0 <= i < tomo_n:
                    d0m.append(fabio.open(img_name % (reference_plane, reference_plane, i)).data)
            d0 = d0.astype(np.float32) - np.array(d0m).mean(axis=0)

            d180 = fabio.open(img_name % (reference_plane, reference_plane, i180)).data
            d180m = [d180]
            for i in [i180 - 1, i180 + 1]:
                if 0 <= i < tomo_n:
                    d180m.append(fabio.open(img_name % (reference_plane, reference_plane, i)).data)
            d180 = d180.astype(np.float32) - np.array(d180m).mean(axis=0)

            if binning is not None:
                if binning > 1:
                    d0 = rebin(d0, binning)
                    d180 = rebin(d180, binning)

            d180 = np.fliplr(d180)
            cor = CenterOfRotationSlidingWindow(logger=logger).find_shift(d0, d180, 'center') + d0.shape[1] / 2
            if hasattr(cor, "__iter__"):
                cor = cor[0]
            print("CoR on derivative projections / sliding window: i=%4d, %4d  CoR=%6.2f" % (i0, i180, cor))
            vcor.append(cor)
            v_projs.append(d0)
            v_projs.append(d180)
        with h5py.File("cor.h5", "w") as h:
            h.create_dataset('projs', data=np.array(v_projs))
        logger.info("CoR using derivative projections: %s" % str(vcor))
        cor = np.median(np.array(vcor)) * binning
        print("Final CoR based on %d +180° derivative projections: %6.2f" % (len(idx0), cor))
        return cor


def find_cor_half(binning, logger, ph_array, ph_shape, ph_array_dtype, proj_idx, data_src):
    """
    Find the Center of Rotation using Nabu, with half of the projections
    (all the even ones) reconstructed.

    :param logger: logger object
    :return: the center of rotation (multiplied by the binning value)
    """
    logger.info("Finding rot center from phased projections.")
    from tomoscan.esrf.edfscan import EDFTomoScan
    from nabu.estimation.cor_sino import SinoCor
    from nabu.estimation.cor import (CenterOfRotation, CenterOfRotationAdaptiveSearch,
                                     CenterOfRotationSlidingWindow, CenterOfRotationGrowingWindow)

    # Tomo angular parameters
    scan_range, tomo_n = EDFTomoScan.get_scan_range(data_src), EDFTomoScan.get_tomo_n(data_src)
    tomo_angular_step_orig = scan_range / (tomo_n - 1)  # Keep in degrees to match user input
    tomo_angular_step = tomo_angular_step_orig * (proj_idx.max() - proj_idx.min()) / (len(proj_idx) - 1)
    nproj, ny, nx = ph_shape
    angles = np.deg2rad(np.arange(nproj) * tomo_angular_step)
    ph = np.frombuffer(ph_array.get_obj(), dtype=ph_array_dtype).reshape(ph_shape)
    # Evaluate CoR from half-sinogram (even projection)
    side = "right"
    window_width = None
    neighborhood = 7
    shift_value = 0.1

    # Extract sinograms with even projections #################
    sino = ph[::2, ny // 4: 3 * ny // 4 + 1:ny // 20]
    # We need an even shape
    if len(sino) % 2:
        sino = sino[:-1]
    # We use N layers and use the median value
    vcor_sino = []
    v_sino = []
    for i in range(sino.shape[1]):
        s = sino[:, i].astype(np.float32)
        # Perform some filtering
        s = median_filter(s, (3, 1))
        s -= s.mean(axis=0)
        cor_finder = SinoCor(sinogram=s, logger=logger)
        cor_finder.estimate_cor_coarse(side=side, window_width=window_width)
        vcor_sino.append(cor_finder.estimate_cor_fine(neighborhood=neighborhood, shift_value=shift_value))
        v_sino.append(s)
    vcor_sino = np.array(vcor_sino)
    cor = np.median(vcor_sino)
    with np.printoptions(precision=2, suppress=False, floatmode='fixed'):
        str_vcor = str(vcor_sino)
    logger.info("CoR for different layers using sinograms: " + str_vcor)

    if False:
        # CoR using +180 projections # Extract sinograms with even projections #################
        # Try to get ~20 couple of projections if we have more than 180°
        angles -= angles.min()
        nb_above_180 = (angles > np.pi).sum()
        step = 2  # We only want even frames
        if nb_above_180 > 20:
            step = nb_above_180 // 20
            step += step % 2
        idx0 = np.arange(0, nb_above_180 + 1, step)
        print(idx0)
        vcor = []
        v_projs = []
        for i0 in idx0:
            i180 = np.argmin(np.abs(angles - angles[i0] - np.pi))
            print(i0, i180)
            # Projections, without borders
            d0 = ph[i0, 100:-100, 100:-100].copy()
            d180 = np.fliplr(ph[i180, 100:-100, 100:-100]).copy()
            d0 -= d0.mean()
            d180 -= d180.mean()
            cor = CenterOfRotationSlidingWindow(logger=logger).find_shift(d0, d180, 'center') + nx / 2
            if hasattr(cor, "__iter__"):
                cor = cor[0]
            vcor.append(cor)
            # cor2 = CenterOfRotationGrowingWindow(logger=logger).find_shift(d0, d180, 'center') + nx / 2
            # if hasattr(cor2, "__iter__"):
            #     cor2 = cor2[0]
            # vcor.append(cor2)
            v_projs += [d0, d180]
        logger.info("CoR for couple of projections at +180° : " + str(vcor))
        with h5py.File("cor.h5", "w") as h:
            h.create_dataset('sino', data=np.array(v_sino))
            h.create_dataset('projs', data=np.array(v_projs))

        cor = np.median(vcor)
    logger.info("Final rotation center using median value: %8.2f" % (cor * binning))
    return cor * binning


def prep_holotomodata(iobs_array, iobs_shape, ref, pixel_size, wavelength, detector_distance,
                      dx, dy, proj_idx, padding, algorithm, gpu_mem_f, pu, logger, data_old=None):
    # TODO: keep & reuse Holotomo and HolotomoData object
    from pynx.holotomo import HoloTomoData
    # Iobs from shared memory buffer
    nproj, nz, ny, nx = iobs_shape
    nproj_sharray = len(iobs_array) // (nz * ny * nx)  # Actual len of shared array
    iobs = np.frombuffer(iobs_array.get_obj(), dtype=ctypes.c_float).reshape((nproj_sharray, nz, ny, nx))
    iobs = iobs[:nproj]
    if True:  # stack_size is None:
        # Estimate how much memory will be used
        mem = pu.cu_device.total_memory()
        logger.info("Available GPU memory: %6.2fGB  data size: %6.2fGB" % (mem / 1024 ** 3, iobs.nbytes / 1024 ** 3))
        # Required memory: data, iobs, object (complex + phase), 1 copy of psi per projection (DRAP/DM/RAAR)
        mem_req = 4 * nz * nx * ny  # Iobs for the two stacks
        mem_req += 8 * nz * nx * ny  # Psi
        mem_req += 12 * nx * ny  # object projections & phase
        if 'RAAR' in algorithm or 'DM' in algorithm or 'DRAP' in algorithm:
            mem_req += 8 * nz * nx * ny
        # We need two stacks in memory (one for computing, the other to swap)
        mem_req *= 2
        if nx > 4096 or ny > 2048:
            # pyvkfft requires an extra buffer for inplace transforms, for large arrays
            mem_req += 8 * nz * nx * ny
        logger.info("Estimated memory requirement: %8.4fGB/projection" % (mem_req / 1024 ** 3))
        # *2 to keep a margin of error, *gpu_mem_f as a workaround
        stack_size = int(np.ceil((mem - 0.5 * 1024 ** 3) / (mem_req * 2 * gpu_mem_f)))
        if nproj // stack_size == 2:
            # We need either n=1 or n>=3 stacks for swapping
            stack_size = int(np.ceil(nproj // 3))
        logger.info("Using stack size = %d" % stack_size)
    ################################################################
    # Create HoloTomoData and HoloTomo objects
    ################################################################
    if data_old is not None:
        logger.info("Re-using existing HolotomoData object with pinned memory")
        data = data_old
        data.replace_data(iobs, ref, dx=dx, dy=dy, idx=proj_idx, pu=pu)
    else:
        logger.info("Creating HolotomoData object and allocating projections in pinned memory")
        data = HoloTomoData(iobs, ref, pixel_size_detector=pixel_size, wavelength=wavelength,
                            detector_distance=detector_distance, stack_size=stack_size,
                            dx=dx, dy=dy, idx=proj_idx, padding=padding, pu=pu)
    return data


def run_algorithms(data, iobs_shape, padding, delta_beta, algorithm, ph_array, ph_shape,
                   i0, nb_chunk, obj_smooth, obj_inertia, obj_min, obj_max, nb_probe,
                   sino_filter, pu, logger, old_ht=None, return_probe=False):
    from pynx.holotomo import HoloTomo
    from pynx.holotomo.operator import AP, DRAP, DM, RAAR, ScaleObjProbe, \
        BackPropagatePaganin, BackPropagateCTF, FreePU, MemUsage, SinoFilter

    # Iobs from shared memory buffer
    nproj, nz, ny, nx = iobs_shape

    ################################################################
    # Use coherent probe modes ?
    ################################################################
    probe = np.ones((nz, nb_probe, ny, nx), dtype=np.complex64)
    if nb_probe > 1:
        logger.info("Using %d coherent probe modes with linear ramp coefficients" % nb_probe)
        # Use linear ramps for the probe mode coefficients
        coherent_probe_modes = np.zeros((nproj, nz, nb_probe))
        dn = nproj // (nb_probe - 1)
        for iz in range(nz):
            for i in range(nb_probe - 1):
                if i < (nb_probe - 2):
                    coherent_probe_modes[i * dn:(i + 1) * dn, iz, i] = np.linspace(1, 0, dn)
                    coherent_probe_modes[i * dn:(i + 1) * dn, iz, i + 1] = np.linspace(0, 1, dn)
                else:
                    n = nproj - i * dn
                    coherent_probe_modes[i * dn:, iz, i] = np.linspace(1, 0, n)
                    coherent_probe_modes[i * dn:, iz, i + 1] = np.linspace(0, 1, n)
    else:
        coherent_probe_modes = False

    ################################################################
    # Create HoloTomo object
    ################################################################
    if old_ht is None:
        logger.info("Creating Holotomo object")
        p = HoloTomo(data=data, obj=None, probe=probe, coherent_probe_modes=coherent_probe_modes, pu=pu)
    else:
        logger.info("Re-using Holotomo object")
        p = old_ht
        p._cu_timestamp_counter, p._cl_timestamp_counter = -1, -1
        p.cycle = 0

    ################################################################
    # Algorithms
    ################################################################
    logger.info("Starting algorithms")
    t0 = timeit.default_timer()
    # print(nz, 1, ny, nx, p.data.nz, p.nb_probe, p.data.ny, p.data.nx)
    p = ScaleObjProbe() * p
    beta = 0.75
    show_obj_probe = False
    probe_inertia = 0.01

    db = delta_beta
    update_obj = True
    update_probe = True
    verbose = 10
    for algo in algorithm.split(",")[::-1]:
        if "=" in algo:
            logger.info("Changing parameter? %s" % algo)
            k, v = algo.split("=")
            if k == "delta_beta":
                db = eval(v)
                if db == 0:
                    db = None
                elif db == 1:
                    db = delta_beta
                else:
                    delta_beta = db
            elif k == "verbose":
                verbose = int(eval(v))
            elif k == "beta":
                beta = eval(v)
            elif k == "obj_smooth":
                obj_smooth = eval(v)
            elif k == "obj_inertia":
                obj_inertia = eval(v)
            elif k == "obj_min":
                obj_min = eval(v)
            elif k == "obj_max":
                obj_max = eval(v)
            elif k == "probe_inertia":
                probe_inertia = eval(v)
            elif k == "probe":
                update_probe = eval(v)
            elif k == "obj":
                update_obj = eval(v)
            else:
                logger.info("WARNING: did not understand algorithm step: %s" % algo)
        elif "paganin" in algo.lower():
            logger.info("Paganin back-projection")
            p = BackPropagatePaganin(delta_beta=delta_beta) * p
            # p.set_probe(np.ones((nz, nb_probe, ny, nx)))
            p = ScaleObjProbe() * p
        elif "ctf" in algo.lower():
            logger.info("CTF back-projection")
            p = BackPropagateCTF(delta_beta=delta_beta) * p
            # p.set_probe(np.ones((nz, nb_probe, ny, nx)))
            p = ScaleObjProbe() * p
        else:
            logger.info("Algorithm step: %s" % algo)
            dm = DM(update_object=update_obj, update_probe=update_probe,
                    calc_llk=verbose, delta_beta=db, obj_min=obj_min, obj_max=obj_max,
                    obj_smooth=obj_smooth, obj_inertia=obj_inertia, weight_empty=1,
                    show_obj_probe=show_obj_probe, probe_inertia=probe_inertia)
            ap = AP(update_object=update_obj, update_probe=update_probe,
                    calc_llk=verbose, delta_beta=db, obj_min=obj_min, obj_max=obj_max,
                    obj_smooth=obj_smooth, obj_inertia=obj_inertia, weight_empty=1,
                    show_obj_probe=show_obj_probe, probe_inertia=probe_inertia)
            apn = AP(update_object=update_obj, update_probe=update_probe,
                     calc_llk=verbose, delta_beta=db, obj_min=obj_min, obj_max=obj_max,
                     obj_smooth=obj_smooth, obj_inertia=obj_inertia,
                     weight_empty=1, show_obj_probe=show_obj_probe,
                     iobs_normalise=True, probe_inertia=probe_inertia)
            raar = RAAR(update_object=update_obj, update_probe=update_probe, beta=beta,
                        calc_llk=verbose, delta_beta=db, obj_min=obj_min, obj_max=obj_max,
                        obj_smooth=obj_smooth, obj_inertia=obj_inertia, weight_empty=1,
                        show_obj_probe=show_obj_probe, probe_inertia=probe_inertia)
            drap = DRAP(update_object=update_obj, update_probe=update_probe, beta=beta,
                        calc_llk=verbose, delta_beta=db, obj_min=obj_min, obj_max=obj_max,
                        obj_smooth=obj_smooth, obj_inertia=obj_inertia, weight_empty=1,
                        show_obj_probe=show_obj_probe, probe_inertia=probe_inertia)
            p = eval(algo.lower()) * p

    if sino_filter is not None:
        p = SinoFilter(sino_filter) * p

    # p = FreePU() * MemUsage() * p
    logger.info("Algorithms: dt = %6.2fs" % (timeit.default_timer() - t0))
    # We use float32 because float16 operations are extremely slow
    idx, ph0 = p.get_obj_phase_unwrapped(crop_padding=True, dtype=np.float32, sino_filter=None)
    logger.info("Got unwrapped phases (sino_filter:%s): dt = %6.2fs" % (str(sino_filter), timeit.default_timer() - t0))
    # Move phases to shared memory array
    ph = np.frombuffer(ph_array.get_obj(), dtype=np.float16).reshape(ph_shape)
    # Move ph to shared memory buffer
    if nb_chunk == 1:
        ph[:] = ph0
    else:
        # print(i0 + i * nb_chunk, i, ph0.shape, ph.shape)
        ph[i0::nb_chunk] = ph0
    logger.info("Copied back phases for the main process")

    if not return_probe:
        return idx, None, p

    pady, padx = p.data.padding
    if pady and padx:
        probe = p.get_probe()[..., pady:-pady, padx:-padx]
    elif pady:
        probe = p.get_probe()[..., pady:-pady]
    elif padx:
        probe = p.get_probe()[..., :, padx:-padx]
    else:
        probe = p.get_probe()

    # only for debugging - save object and probe
    # p.save_obj_probe_chunk("obj_probe_%04d", save_obj_complex=True, save_probe=True, dtype=np.float16)
    return idx, probe, p


# Shared arrays dictionary for the FBP pool
shared_arrays_fbp = {}


def fbp(vy, nproj, ny, nx, tomo_rot_center, tomo_angular_step, idx_in, idx_out, profile, igpu, sino_filter="ram-lak"):
    """

    :param vy:
    :param nproj:
    :param ny:
    :param nx:
    :param tomo_rot_center:
    :param tomo_angular_step:
    :param idx_in:
    :param idx_out:
    :param profile:
    :param igpu:
    :param sino_filter: if None, assume that the sinogram filtering before the FBP has already been
        done.
    :return:
    """
    # os.environ['PYNX_PU'] = 'cuda.%d' % igpu
    if 'CUDA_VISIBLE_DEVICES' not in os.environ:
        os.environ['CUDA_VISIBLE_DEVICES'] = "%d" % igpu
    else:
        vigpu = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
        os.environ['CUDA_VISIBLE_DEVICES'] = vigpu[igpu]
    from nabu.reconstruction.fbp import Backprojector
    from ...processing_unit.cu_resources import cu_resources
    from ...processing_unit import default_processing_unit as pu
    pu.select_gpu(language='cuda')
    pu.enable_profiling(profile)
    cu_ctx = cu_resources.get_context(pu.cu_device)
    ph_array = shared_arrays_fbp['ph_array']
    vol_array = shared_arrays_fbp['vol_array']
    ph = np.frombuffer(ph_array.get_obj(), dtype=np.float16).reshape((nproj, ny, nx))
    vol = np.frombuffer(vol_array.get_obj(), dtype=np.float16).reshape((ny, nx, nx))
    B = Backprojector((nproj, nx), rot_center=tomo_rot_center,
                      angles=np.deg2rad(np.arange(nproj) * tomo_angular_step),
                      filter_name=sino_filter, cuda_options={'ctx': cu_ctx})
    # sys.stdout.write('Reconstructing %d slices (%d x %d): ' % (ny, nproj, nx))
    volmin, volmax = [], []
    for i in vy:
        sino = ph[:, i, :].astype(np.float32)
        # if i % 50 == 0:
        #     sys.stdout.write('%d..' % (ny - i))
        #     sys.stdout.flush()
        if sino_filter is None:
            # Filtering has already been done when getting the unwrapped phases,
            # using the padded areas.
            # Need to normalise to keep the same scale as through nabu filtering
            sino *= np.pi / len(sino)
            vol[i] = B.backproj(sino)
        else:
            vol[i] = B.fbp(sino)
        vol[i][idx_out] = np.nan
        # Get min/max, avoid slow nanmin and nanmax later
        volmin.append(vol[i][idx_in].min())
        volmax.append(vol[i][idx_in].max())
    # print("\n")
    if profile:
        pu.enable_profiling(False)
    return volmin, volmax


def fbp_pool_init(ph_array, vol_array):
    shared_arrays_fbp['ph_array'] = ph_array
    shared_arrays_fbp['vol_array'] = vol_array


def fbp_kw(kwargs):
    return fbp(**kwargs)


def vol2tiff_slices(filename, volslice, pxum, vmin, vmax):
    """
    Save a volume slice to a tiff file.
    :param filename: the filename e.g. sample_vol/sample_vol00053.tiff
    :param volslice: the volume as a float array
    :param pxum: pixel/voxel spacing in micron
    :param vmin: min value to scale the slice to uint16
    :param vmin: max value to scale the slice to uint16
    :return: nothing
    """
    volint = ((volslice - vmin) * ((2 ** 16 - 1) / (vmax - vmin))).astype(np.uint16)
    imwrite(filename, volint, imagej=True, resolution=(1 / pxum, 1 / pxum),
            metadata={'spacing': pxum, 'unit': 'um', 'hyperstack': False})


def run_fbp(data_src, prefix_output, ph_array, ph_shape, idx, tomo_rot_center, pixel_size,
            save_fbp_vol='tiff', save_3sino=True, logger=None, profile=False, ngpu=1,
            sino_filter='ram-lak', params=None):
    if logger is None:  # Should not happen
        logger = logging.getLogger()
    from matplotlib.figure import Figure
    from matplotlib.backends.backend_agg import FigureCanvasAgg
    from tomoscan.esrf.edfscan import EDFTomoScan
    logger.info("#" * 48)
    logger.info("# 3D Tomography reconstruction (FBP) with Nabu using %d GPU(s)" % ngpu)
    logger.info("#" * 48)
    t0 = timeit.default_timer()

    ph = np.frombuffer(ph_array.get_obj(), dtype=np.float16).reshape(ph_shape)

    # Tomo angular step
    scan_range, tomo_n = EDFTomoScan.get_scan_range(data_src), EDFTomoScan.get_tomo_n(data_src)
    tomo_angular_step_orig = scan_range / (tomo_n - 1)  # Keep in degrees to match user input
    tomo_angular_step = tomo_angular_step_orig * (idx.max() - idx.min()) / (len(idx) - 1)
    logger.info("Calculated angular step: %6.2f° / %d = %6.4f°" % (scan_range, tomo_n, tomo_angular_step))

    nproj, ny, nx = ph.shape
    vol_shape = (ny, nx, nx)

    logger.info("Allocating volume array: %s [%6.2fGB]" % (str(vol_shape), np.prod(vol_shape) * 2 / 1024 ** 3))
    vol_array = Array(ctypes.c_int16, int(np.prod(vol_shape)))
    vol = np.frombuffer(vol_array.get_obj(), dtype=np.float16).reshape(vol_shape)
    # Mask outside of reconstruction
    ix, iy = np.meshgrid(np.arange(nx) - nx / 2, np.arange(nx) - nx / 2)
    r2 = ix ** 2 + iy ** 2
    idx_out = r2 >= (nx * nx / 4)
    idx_in = r2 < (nx * nx / 4)

    logger.info('Starting Filtered Back-Projection using Nabu - %d slices (%d x %d)' % (ny, nproj, nx))
    if sino_filter is None:
        logger.info('sino_filter=None - assuming FBP filtering has already been performed')
    else:
        logger.info('sino_filter=%s' % sino_filter)
    volmin, volmax = [], []
    n_fbp_proc = 4 * ngpu
    vkw = [{'vy': range(i, ny, n_fbp_proc), 'nproj': nproj, 'ny': ny, 'nx': nx, 'tomo_rot_center': tomo_rot_center,
            'tomo_angular_step': tomo_angular_step, 'idx_in': idx_in, 'idx_out': idx_out, "profile": profile,
            'igpu': i % ngpu, 'sino_filter': sino_filter} for i in range(n_fbp_proc)]
    with Pool(n_fbp_proc, initializer=fbp_pool_init, initargs=(ph_array, vol_array)) as pool:
        results = pool.imap(fbp_kw, vkw)
        for i in range(len(vkw)):
            r = results.next()
            volmin += r[0]
            volmax += r[1]

    dt = timeit.default_timer() - t0
    logger.info("Time to perform FBP reconstruction:  %4.1fs" % dt)

    logger.info("#" * 48)
    logger.info(" Saving central XYZ cuts of volume to: %s_volXYZ.png" % prefix_output)
    logger.info("#" * 48)
    fig = Figure(figsize=(16, 16))
    ax1, ax2, ax3, ax4 = fig.subplots(2, 2).flat
    ax4.set_visible(False)
    c = vol[ny // 2].astype(np.float32)
    vmin, vmax = np.nanpercentile(c, (0.1, 99.9))
    ax = ax1.imshow(c, vmin=vmin, vmax=vmax, cmap='gray')
    # ax1.set_title(os.path.split(prefix_output)[-1] + "[XY]")
    # ax.colorbar()

    # We use a layout similar to what is used in imagej
    c = vol[..., nx // 2].astype(np.float32).transpose()
    vmin, vmax = np.nanpercentile(c, (0.1, 99.9))
    ax = ax2.imshow(c, vmin=vmin, vmax=vmax, cmap='gray')
    # ax.colorbar()
    # ax2.set_title(os.path.split(prefix_output)[-1] + "[YZ]")

    c = vol[:, nx // 2].astype(np.float32)
    vmin, vmax = np.nanpercentile(c, (0.1, 99.9))
    ax = ax3.imshow(c, vmin=vmin, vmax=vmax, cmap='gray')
    # ax3.set_title(os.path.split(prefix_output)[-1] + "[XZ]")
    # ax.colorbar()

    # Step between text lines of size 6
    fontsize = 8
    dy = fontsize * 1.2 / 72 / fig.get_size_inches()[1]
    if params is not None:
        fig.text(0.5, 1, r"%s   binning=%d   nz=%d   %s   $\delta/\beta=%3.0f$" %
                 (os.path.split(prefix_output)[-1], params.binning, params.nz,
                  params.algorithm, params.delta_beta), fontsize=12, horizontalalignment='center',
                 verticalalignment='top', stretch='condensed')
        # Print parameters
        y0 = 0.45 - 1.5 * dy
        n = 1
        vpars = vars(params)
        vpars['host'] = "%s - %s" % (platform.node(), platform.platform())
        vpars['directory'] = os.getcwd()
        try:
            vpars['user'] = os.getlogin()
        except OSError:
            # When using slurm, this is not running in a shell with user info
            if 'USER' in os.environ:
                vpars['user'] = os.environ['USER']
        if params.tomo_rot_center is None:
            vpars['tomo_rot_center_real'] = tomo_rot_center
        vk = list(vpars.keys())
        vk.sort()
        for k in vk:
            v = vpars[k]
            if v is not None and k not in ['slurm', 'step']:
                fig.text(0.55, y0 - n * dy, "%s = %s" % (k, str(v)), fontsize=fontsize, horizontalalignment='left',
                         stretch='condensed')
                n += 1
    else:
        fig.text(0.5, 1, "%s" % (os.path.split(prefix_output)[-1]), fontsize=12,
                 horizontalalignment='center', verticalalignment='top', stretch='condensed')
    fig.text(dy, dy, "PyNX v%s, finished at %s" % (_pynx_version, time.strftime("%Y/%m/%d %H:%M:%S")),
             fontsize=fontsize, horizontalalignment='left', stretch='condensed')

    fig.tight_layout()
    canvas = FigureCanvasAgg(fig)
    canvas.print_figure(prefix_output + "_volXYZ.png", dpi=120)

    if save_3sino:
        filename = prefix_output + "_3sino.npz"
        logger.info("#" * 48)
        logger.info(" Saving 3 sinograms to: %s" % filename)
        logger.info("#" * 48)
        ny = ph.shape[1]
        np.savez_compressed(filename, sino=np.take(ph.astype(np.float32), (ny // 4, ny // 2, 3 * ny // 4), axis=1))

    if isinstance(save_fbp_vol, bool):
        if save_fbp_vol:
            save_fbp_vol = "hdf5"
        else:
            save_fbp_vol = "False"
    if "tif" in save_fbp_vol:
        logger.info("#" * 48)
        logger.info(" Exporting FBP volume as an uint16 TIFF file")
        logger.info("#" * 48)
        t1 = timeit.default_timer()
        logger.info("Computing min/max of volume")
        # The following operations are dead slow using float16
        vmin, vmax = np.array(volmin).min(), np.array(volmax).max()
        pxum = pixel_size * 1e6
        if vol.size * 2 >= 4 * 1024 ** 3:  # 4GB tiff file limit
            logger.info("Volume is %6.2fGB > 4GB => saving volume as separate tiff slices" %
                        (vol.size * 2 / 1024 ** 3))
            logger.info("Saving volume to: %s_vol/%s_volNNNNN.tiff" % (prefix_output, prefix_output))
            os.makedirs("%s_vol" % prefix_output, exist_ok=True)
            with Pool(nproc) as pool:
                vpar = [("%s_vol/%s_vol%05d.tiff" % (prefix_output, prefix_output, i), vol[i], pxum, vmin, vmax)
                        for i in range(ny)]
                pool.starmap(vol2tiff_slices, vpar)
            # for i in range(ny):
            #     filename = "%s_vol/%s_vol%05d.tiff" % (prefix_output, prefix_output, i)
            #     imwrite(filename, volint[i], imagej=True, resolution=(1 / pxum, 1 / pxum),
            #             metadata={'spacing': pxum, 'unit': 'um', 'hyperstack': False}, maxworkers=nproc)
        else:
            logger.info("Converting volume to uint16")
            volint = ((vol - vmin) * ((2 ** 16 - 1) / (vmax - vmin))).astype(np.uint16)
            filename = prefix_output + "_vol.tiff"
            logger.info("Writing tiff file: %s" % filename)
            imwrite(filename, volint, imagej=True, resolution=(1 / pxum, 1 / pxum),  # bigtiff=True,
                    metadata={'spacing': pxum, 'unit': 'um', 'hyperstack': False}, maxworkers=nproc)
            logger.info("Finished saving volume to: %s" % filename)
        dt = timeit.default_timer() - t1
        logger.info("Time to export volume as tiff:  %4.1fs" % dt)
    if "hdf5" in save_fbp_vol:
        logger.info("#" * 48)
        logger.info(" Saving FBP volume as an hdf5/float16")
        logger.info("#" * 48)
        t1 = timeit.default_timer()
        import hdf5plugin
        import h5py as h5
        filename = prefix_output + "pynx_vol.h5"
        f = h5.File(filename, "w")
        f.attrs['creator'] = 'PyNX'
        # f.attrs['NeXus_version'] = '2018.5'  # Should only be used when the NeXus API has written the file
        f.attrs['HDF5_Version'] = h5.version.hdf5_version
        f.attrs['h5py_version'] = h5.version.version
        f.attrs['default'] = 'entry_1'

        entry_1 = f.create_group("entry_1")
        entry_1.attrs['NX_class'] = 'NXentry'
        entry_1.attrs['default'] = 'data_1'
        data_1 = entry_1.create_group("data_1")
        data_1.attrs['NX_class'] = 'NXdata'
        data_1.attrs['signal'] = 'data'
        data_1.attrs['interpretation'] = 'image'
        data_1['title'] = 'PyNX (FBP:Nabu)'
        nz, ny, nx = vol.shape
        data_1.create_dataset("data", data=vol.astype(np.float16), chunks=(1, ny, nx), shuffle=True,
                              **hdf5plugin.LZ4())
        f.close()
        logger.info("Finished saving %s" % filename)

        dt = timeit.default_timer() - t1
        logger.info("Time to save hdf5 volume:  %4.1fs" % dt)

    dt = timeit.default_timer() - t0
    logger.info("Time for 3D FBP & save volume:  %4.1fs" % dt)


def worker_params_id16(qin: Queue, qout: Queue):
    # only to avoid imports triggering cuda init
    if 'id16a' in sys.argv[0]:
        from pynx.holotomo.utils import get_params_id16a as get_params_id16
    else:
        from pynx.holotomo.utils import get_params_id16b as get_params_id16
    for kwargs in iter(qin.get, 'STOP'):
        logger, stream = new_logger_stream("ID16Params")
        result = get_params_id16(**kwargs, logger=logger)
        qout.put((result, stream.getvalue()))


def worker_load_align_zoom(qin, qout, iobs_array):
    # Load & align images
    for kwargs in iter(qin.get, 'STOP'):
        logger, stream = new_logger_stream("Load-Align-Zoom")
        ref, dx, dy = load_align_zoom(**kwargs, iobs_array=iobs_array, logger=logger)
        # Move iobs array to a shared_memeory buffer to avoid serialising
        # TODO: avoid the copy and make load_align_zoom directly write to shared memory
        qout.put((ref, dx, dy, stream.getvalue()))


def worker_find_cor(qin, qout, ph_array):
    # Center of rotation
    for kwargs in iter(qin.get, 'STOP'):
        logger, stream = new_logger_stream("CORFinder")
        logger.info("Searching for centre of rotation")
        result = find_cor(**kwargs, logger=logger, ph_array=ph_array)
        logger.info("Returning CoR = %8.2f" % result)
        qout.put((result, stream.getvalue()))


def worker_algorithms(qin, qout, iobs_array, ph_array, profile, igpu):
    # os.environ['PYNX_PU'] = 'cuda.%d' % igpu
    if 'CUDA_VISIBLE_DEVICES' not in os.environ:
        os.environ['CUDA_VISIBLE_DEVICES'] = "%d" % igpu
    else:
        vigpu = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
        os.environ['CUDA_VISIBLE_DEVICES'] = vigpu[igpu]
    from pynx.holotomo.operator import ScaleObjProbe
    from pycuda.driver import MemoryError as cuMemoryError
    pu = ScaleObjProbe().processing_unit
    pu.enable_profiling(profiling=profile)
    gpu_mem_f = 1  # use this to lower GPU memory requirements if we hit a MemoryError
    data, ht = None, None
    for kwargs in iter(qin.get, 'STOP'):
        logger, stream = new_logger_stream("Algorithms%d" % igpu)
        logger.info("#" * 48)
        logger.info("# Algorithms for chunk %d/%d:  [GPU #%d]" % (kwargs["ichunk"] + 1, kwargs['nb_chunk'], igpu))
        logger.info("#" * 48)
        pu.set_logger(logger)
        todo = True
        while todo:
            try:
                logger.info("Preparing HolotomoData")
                data = prep_holotomodata(iobs_array=iobs_array, iobs_shape=kwargs['iobs_shape'], ref=kwargs['ref'],
                                         pixel_size=kwargs['pixel_size'], wavelength=kwargs['wavelength'],
                                         detector_distance=kwargs['detector_distance'],
                                         dx=kwargs['dx'], dy=kwargs['dy'], proj_idx=kwargs['proj_idx'],
                                         padding=kwargs['padding'], algorithm=kwargs['algorithm'],
                                         gpu_mem_f=gpu_mem_f, pu=pu, logger=logger, data_old=data)
                qout.put("Finished preparing HolotomoData")  # This frees iobs to load the next dataset.
                logger.info("Beginning algorithms")
                idx, probe, ht = run_algorithms(data, iobs_shape=kwargs['iobs_shape'], padding=kwargs['padding'],
                                                delta_beta=kwargs['delta_beta'], algorithm=kwargs['algorithm'],
                                                ph_shape=kwargs['ph_shape'], i0=kwargs['i0'],
                                                nb_chunk=kwargs['nb_chunk'],
                                                obj_smooth=kwargs['obj_smooth'], obj_inertia=kwargs['obj_inertia'],
                                                obj_min=kwargs['obj_min'], obj_max=kwargs['obj_max'],
                                                nb_probe=kwargs['nb_probe'], ph_array=ph_array,
                                                sino_filter=kwargs['sino_filter'], pu=pu, logger=logger,
                                                old_ht=ht, return_probe=kwargs['return_probe'])
                todo = False
            except cuMemoryError as ex:
                try:
                    del data
                    data = None
                    gc.collect()
                    gc.collect()
                except NameError:  # Is 'data' defined here ?
                    pass
                if gpu_mem_f > 1:
                    logger.warning("cuda MemoryError - again - giving up [GPU#%d]" % igpu)
                    raise ex
                else:
                    logger.warning("cuda MemoryError - try again, lowering GPU memory usage [GPU#%d]" % igpu)
                    os.system('nvidia-smi')
                    gpu_mem_f = 2
        logger.info("Finished algorithms for chunk %d/%d: sending back unwrapped phases and probe" %
                    (kwargs['ichunk'] + 1, kwargs['nb_chunk']))
        qout.put((idx, probe, stream.getvalue()))
    if profile:
        pu.enable_profiling(False)
    del data, ht
    gc.collect()
    gc.collect()


def worker_save_projections_hdf5(qin: JoinableQueue, ph_array):
    from pynx.holotomo import save_obj_probe_chunk
    for kwargs in iter(qin.get, 'STOP'):
        result = save_obj_probe_chunk(**kwargs, obj_phase=ph_array)
        qin.task_done()
    qin.task_done()


def worker_save_projections_edf(qin: JoinableQueue, ph_array):
    from pynx.holotomo.utils import save_phase_edf_kw
    for kwargs in iter(qin.get, 'STOP'):
        idx, prefix = kwargs['idx'], kwargs['prefix']
        ph = np.frombuffer(ph_array.get_obj(), dtype=ctypes.c_float).reshape(kwargs['ph_shape'])
        vkw = [{'idx': idx[i], 'ph': ph[i], 'prefix_output': prefix} for i in range(len(idx))]
        with Pool(nproc) as pool:
            pool.map(save_phase_edf_kw, vkw)
        qin.task_done()
    qin.task_done()


def worker_fbp(qin, qout, ph_array, profile):
    # Filtered back-projection CT with Nabu
    for kwargs in iter(qin.get, 'STOP'):
        logger, stream = new_logger_stream("Tomo (Nabu FBP)")
        run_fbp(**kwargs, ph_array=ph_array, logger=logger, profile=profile)
        qout.put(stream.getvalue())


def master(params):
    from tomoscan.esrf.edfscan import EDFTomoScan
    # Setup logging
    DeltaTimeFormatter.start_time = time.time()
    logger, stream = new_logger_stream("Master")
    logger.info("Host: %s - %s" % (platform.node(), platform.platform()))
    logger.info("Working directory: %s" % os.getcwd())
    try:
        logger.info("User: %s " % os.getlogin())
    except OSError:
        # When using slurm, this is not running in a shell with user info
        if 'USER' in os.environ:
            logger.info("User: %s " % os.environ['USER'])
    for k in ['SLURM_GPUS_ON_NODE', 'SLURM_JOB_ACCOUNT', 'SLURM_JOB_GPUS', 'SLURM_JOB_ID',
              'SLURM_JOB_NAME', 'SLURM_JOB_PARTITION', 'SLURM_CPUS_ON_NODE', 'SLURM_NTASKS',
              'SLURM_MEM_PER_NODE']:
        if k in os.environ:
            logger.info("%s: %s " % (k, os.environ[k]))
    logger.info("#" * 48)
    logger.info("# Parsing parameters")
    logger.info("#" * 48)

    c = sys.argv[0]
    for a in sys.argv[1:]:
        c += " %s" % a
    logger.info("Command: %s" % c)
    logger.info("Parameters:")
    vp = vars(params)
    vk = list(vp.keys())
    vk.sort()
    params_dic = {}  # Later, for export
    for k in vk:
        v = vp[k]
        logger.info("     %s: %s" % (k, str(v)))
        params_dic[k] = v

    if 'CUDA_VISIBLE_DEVICES' in os.environ:
        ngpu_env = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
        if ngpu_env < params.ngpu:
            raise RuntimeError("ngpu=%d have been required but CUDA_VISIBLE_DEVICES=%s" %
                               (params.ngpu, os.environ['CUDA_VISIBLE_DEVICES']))

    # Analysing data name
    if params.data[-1] == "/":
        params.data = params.data[:-1]
    data_dir, prefix = os.path.split(params.data)
    iz = int(prefix[-2])
    prefix = prefix[:-2]  # Remove the trailing '1_'

    if params.prefix_output is None:
        i = 1
        while True:
            if params.nz > 1:
                params.prefix_output = prefix + "result%02d" % i
            else:
                params.prefix_output = prefix + "%d_result%02d" % (params.planes[0], i)
            if len([fn for fn in os.listdir('.') if fn.startswith(params.prefix_output)]) == 0:
                break
            i += 1
        logger.info('No prefix_output given, will use: %s' % params.prefix_output)

    # Create logfile
    path = os.path.split(params.prefix_output)[0]
    if len(path):
        os.makedirs(path, exist_ok=True)
    logfile = open("%s.log" % params.prefix_output, 'w')

    # if params.step is None:
    #     params.step = params.binning
    if params.last is None:
        # TODO: why does nb_proj from get_tomo_n differ from the number of keys ?
        nb_proj = EDFTomoScan.get_tomo_n(params.data)
        logger.info("No projection_range given - using all %d projections" % nb_proj)
        proj_idx = list(EDFTomoScan.get_proj_urls(params.data, n_frames=1).keys())
        proj_idx.sort()
        proj_idx = proj_idx[:nb_proj]  # Remove extra projections for alignment.
        proj_idx = np.array(proj_idx[params.first::params.binning])
        logger.info("Projections: [%d, %d ...%d]" % (proj_idx[0], proj_idx[1], proj_idx[-1]))
    else:
        logger.info("projection_range: [%d-%d::%d]" % (params.first, params.last, params.binning))
        proj_idx = np.arange(params.first, params.last + 1, params.binning)
    nb_proj = len(proj_idx)  # number of images loaded (excluding dark and empty_beam images)
    logger.info("nb_proj=%d" % nb_proj)

    logger, stream = new_logger_stream("Master", flush_to=logfile)

    # Gather magnified pixel sizes and propagation distances
    # This is done in a separate process to avoid imports triggering cuda init
    # ...though this should not be necessary !
    kwargs = {"scan_path": params.data, "planes": params.planes, "sx0": params.sx0, "cx": params.cx, "sx": params.sx,
              "pixel_size_detector": params.pixel_size_detector}

    queue_params_id16_in = Queue()
    queue_params_id16_out = Queue()
    Process(target=worker_params_id16, args=(queue_params_id16_in, queue_params_id16_out)).start()
    queue_params_id16_in.put(kwargs)
    params_id16, log = queue_params_id16_out.get()
    queue_params_id16_in.put('STOP')

    logfile.write(log)
    logfile.flush()
    ny, nx = params_id16['ny'] // params.binning, params_id16['nx'] // params.binning
    nz = params.nz

    if params.nrj is None:
        wavelength = params_id16['wavelength']
    else:
        wavelength = 12.3984 / params.nrj
    logger.info("Energy: %6.2f keV" % (12.3984e-10 / wavelength))
    logger.info("Wavelength: %8.6e m" % wavelength)

    if params.distance is None:
        detector_distance = params_id16['distance']
    else:
        detector_distance = np.array(params.distance)

    if params.pixel_size is None:
        pixel_size_z = params_id16['pixel_size']
    else:
        pixel_size_z = np.array(params.pixel_size)
    pixel_size_z *= params.binning
    pixel_size = pixel_size_z[0]

    magnification = pixel_size_z[0] / pixel_size_z

    logger.info("Detector distances [m]: %s" % str(detector_distance))
    logger.info("Pixel sizes [nm]: %s" % str(pixel_size_z * 1e9))

    # Padding:
    if params.padding == 2:
        pady, padx = ny // 2, nx // 2
        logger.info("Padding==2 => using pady=%d  padx=%d" % (pady, padx))
    else:
        pady = padx = params.padding

    ################################################################
    # Test for radix transforms
    ################################################################
    # Test if we have a radix transform
    logger.info("Testing dimensions prime decomposition: ny = %d = %s  nx= %d = %s" %
                (ny + 2 * pady, str(primes(ny)),
                 nx + 2 * padx, str(primes(nx))))
    primesy, primesx = primes(ny + 2 * pady), primes(nx + 2 * padx)
    if max(primesy) > 13:
        padup = pady
        while not test_smaller_primes(ny + 2 * padup, required_dividers=[2]):
            padup += 1
        paddown = pady
        while not test_smaller_primes(ny + 2 * paddown, required_dividers=[2]):
            paddown -= 1
        s = "The Y dimension (with padding=%d,%d) is incompatible with a radix FFT:\n" \
            "  ny=%d primes=%s  (should be <=13)\n" \
            "  Closest acceptable padding values: %d or %d" % \
            (pady, padx, ny + 2 * pady, str(primesy), paddown, padup)
        logger.warning(s)
        # Use closest padding option, unless we pad to 2x the array size, in which case
        # we can only go down
        if (padup - pady < pady - paddown) and (params.padding != 2):
            pady = padup
        else:
            pady = paddown
        logger.info("Changing Y padding to closest radix-compatible value: %d" % pady)

    if max(primesx) > 13:
        padup = padx
        while not test_smaller_primes(nx + 2 * padup, required_dividers=[2]):
            padup += 1
        paddown = pady
        while not test_smaller_primes(nx + 2 * paddown, required_dividers=[2]):
            paddown -= 1
        s = "The X dimension (with padding=%d) is incompatible with a radix FFT:\n" \
            "  nx=%d primes=%s  (should be <=13)\n" \
            "  Closest acceptable padding values: %d or %d" % \
            (padx, nx + 2 * padx, str(primesx), paddown, padup)
        logger.warning(s)
        # Use closest padding option, unless we pad to 2x the array size, in which case
        # we can only go down
        if (padup - padx < padx - paddown) and (params.padding != 2):
            padx = padup
        else:
            padx = paddown
        logger.info("Changing X padding to closest radix-compatible value: %d" % pady)
    # Add real padding to params
    setattr(params, 'padding_real', (pady, padx))
    # logger, stream = new_logger_stream("Master", flush_to=logfile)
    # raise RuntimeError(s)

    ################################################################
    # Estimate memory requirements & number of chunks
    ################################################################
    mem_req = nz * (ny + 2 * pady) * (nx + 2 * padx) * 4
    mem_req += 12 * (nx + 2 * pady) * (ny + 2 * padx)  # object projections & phase
    if True:  # 'RAAR' in algorithm or 'DM' in algorithm or 'DRAP' in algorithm:
        mem_req += 8 * nz * (nx + 2 * pady) * (ny + 2 * padx)
    logger.info("Estimated memory requirement: %8.4fGB/projection (total = %8.4fGB)" %
                (mem_req / 1024 ** 3, mem_req * nb_proj / 1024 ** 3))
    # Maximum memory used for a single chunk of projections (not counting final result)
    max_mem = params.max_mem_chunk * 1024 ** 3
    nb_chunk = int(np.ceil(nb_proj * mem_req / max_mem / params.ngpu)) * params.ngpu
    if nb_chunk == 1:
        nb_chunk = 2  # This allows for the CoR to be found while the phasing finishes
    nb_proj_chunk = len(proj_idx[::nb_chunk])
    logger.info("Using %d chunk(s) [%d projections/chunk]" % (nb_chunk, nb_proj_chunk))

    ################################################################
    # Create shared memory arrays before forking process
    # NB: we use multiprocessing.Array rather than shared_memory
    # as the latter runs into issues (Bus error) for large sizes (around 32GB)
    # on power9 (bug ?)
    ################################################################
    logger.info("#" * 48)
    logger.info("# Creating shared memory arrays")
    logger.info("#" * 48)
    t0 = timeit.default_timer()

    # Create shared memory array for phases
    ph_shape = (nb_proj, ny, nx)
    logger.info("Phase array shape: %s [%6.2fGB]" % (str(ph_shape), np.prod(ph_shape) * 2 / 1024 ** 3))
    ph_array_dtype = np.float16
    ph_array = Array(ctypes.c_int16, int(np.prod(ph_shape)))  # There is no ctypes.c_float16
    ph = np.frombuffer(ph_array.get_obj(), dtype=ph_array_dtype).reshape(ph_shape)

    ################################################################
    # Setup job queues & worker for center of rotation
    ################################################################
    # Center of rotation
    queue_cor_in = Queue()
    queue_cor_out = Queue()
    p_cor = Process(target=worker_find_cor, args=(queue_cor_in, queue_cor_out, ph_array))
    p_cor.start()
    tomo_rot_center = params.tomo_rot_center

    ################################################################
    # Optionally - reconstruct volume from holoct projections,
    # and exit
    ################################################################
    if params.holoct is not None:
        # Export same files for holoct files, assuming we can find them
        # FBP
        queue_fbp_in = Queue()
        queue_fbp_out = Queue()
        p_fbp = Process(target=worker_fbp, args=(queue_fbp_in, queue_fbp_out, ph_array, params.profile))
        p_fbp.start()
        logger.info("Processing volume from holoct reconstructed projections...")
        from ..utils import load_holoct_data_pool_init, load_holoct_data_kw
        import fabio
        # Check we can find holoct files and load them into the phase array
        if not os.path.exists(params.holoct % proj_idx[0]):
            raise RuntimeError("HoloCT volume calculation: could not find %s" % (params.holoct % proj_idx[0]))
        # Use a fake dark array to re-use loading functions
        vkw = [{'i': i, 'idx': proj_idx[i], 'nproj': nb_proj, 'nx': nx, 'ny': ny,
                'img_name': params.holoct, 'binning': params.binning,
                'ph_array_dtype': ph_array_dtype}
               for i in range(nb_proj)]
        logger.info("Loading holoct reconstructed projections...")
        t0 = timeit.default_timer()
        with Pool(nproc, initializer=load_holoct_data_pool_init, initargs=(ph_array,)) as pool:
            pool.map(load_holoct_data_kw, vkw)
        dt = timeit.default_timer() - t0
        logger.info("Loading holoct reconstructed projections... Finished [%6.1fMB/s]"
                    % (np.prod(ph_shape) * 2 * params.binning ** 2 / 1024 ** 2 / dt))
        # Check if a sino filter needs to be applied
        edf = fabio.open(params.holoct % proj_idx[0])
        sino_filter = 'ram-lak'
        if 'post_filter' in edf.header:
            if 'sino' in edf.header['post_filter']:
                sino_filter = None

        if tomo_rot_center is None:
            logger.info("Evaluating rotation centre using: %s (method=%s)" % (params.data, params.tomo_cor_method))
            queue_cor_in.put({"data_src": params.data, "prefix": prefix, "binning": params.binning,
                              "reference_plane": params.reference_plane, "method": params.tomo_cor_method,
                              "motion": False,
                              "ph_shape": ph_shape, "ph_array_dtype": ph_array_dtype, "proj_idx": proj_idx})

            tomo_rot_center, log = queue_cor_out.get()
            logger.info("Got back CoR: %6.1f" % tomo_rot_center)
            logger, stream = new_logger_stream("Master", flush_to=logfile)
            logfile.write(log)
            logfile.flush()
        else:
            params.tomo_cor_method = "manual"

        # Perform FBP exactly as before using holoCT data
        logger.info("Launching FBP reconstruction from holoct reconstructed projections...")
        kwargs = {"data_src": params.data, "prefix_output": params.prefix_output + "_holoct",
                  "ph_shape": ph_shape, "idx": proj_idx,
                  "tomo_rot_center": tomo_rot_center, "pixel_size": pixel_size,
                  "save_fbp_vol": params.save_fbp_vol, "save_3sino": params.save_3sino,
                  "ngpu": params.ngpu, "params": params, "sino_filter": sino_filter}
        queue_fbp_in.put(kwargs)
        log = queue_fbp_out.get()
        logger, stream = new_logger_stream("Master", flush_to=logfile)
        logfile.write(log)
        logfile.flush()
        # Flush logger
        logger, stream = new_logger_stream("Master", flush_to=logfile)
        logger.info("Finished FBP reconstruction from holoct reconstructed projections - exiting")
        for p, q in [(p_cor, queue_cor_in), (p_fbp, queue_fbp_in)]:
            q.put('STOP')
            p.join()
        sys.exit(0)

    # Create shared memory array for iobs
    proj_idx_chunk = proj_idx[::nb_chunk]
    iobs_shape = (len(proj_idx_chunk), nz, ny + 2 * pady, nx + 2 * padx)
    logger.info("Iobs shape: %s [%6.2fGB]" % (str(iobs_shape), np.prod(iobs_shape) * 4 / 1024 ** 3))
    iobs_array = Array(ctypes.c_float, int(np.prod(iobs_shape)))

    dt = timeit.default_timer() - t0
    logger.info("Finished creating shared memory arrays: dt = %6.2fs" % dt)
    ################################################################
    # Setup job other queues & workers
    ################################################################
    queue_load_align_in = Queue()
    queue_load_align_out = Queue()
    p_load_align = Process(target=worker_load_align_zoom,
                           args=(queue_load_align_in, queue_load_align_out, iobs_array))
    p_load_align.start()

    # Algorithms for multiple GPUs
    queue_algo_in, queue_algo_out, p_algo = [], [], []
    for igpu in range(params.ngpu):
        queue_algo_in.append(Queue())
        queue_algo_out.append(Queue())
        p_algo.append(Process(target=worker_algorithms, args=(queue_algo_in[-1], queue_algo_out[-1], iobs_array,
                                                              ph_array, params.profile, igpu)))
        p_algo[-1].start()

    # Save projections to hdf5 file
    queue_save_projections_hdf5_in = JoinableQueue()
    p_phases2hdf5 = Process(target=worker_save_projections_hdf5, args=(queue_save_projections_hdf5_in, ph_array))
    p_phases2hdf5.start()

    # Save projections to edf files
    queue_save_projections_edf_in = JoinableQueue()
    p_phases2edf = Process(target=worker_save_projections_edf, args=(queue_save_projections_edf_in, ph_array))
    p_phases2edf.start()

    # FBP
    queue_fbp_in = Queue()
    queue_fbp_out = Queue()
    p_fbp = Process(target=worker_fbp, args=(queue_fbp_in, queue_fbp_out, ph_array, params.profile))
    p_fbp.start()

    ################################################################
    # Distribute tasks for chunks
    ################################################################
    # Distribute chunks so that all even frames are done first, and a CoR search
    # can be performed using that half sinogram
    v_i0 = list(range(0, nb_chunk, 2))
    if nb_chunk > 1:
        v_i0 += list(range(1, nb_chunk, 2))
    # Launch first input job
    proj_idx_chunk = proj_idx[::nb_chunk]
    iobs_shape = (len(proj_idx_chunk), nz, ny + 2 * pady, nx + 2 * padx)

    # Choose default alignment method if not set
    if params.align is None:
        if params.rhapp is None:
            params.align = 'fft_fit'
        else:
            params.align = 'rhapp'

    kwargs = {"ichunk": 0, "data_src": params.data, "prefix": prefix, "proj_idx": proj_idx_chunk,
              "planes": params.planes, "binning": params.binning, "magnification": magnification,
              "padding": (pady, padx), "pad_method": params.pad_method, "rhapp": params.rhapp,
              "reference_plane": params.reference_plane, "spikes_threshold": params.remove_spikes,
              "align_method": params.align, "align_interval": params.align_interval,
              "align_fourier_low_cutoff": params.align_fourier_low_cutoff,
              "prefix_output": params.prefix_output, "nb_chunk": nb_chunk,
              "motion": params.motion, "params": params}
    logger.info("Requesting Iobs data for chunk #1/%d" % nb_chunk)
    queue_load_align_in.put(kwargs)

    for ichunk in range(nb_chunk):
        i0 = v_i0[ichunk]
        # Projections for this chunk
        proj_idx_chunk = proj_idx[i0::nb_chunk]
        # Wait for data (iobs comes from shared memory)
        ref, dx, dy, log = queue_load_align_out.get()
        # GPU for this chunk
        igpu = ichunk % params.ngpu

        logger, stream = new_logger_stream("Master", flush_to=logfile)
        logfile.write(log)
        logfile.flush()
        logger.info("Received Iobs and reference frames for chunk #%d/%d" % (ichunk + 1, nb_chunk))
        # Start algorithms
        kwargs = {"ichunk": ichunk, "iobs_shape": iobs_shape, "ref": ref,
                  "pixel_size": pixel_size, "wavelength": wavelength,
                  "detector_distance": detector_distance, "dx": dx, "dy": dy, "proj_idx": proj_idx_chunk,
                  "padding": (pady, padx), "delta_beta": params.delta_beta, "algorithm": params.algorithm,
                  "ph_shape": ph_shape, "i0": i0, "nb_chunk": nb_chunk,
                  "obj_smooth": params.obj_smooth, "obj_inertia": params.obj_inertia,
                  "obj_min": params.obj_min, "obj_max": params.obj_max, "nb_probe": params.nb_probe,
                  "sino_filter": params.sino_filter if padx else None,
                  "return_probe": params.save_phase_chunks}
        logger.info("Launching algorithms for chunk #%d/%d" % (ichunk + 1, nb_chunk))
        queue_algo_in[igpu].put(kwargs)
        # Get first message that HolotomoData is ready and the next set of data can be read
        queue_algo_out[igpu].get()
        # Read next set of data now that iobs data is free
        if ichunk < (nb_chunk - 1):
            proj_idx_chunk_next = proj_idx[v_i0[ichunk + 1]::nb_chunk]
            iobs_shape = (len(proj_idx_chunk_next), nz, ny + 2 * pady, nx + 2 * padx)

            kwargs = {"ichunk": ichunk + 1, "data_src": params.data, "prefix": prefix, "proj_idx": proj_idx_chunk_next,
                      "planes": params.planes, "binning": params.binning, "magnification": magnification,
                      "padding": (pady, padx), "pad_method": params.pad_method, "rhapp": params.rhapp,
                      "reference_plane": params.reference_plane, "spikes_threshold": params.remove_spikes,
                      "align_method": params.align, "align_interval": params.align_interval,
                      "align_fourier_low_cutoff": params.align_fourier_low_cutoff,
                      "prefix_output": params.prefix_output, "nb_chunk": nb_chunk,
                      "motion": params.motion}
            logger.info("Requesting Iobs data for chunk #%d/%d" % (ichunk + 2, nb_chunk))
            queue_load_align_in.put(kwargs)
        if (ichunk + 1) >= params.ngpu:
            # Get results (phases are in shared memory) for chunk ichunk - ngpu +1
            idx, probe, log = queue_algo_out[(ichunk - params.ngpu + 1) % params.ngpu].get()

            if not np.allclose(idx, proj_idx[v_i0[ichunk - params.ngpu + 1]::nb_chunk]):
                print("Whoops: idx != proj_idx_chunk:")
                print(idx)
                print(proj_idx_chunk)
            logger, stream = new_logger_stream("Master", flush_to=logfile)
            logfile.write(log)
            logfile.flush()
            logger.info("Received unwrapped phases and probe for chunk %d/%d" %
                        (ichunk + 1 - params.ngpu + 1, nb_chunk))
            if tomo_rot_center is None:
                if (ichunk - params.ngpu + 2 == nb_chunk // 2) or (nb_chunk == 1):
                    logger.info("Evaluating rotation centre using: %s (method=%s)" %
                                (params.data, params.tomo_cor_method))
                    queue_cor_in.put({"data_src": params.data, "prefix": prefix, "binning": params.binning,
                                      "reference_plane": params.reference_plane,
                                      "method": params.tomo_cor_method, "motion": params.motion,
                                      "ph_shape": ph_shape, "ph_array_dtype": ph_array_dtype, "proj_idx": proj_idx})
            else:
                params.tomo_cor_method = "manual"

    # Get results for the last (ngpu-1) chunks
    if params.ngpu > 1:
        for ichunk in range(nb_chunk - params.ngpu + 1, nb_chunk):
            idx, probe, log = queue_algo_out[ichunk % params.ngpu].get()

            if not np.allclose(idx, proj_idx[v_i0[ichunk]::nb_chunk]):
                print("Whoops: idx != proj_idx_chunk:")
                print(idx)
                print(proj_idx_chunk)
            logger, stream = new_logger_stream("Master", flush_to=logfile)
            logfile.write(log)
            logfile.flush()
            logger.info("Received unwrapped phases and probe for chunk %d/%d" % (ichunk + 1, nb_chunk))

    tmp = ",".join(["%d" % p.pid for p in p_algo])
    logger.info("Joining load[%d] & algorithm[%s] processes" % (p_load_align.pid, tmp))
    for p, q in [(p_load_align, queue_load_align_in)] + [(p_algo[i], queue_algo_in[i]) for i in range(params.ngpu)]:
        q.put('STOP')
        p.join()

    # Free some shared memory before FBP
    del iobs_array
    gc.collect()

    # Save projections
    if params.save_phase_chunks:
        fname = params.prefix_output + "_i0=%04d.h5" % proj_idx[0]
        logger.info("Saving holotomo projections to: %s" % fname)
        kwargs = {"filename": fname, "idx": idx, "pixel_size": pixel_size,
                  "obj_phase_shape": ph_shape,
                  "probe": probe, "process_parameters": params_dic}
        queue_save_projections_hdf5_in.put(kwargs)

    if params.save_edf:
        logger.info("#" * 48)
        logger.info(" Saving phased images to edf files: " + params.prefix_output + "_%04d.edf")
        logger.info("#" * 48)
        kwargs = {"idx": idx, "ph_shape": ph_shape, "prefix": params.prefix_output}
        queue_save_projections_edf_in.put()

    # Get back rotation centre if needed
    if params.tomo_rot_center is None:
        logger.info("Getting back CoR estimations")
        tomo_rot_center, log = queue_cor_out.get()
        logger.info("Got back CoR: %6.1f" % tomo_rot_center)
        logger, stream = new_logger_stream("Master", flush_to=logfile)
        logfile.write(log)
        logfile.flush()

    # Run FBP & export reconstructed volume
    if params.save_fbp_vol:
        sino_filter = params.sino_filter
        if sino_filter is not None:
            if padx:
                # If sino_filter was supplied and padding is used, then the pre-backprojection
                # filtering was performed when getting the unwrapped phases at the end of the algorithms
                sino_filter = None
        else:
            # No arguments where given, so perform a standard FBP filtering
            sino_filter = "ram-lak"

        kwargs = {"data_src": params.data, "prefix_output": params.prefix_output,
                  "ph_shape": ph_shape, "idx": proj_idx,
                  "tomo_rot_center": tomo_rot_center / params.binning, "pixel_size": pixel_size,
                  "save_fbp_vol": params.save_fbp_vol, "save_3sino": params.save_3sino,
                  "ngpu": params.ngpu, "params": params, "sino_filter": sino_filter}
        queue_fbp_in.put(kwargs)
        log = queue_fbp_out.get()
        logger, stream = new_logger_stream("Master", flush_to=logfile)
        logfile.write(log)
        logfile.flush()

    logger.info("Waiting for all parallel processes to stop...")
    for p, q in [(p_cor, queue_cor_in), (p_fbp, queue_fbp_in),
                 (p_phases2hdf5, queue_save_projections_hdf5_in),
                 (p_phases2edf, queue_save_projections_edf_in)]:
        q.put('STOP')
        p.join()
    for q in [queue_save_projections_hdf5_in, queue_save_projections_edf_in]:
        q.join()
    logger.info("            all parallel processes have returned")

    # Flush logger
    logger, stream = new_logger_stream("Master", flush_to=logfile)


def make_parser():
    epilog = """Example usage:

* ID16B, quick (binned x2) 4-distance reconstruction, Paganin, save tiff volume:
   pynx-holotomo-id16b --data data/alcu_25nm_8000adu_1_ --delta_beta 530 --save_fbp_vol tiff \
--algorithm AP**5,AP**10,Paganin --nz 4 --padding 100 --binning 2 --slurm

* ID16B, 4-distance reconstruction, CTF, save tiff volume:
   pynx-holotomo-id16b --data data/alcu_25nm_8000adu_1_ --delta_beta 530 --save_fbp_vol tiff \
--algorithm AP**5,AP**20,CTF --nz 4 --padding 200 --slurm

* ID16B, single distance reconstruction, Paganin, save tiff volume:
   pynx-holotomo-id16b --data data/alcu_25nm_8000adu_1_ --delta_beta 530 --save_fbp_vol tiff \
--algorithm AP**5,AP**20,Paganin --nz 1 --padding 200 --slurm

* ID16A, 4-distance reconstruction, CTF, save tiff volume, random motion:
   pynx-holotomo-id16a --data data/alcu_25nm_8000adu_1_ --delta_beta 530 --save_fbp_vol tiff --motion \
--algorithm AP**5,AP**20,CTF --nz 1 --padding 200 --slurm

* ID16A, 4-distance reconstruction, CTF, save tiff volume, random motion, \
use phased images for center of rotation:
   pynx-holotomo-id16a --data data/alcu_25nm_8000adu_1_ --delta_beta 530 --save_fbp_vol tiff --motion \
--algorithm AP**5,AP**20,CTF --nz 1 --padding 200 --tomo_cor_method phase_registration --slurm
       """
    if 'id16b' in sys.argv[0]:
        parser = argparse.ArgumentParser(prog="pynx-holotomo-id16b",
                                         description=__doc__,
                                         epilog=epilog,
                                         formatter_class=argparse.RawTextHelpFormatter)
    elif 'id16a' in sys.argv[0]:
        parser = argparse.ArgumentParser(prog="pynx-holotomo-id16a",
                                         description=__doc__,
                                         epilog=epilog,
                                         formatter_class=argparse.RawTextHelpFormatter)
    else:
        raise RuntimeError("What program was launched ? Should be pynx-holotomo-id16a or pynx-holotomo-id16b ?")

    parser.add_argument("--data", action='store', required=True,
                        help="Path to the dataset directory, e.g. path/to/sample_1_, "
                             "For nz>1 this should be one of the distances directories.")
    parser.add_argument("--delta_beta", action='store', type=float, required=True,
                        help="delta/beta value for CTF/Paganin")

    group = parser.add_argument_group("Input parameters")
    group.add_argument("--nz", action='store', type=int, default=None,
                       help="Number of distances (planes) for the analysis")
    group.add_argument("--planes", action='store', nargs='+', type=int, default=None,
                       help="Planes to use for the analysis. This is normally set by --nz, "
                            "e.g. '--nz 4' is equivalent to '--planes 1 2 3 4'. "
                            "This option can be used to exclude specific planes e.g. "
                            "the first distance using '--planes 2 3 4'")
    group.add_argument("--binning", action='store', type=int, default=1,
                       help="Binning parameter.")
    group.add_argument("--first", action='append', type=int, default=0,
                       help="Index of the first projection to analyse")
    group.add_argument("--last", action='append', type=int, default=None,
                       help="Index of the last (included) projection to analyse. Defaults "
                            "to the last recorded projection, automatically detected.")
    # group.add_argument("--step", action='append', type=int, default=None,
    #                    help="Step to pick the projections to analyse. Defaults to the "
    #                         "binning value. Can also be >1 to distribute projections to "
    #                         "independent jobs.")
    group.add_argument("--remove_spikes", action='store', type=float, default=0,
                       help="Replace any pixel which differs from the median-filtered (3x3) image "
                            "by more than this (relative) value. Typical value 0.04. "
                            "Filtering is done after dark subtraction.")
    group = parser.add_argument_group("Experimental parameters")
    group.add_argument("--nrj", action='store', type=float,
                       help="Energy (keV). By default will be read from the info or edf files.")
    group.add_argument("--distance", action='store', nargs='+', type=float,
                       help="Sample-detector distance(s) (m), after magnification. "
                            "Multiple values should be given if nz>1 using "
                            "'--distance 1e-3 1.2e-3 1.5e-3 1.9e-3' . By default "
                            "will be automatically calculated from .info and/or motor positions.")
    group.add_argument("--pixel_size", action='store', nargs='+', type=float,
                       help="Pixel size(s) (m), for each distance, after magnification. "
                            "Multiple values should be given if nz>1 using "
                            "'--pixel_size 1e-8 1.2e-8 1.6e-8 2.2e-8' . "
                            "By default this is computed from .info and/or edf files. "
                            "Will be multiplied by binning.")
    group.add_argument("--sx0", action='store', type=float,
                       help="focus offset parameter in mm: if given, the magnified pixel sizes "
                            "and sample-detector distances will be calculated from cx and sx.")
    group.add_argument("--pixel_size_detector", action='store', type=float, default=None,
                       help="Detectr pixel size (m), before magnification. "
                            "This can be used only in combination with sx0, to compute"
                            "the magnified distances. If not given, this is read from the "
                            "'optic_used' field in the .info file.")
    group.add_argument("--cx", action='store', type=float,
                       help="detector motor position (mm), used if sx0 is also given. "
                            "By default the value is read from edf files")
    group.add_argument("--sx", action='store', nargs='+', type=float,
                       help="sample motor position (mm), used only if sx0 is also given. "
                            "Multiple values should be given (--sx 3 4 5 6) if nz>1. "
                            "By default the value is read from edf files")

    group = parser.add_argument_group("Alignment parameters")
    group.add_argument("--rhapp", action='store', type=str, default=None,
                       help="Filename for the rhapp alignment matrix written by holoCT. "
                            "If set, this automatically sets the alignment method to using "
                            "the rhapp matrix, unless --align is also used (the rhapp matrix "
                            "can then still be used for comparison).")
    group.add_argument("--align", action='store', type=str, default=None,
                       choices=['rhapp', 'fft', 'fft_fit', 'fft_median'],
                       help="Alignment method, either 'rhapp' (default choice if --rhapp is given), "
                            "'fft' (Fourier cross-correlation), 'fft_fit' (same as fft but "
                            "with a polynomial fit, the default), 'fft_median' (same as fft but with a "
                            "median filter)")
    group.add_argument("--align_interval", action='store', type=int, default=1,
                       help="if > 1, align only every N projection, and interpolate "
                            "for the other projections. N>1 can be useful if loading & "
                            "alignment is slower than algorithms when using multiple GPUs.")
    group.add_argument("--align_fourier_low_cutoff", action='store', type=float, default=None,
                       help="Cutoff (default: None), a value (e.g. 0.01) to cutoff "
                            "low frequencies during Fourier registration.")
    group.add_argument("--reference_plane", action='store', type=int, default=None,
                       choices=[1, 2, 3, 4],
                       help="Reference plane for the alignment, for nz>1. "
                            "Default is the first plane. Numbering begins at 1, same as "
                            "for the --planes argument")
    group.add_argument("--motion", action='store_true', default=False,
                       help="Take into account the random motion 'correct.txt' data during alignment")

    group = parser.add_argument_group("Reconstruction parameters")
    group.add_argument("--algorithm", action='store', type=str, default="AP**5,delta_beta=0,AP**10,CTF",
                       help="Algorithm string, executed from right to left.\n"
                            "Examples:\n"
                            "   AP**5,delta_beta=0,AP**10,Paganin\n"
                            "   AP**10,obj_smooth=0.5,AP**20,obj_smooth=8,CTF\n"
                            "   AP**10,obj_smooth=0.5,Paganin\n"
                            "   AP**10,obj_smooth=0.5,DRAP**10,obj_smooth=1,CTF")
    group.add_argument("--padding", action='store', type=int, default=0,
                       help="Number of pixels to pad on every side of the array. 20%% of the "
                            "size is normally enough. The value will be rounded to the closest "
                            "value which allows a radix FFT. 2 can be given as a special value "
                            "and will correspond to a doubling of the array size, i.e. a padding "
                            "of ny/2 and nx/2 along each dimension (not recommended, this wastes "
                            "memory and computing time). NOTE: this is only the padding used "
                            "during phasing - when computing the filtered back-projection, "
                            "the padding used always doubles the array size.")
    group.add_argument("--pad_method", action='store', type=str, default='reflect_linear',
                       choices=['edge', 'mean', 'median', 'reflect', 'reflect_linear', 'reflect_erf',
                                'reflect_sine', 'symmetric', 'wrap'],
                       help="Padding method. See methods description in numpy.pad() online doc, with "
                            "'reflec_' additional methods, using a linear combination of wrapped "
                            "reflected arrays across the border, to guarantee a continuity across"
                            "the edges of the padded array.")
    group.add_argument("--stack_size", action='store', type=int, default=None,
                       help="Stack size-how many projections are stored simultaneously in memory. "
                            "This is normally automatically adapted to the available GPU memory.")
    group.add_argument("--obj_smooth", action='store', type=float, default=0,
                       help="Object smoothing parameter in pixels. Can be overridden in the "
                            "algorithm string. Requires obj_inertia>0 to work. Note that it does "
                            "not smooth directly the object by a gaussian of the given sigma, but "
                            "rather regularises the updated object to be close to the previous "
                            "gaussian-blurred iteration.")
    group.add_argument("--obj_inertia", action='store', type=float, default=0.1,
                       help="Object inertia parameter. Can be overridden in the algorithm string. "
                            "Required > 0 to exploit smoothing, typical values 0.1 to 0.5.")
    group.add_argument("--obj_min", action='store', type=float, default=None,
                       help="Object minimum amplitude (e.g. 1)")
    group.add_argument("--obj_max", action='store', type=float, default=None,
                       help="Object maximum amplitude (e.g. 1)")
    group.add_argument("--beta", action='store', type=float, default=0.9,
                       help="Beta coefficient for the RAAR or DRAP algorithm. Not to be mistaken"
                            "with the refraction index beta")
    group.add_argument("--nb_probe", action='store', type=int, default=1, choices=[1, 2, 3, 4],
                       help="Number of coherent probe modes. If >1, the (constant) probe mode coefficients"
                            "will linearly vary as a function of the projection index. "
                            "EXPERIMENTAL")
    # group.add_argument("--liveplot", action='store', type=float, default=1,
    #                    help="Liveplot of phasing step. Currently unused")

    group = parser.add_argument_group("Tomography reconstruction parameters")
    # group.add_argument("--tomo_angular_step", action='store', type=float, default=None,
    #                    help="Angular step (degrees), not taking into account binning. "
    #                         "By default, is determined automatically from data files.")
    group.add_argument("--tomo_rot_center", action='store', type=float, default=None,
                       help="Rotation centre (pixels), not taking into account binning. "
                            "By default, this is determined automatically.")
    group.add_argument("--tomo_cor_method", action='store', type=str, default='iobs_sliding_window',
                       choices=["iobs_sliding_window", "sino", "phase_registration"],
                       help="Method used for the determination of the centre of rotation, either "
                            "'iobs_sliding_window' (standard sliding windows on iobs, "
                            "'sino' using Nabu's SinoCOR, or 'phase_registration' using "
                            "cross-correlation of phased projections at +180°")
    group.add_argument("--sino_filter", action='store', type=str, default='ram-lak',
                       choices=['ram-lak', 'shepp-logan', 'cosine', 'hamming', 'hann',
                                'tukey', 'lanczos'],
                       help="Name for the sinogram 1D filter before back-projection. "
                            "The filtering will be done using the padded array if "
                            "padding is used.")

    group = parser.add_argument_group("Output parameters")
    group.add_argument("--prefix_output", action='store', default=None,
                       help="Prefix for the output (can include a full directory path), "
                            "Defaults to data_prefix_resultNN")
    group.add_argument("--save_phase_chunks", action='store_true', default=False,
                       help="Save the phased projections to an hdf5 file")
    group.add_argument("--save_edf", action='store_true', default=False,
                       help="Save the phased projections individual edf files.")
    group.add_argument("--save_fbp_vol", action='store', default=False, type=str, choices=['hdf5', 'tiff', 'cuts'],
                       help="Save the reconstructed volume, either using the hdf5 format (float16), "
                            "uint16 tiff (single file if volume<4GB, individual slices otherwise),"
                            "or only with a png with the 3 cuts (mostly for testing)")
    group.add_argument("--save_3sino", action='store_true', default=False,
                       help="Save the 3 sinograms at 25, 50 and 75%% of height for testing.")
    # Hidden option to output volume from holoct reconstructed projections. The supplied string must
    # be the file pattern for holoct result files. The program will exit immediately after
    # reconstructing the 'holoct' volume.
    group.add_argument("--holoct", action='store', default=None, type=str,
                       help=argparse.SUPPRESS)

    group = parser.add_argument_group("Job parameters")
    group.add_argument("--ngpu", action='store', default=None, type=int,
                       help="number of GPUs the job should be distributed to (default=2"
                            "for a slurm job, otherwise defaults to 1)")
    group.add_argument("--slurm", action='store_true', default=False,
                       help="If given, the job will be submitted as an ESRF slurm job (p9gpu)")
    group.add_argument("--slurm_partition", action='store', default='p9gpu',
                       choices=['p9gpu', 'p9gpu-long', 'low-p9gpu'],
                       help="Specify the slurm partition to use (for long jobs, default is p9gpu, 1 hour max)")
    group.add_argument("--slurm_nodelist", action='store', default=None,
                       help="Specify a nodelist for the slurm job submission (mostly for debugging or "
                            "very specific needs e.g. memory). The string will be passed directly to "
                            "sbatch --nodelist=...")
    group.add_argument("--slurm_time", action='store', default=None,
                       help="Specify a time limit for the slurm job (mostly for debugging). "
                            "The string will be passed directly to sbatch --time=... . "
                            "By default the time limit is automatically chosen.")
    # group.add_argument("--serial", action='store_true', default=False,
    #                    help="If given, the analysis will be run with multiprocessing (for debugging)")
    group.add_argument("--max_mem_chunk", action='store', default=64, type=int,
                       help="Maximum memory (GB) available per chunk (ngpu chunks are processed in //)."
                            "The actual memory used should be about 1.5x this value. This should be kept"
                            "reasonably low because the process uses a lot of shared memory, and the"
                            "analysis can slow down significantly if too much is used.")
    group.add_argument("--profile", action='store_true', default=False,
                       help="Enable GPU profiling (for development only)")
    return parser


def main():
    params = make_parser().parse_args()
    # Use planes or nz
    if params.planes is not None:
        if params.nz is not None:
            if len(params.planes) != params.nz:
                raise RuntimeError("Parameters are inconsistent: --nz=%d  and --planes=%s"
                                   % (params.nz, str(params.planes)))
        params.nz = len(params.planes)
    else:
        if params.nz is None:
            params.nz = 1  # Default value if neither is set
        params.planes = list(range(1, params.nz + 1))
    # Assign or check reference plane
    if params.reference_plane is None:
        params.reference_plane = params.planes[0]
    elif params.reference_plane not in params.planes:
        raise RuntimeError("Reference plane (%d) is not in the given list of planes: %s"
                           % (params.reference_plane, str(params.planes)))
    if params.slurm:
        # Submit job to the ESRF slurm cluster
        if params.data[-1] == "/":
            params.data = params.data[:-1]
        data_dir, prefix = os.path.split(params.data)
        iz = int(prefix[-2])
        prefix = prefix[:-2]  # Remove the trailing '1_'

        if params.ngpu is None:
            params.ngpu = 2

        if params.prefix_output is None:
            i = 1
            while True:
                if params.nz > 1:
                    params.prefix_output = prefix + "result%02d" % i
                else:
                    params.prefix_output = prefix + "%d_result%02d" % (params.planes[0], i)
                if len([fn for fn in os.listdir('.') if fn.startswith(params.prefix_output)]) == 0:
                    break
                i += 1
        # Create logfile
        path = os.path.split(params.prefix_output)[0]
        if len(path):
            os.makedirs(path, exist_ok=True)
        s = "%s.slurm" % params.prefix_output
        with open(s, 'w') as slurmfile:
            slurmfile.write("#!/bin/bash -l\n")
            slurmfile.write("#SBATCH --partition=%s\n" % params.slurm_partition)
            if params.ngpu > 1:
                slurmfile.write("#SBATCH --gres=gpu:%d\n" % params.ngpu)
                slurmfile.write("#SBATCH --mem=0\n")
                # slurmfile.write("#SBATCH --cpus-per-task=128\n")
                slurmfile.write("#SBATCH --exclusive\n")
            else:
                slurmfile.write("#SBATCH --gres=gpu:1\n")
                slurmfile.write("#SBATCH --mem=256g\n")
                slurmfile.write("#SBATCH --cpus-per-task=64\n")
            slurmfile.write("#SBATCH --ntasks=1\n")
            if params.slurm_time is not None:
                slurmfile.write("#SBATCH --time=%s\n" % params.slurm_time)
            else:
                if params.slurm_partition == 'p9gpu':
                    slurmfile.write("#SBATCH --time=60\n")
                else:
                    slurmfile.write("#SBATCH --time=600\n")
            if params.slurm_nodelist is not None:
                slurmfile.write("#SBATCH --nodelist=%s\n" % params.slurm_nodelist)
            slurmfile.write("#SBATCH --output=%s.out\n" % params.prefix_output)
            slurmfile.write("env | grep CUDA\n")
            # KLUDGE to workaround possible bug in loading order for libgomp (with sklearn)
            # slurmfile.write("export LD_PRELOAD=/lib/powerpc64le-linux-gnu/libgomp.so.1\n")
            slurmfile.write("scontrol --details show jobs $SLURM_JOBID |grep RES\n")
            slurmfile.write(("cat /proc/sys/kernel/shmmax\n"))
            slurmfile.write(("cat /proc/sys/kernel/shmall\n"))
            slurmfile.write(("ipcs -m --human\n"))
            slurmfile.write(("ipcs -pm --human\n"))
            if os.getlogin() == 'favre' and platform.node() == 'scisoft15':
                # KULDGE - should be removed from production code
                slurmfile.write("source  /home/esrf/favre/dev/pynx-py38-power9-env/bin/activate\n")
            else:
                slurmfile.write("source  /sware/exp/pynx/activate_pynx.sh devel\n")
            if 'id16b' in sys.argv[0]:
                c = "pynx-holotomo-id16b"
            else:
                c = "pynx-holotomo-id16a"
            for i in range(1, len(sys.argv)):
                a = sys.argv[i]
                if 'slurm' not in a and 'slurm_' not in sys.argv[i - 1]:
                    c += " %s" % a
            if '--ngpu' not in c:
                c += " --ngpu %s" % params.ngpu
            if "prefix_output" not in c:
                c += " --prefix_output %s" % params.prefix_output
            slurmfile.write(c)
        os.system("sbatch %s" % s)
        print("Submitted slurm job with: %s\n" % s)
    else:
        if params.ngpu is None:
            params.ngpu = 1
        master(params)


if __name__ == '__main__':
    if sys.version_info < (3, 8):
        raise RuntimeError("The holotomo script requires python>=3.8")
    # set_start_method('spawn')
    main()
