#!/home/esrf/favre/dev/pynx-py38-power9-env/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2020-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

# TODO: This is a very rough, unorganised script, for testing and early processing.
#  Needs to be object-oriented with a more proper handling of parameters,
#  parallel processing so i/o, alignment, algorithms and export can happen
#  at the same time, etc...

import os
import sys
import timeit
import multiprocessing

# os.environ['PYNX_PU'] = 'cuda.0'
import numpy as np
import psutil
import matplotlib.pyplot as plt

import fabio
from tomoscan.esrf.edfscan import EDFTomoScan
from nabu.resources.dataset_analyzer import analyze_dataset
from nabu.pipeline.estimators import CORFinder
from nabu.estimation.cor import CenterOfRotation, CenterOfRotationAdaptiveSearch, CenterOfRotationSlidingWindow, \
    CenterOfRotationGrowingWindow
from pynx.holotomo import *
from pynx.holotomo.operator import *
from pynx.holotomo.utils import load_data_kw, zoom_pad_images_kw, align_images_kw, save_phase_edf_kw, \
    get_params_id16b
from pynx.utils.array import rebin
from pynx.utils.math import primes, test_smaller_primes

try:
    # Get the real number of processor cores available
    # os.sched_getaffinity is only available on some *nix platforms
    nproc = len(os.sched_getaffinity(0)) * psutil.cpu_count(logical=False) // psutil.cpu_count(logical=True)
except AttributeError:
    nproc = os.cpu_count()
print("Number of available processors: ", nproc)


def main():
    t00 = timeit.default_timer()
    ################################################################
    # Experiment parameters - should be loaded from a file
    ################################################################
    # All these will be loaded from the parameters file
    params = {'i0': None,
              'data': None,
              'prefix_result': None,
              'nz': 1,
              'reference_plane': 0,
              'padding': 0,
              'projection_range': None,
              'delta_beta': None,
              'wavelength': None,
              'detector_distance': None,
              'algorithm': "AP**20,obj_smooth=0.5,CTF",
              'stack_size': None,
              'save_phase_chunks': False,
              'save_edf': False,
              'align_rhapp': None,
              'obj_smooth': 0,
              'obj_min': None,
              'obj_max': None,
              # Beta coefficient for the RAAR or DRAP algorithm. Not to be mistaken
              # with the refraction index beta
              'beta': 0.9,
              # Number of coherent probe modes. If >1, the (constant) probe mode coefficients
              # will linearly vary as a function of the projection index
              'nb_probe': 1,
              # Effective Pixel size in meters (it will be multiplied by binning)
              'pixel_size': None,
              'binning': 1,
              # If True, save the FBP volume computed with Nabu.
              # A string can also be given with "tiff" and the volume will be exported
              # as single tiff file with uint16 data type
              'save_fbp_vol': False,
              'save_3sino': False,
              'tomo_angular_step': None,  # in degrees
              'tomo_rot_center': None,  # in pixels
              'sx': None,  # Sample position (ID16) in mm
              'sx0': None,  # Focus position (ID16) in mm
              'cx': None,  # Detector position (ID16) in mm
              'show_obj_probe': 0  # live display ?
              }

    ################################################################
    # Read parameters - should be loaded from a file
    ################################################################
    # Look for the first "*.par" file given as argument
    for v in sys.argv:
        if len(v) < 4:
            continue
        if v[-4:] == '.par':
            print("Loading parameters from: %s" % v)
            dic = {}
            exec(open(v).read(), globals(), dic)
            # There must be a better way to do this ??
            for k, v in dic.items():
                if k in params.keys():
                    params[k] = v
            break

    # load parameters from command-line args, at least i0=N
    for arg in sys.argv:
        if '=' in arg:
            k = arg.split('=')[0]
            v = arg[len(k) + 1:]  # Allows to handle strings like "algorithm=AP**10,obj_smooth=0.5,CTF"
            if k in ['i0', 'nz', 'binning', 'reference_plane', 'stack_size', 'show_obj_probe', 'save_phase_chunks']:
                params[k] = eval(v)
                print(arg)
            elif k in ['wavelength', 'obj_smooth', 'obj_min', 'obj_max', 'beta',
                       'tomo_angular_step', 'tomo_rot_center', 'sx0', 'sx', 'cx', 'delta_beta']:
                params[k] = float(v)
                print(arg)
            elif k in ['detector_distance', 'projection_range', 'pixel_size']:  # [istart, iend[ , step
                params[k] = eval(v)
                print(arg)
            elif k in params:
                params[k] = v
            else:
                print("Did not understand command-line parameter: ", arg)
                sys.exit(1)

    print("Parameters:")
    for k, v in params.items():
        print("     %s: " % k, v)

    i0 = params['i0']
    data_src = params['data']
    prefix_result = params['prefix_result']
    nz = params['nz']
    reference_plane = params['reference_plane']
    padding = params['padding']
    projection_range = params['projection_range']
    delta_beta = params['delta_beta']
    wavelength = params['wavelength']
    detector_distance = params['detector_distance']
    algorithm = params['algorithm']
    stack_size = params['stack_size']
    save_phase_chunks = params['save_phase_chunks']
    save_edf = params['save_edf']
    align_rhapp = params['align_rhapp']
    obj_smooth = params['obj_smooth']
    obj_min = params['obj_min']
    obj_max = params['obj_max']
    beta = params['beta']
    nb_probe = params['nb_probe']
    pixel_size = params['pixel_size']
    binning = params['binning']
    save_fbp_vol = params['save_fbp_vol']
    save_3sino = params['save_3sino']
    tomo_angular_step = params['tomo_angular_step']
    tomo_rot_center = params['tomo_rot_center']
    sx = params['sx']
    sx0 = params['sx0']
    cx = params['cx']
    show_obj_probe = params['show_obj_probe']

    if i0 is None:
        print("WARNING: i0 was not set => using i0=0. You must use i0=N from the command-line\n"
              "         when splitting the analysis with several jobs")
        i0 = 0
    else:
        print("i0=%d" % i0)

    if delta_beta is None:
        print("You must supply a delta_beta value")
        sys.exit(1)

    # Analysing data name
    if data_src[-1] == "/":
        data_src = data_src[:-1]
    data_dir, prefix = os.path.split(data_src)
    prefix = prefix[:-2]  # Remove the trailing '1_'

    if prefix_result is None:
        i = 1
        while True:
            prefix_result = prefix + "result%02d" % i
            if len([fn for fn in os.listdir('.') if fn.startswith(prefix_result)]) == 0:
                break
            i += 1
        print('No prefix_result given, will use: ', prefix_result)

    if projection_range is None:
        nb_proj = EDFTomoScan.get_tomo_n(data_src)
        print("No projection_range given - using all %d projections" % nb_proj)
        proj_idx = list(EDFTomoScan.get_proj_urls(data_src, n_frames=1).keys())
        proj_idx.sort()
        proj_idx = proj_idx[:nb_proj]
        if binning > 1:
            proj_idx = proj_idx[::binning]
    else:
        print("projection_range: ", projection_range)
        proj_idx = np.arange(i0 + projection_range[0], projection_range[1], projection_range[2])
    nb_proj = len(proj_idx)  # number of images loaded (excluding dark and empty_beam images)
    print("nb_proj=%d" % nb_proj)

    # Gather magnified pixel sizes and propagation distances
    params = get_params_id16b(data_src, nz=nz, sx0=sx0, cx=cx, sx=sx, verbose=True)

    if wavelength is None:
        wavelength = params['wavelength']

    if detector_distance is None:
        detector_distance = params['distance']
    else:
        if np.isscalar(detector_distance):
            detector_distance = [detector_distance]
        detector_distance = np.array(detector_distance)

    if pixel_size is None:
        pixel_size_z = params['pixel_size']
    else:
        if np.isscalar(pixel_size):
            pixel_size = [pixel_size]
        pixel_size_z = np.array(pixel_size)
    pixel_size_z *= binning
    pixel_size = pixel_size_z[0]

    magnification = pixel_size_z[0] / pixel_size_z

    print("Detector distance:", detector_distance)
    print("Pixel sizes:", pixel_size_z)

    ################################################################
    # Prepare for output
    ################################################################

    path = os.path.split(prefix_result)[0]
    if len(path):
        os.makedirs(path, exist_ok=True)

    ################################################################
    # Load data in //
    ################################################################
    t0 = timeit.default_timer()

    dark = None
    for iz in range(0, nz):
        dark_url = EDFTomoScan.get_darks_url("%s%d_" % (data_src[:-2], iz + 1))
        dark_name = dark_url[0].file_path()
        print("Loading dark: ", dark_name)
        d = fabio.open(dark_name).data
        if binning is not None:
            if binning > 1:
                d = rebin(d, binning)
        ny, nx = d.shape
        if dark is None:
            dark = np.empty((nz, ny, nx), dtype=np.float32)
        dark[iz] = d

    ny, nx = dark.shape[-2:]
    print("Frame size: %d x %d" % (ny, nx))

    # Test if we have a radix transform
    primesy, primesx = primes(ny + 2 * padding), primes(nx + 2 * padding)
    if max(primesx + primesy) > 13:
        padup = padding
        while not test_smaller_primes(ny + 2 * padup, required_dividers=[2]) or \
                not test_smaller_primes(nx + 2 * padup, required_dividers=[2]):
            padup += 1
        paddown = padding
        while not test_smaller_primes(ny + 2 * paddown, required_dividers=[2]) or \
                not test_smaller_primes(nx + 2 * paddown, required_dividers=[2]):
            paddown -= 1
        raise RuntimeError("The dimensions (with padding=%d) are incompatible with a radix FFT:\n"
                           "  ny=%d primes=%s  nx=%d primes=%s (should be <=13)\n"
                           "  Closest acceptable padding values: %d or %d" %
                           (padding, ny + 2 * padding, str(primesy),
                            nx + 2 * padding, str(primesx), paddown, padup))

    ref = np.empty_like(dark)
    for iz in range(0, nz):
        for k, v in EDFTomoScan.get_flats_url("%s%d_" % (data_src[:-2], iz + 1)).items():
            print("Loading empty reference image [iz=%d, idx=%d]: " % (iz, k), v.file_path())
            d = fabio.open(v.file_path()).data
            if binning is not None:
                if binning > 1:
                    d = rebin(d, binning)
            ref[iz] = d - dark[iz]
            # TODO: load multiple empty beam images (begin/end)
            break

    # EDFTomoScan scans the directory, but we use a pattern
    img_name = data_src[:-2] + "%d_/" + prefix + "%d_%04d.edf"
    print("Projections images: ", img_name)
    vkw = [{'i': i, 'dark': dark, 'nz': nz, 'img_name': img_name, 'binning': binning} for i in proj_idx]
    load_data_kw(vkw[0])
    iobs = np.empty((nb_proj, nz, ny, nx), dtype=np.float32)
    with multiprocessing.Pool(nproc) as pool:
        results = pool.imap(load_data_kw, vkw)  # , chunksize=1
        for i in range(len(vkw)):
            iobs[i] = results.next(timeout=20)

    dt = timeit.default_timer() - t0
    print("Time to load & uncompress data: %4.1fs [%6.2f Mbytes/s]" %
          (dt, (iobs.nbytes + dark.nbytes + ref.nbytes) * binning ** 2 / dt / 1024 ** 2))

    ################################################################
    # Zoom & register images, keep first distance pixel size and size
    # Do this using multiple process to speedup
    ################################################################
    if nz > 1:
        # TODO: zoom (linear interp) & registration could be done on the GPU once the images are loaded ?

        print("Magnification relative to iz=0: ", magnification / magnification[0])
        t0 = timeit.default_timer()

        vkw = [{'x': ref, 'magnification': magnification, 'padding': padding, 'nz': nz}] + \
              [{'x': iobs[i], 'magnification': magnification, 'padding': padding, 'nz': nz} for i in range(nb_proj)]
        res = []
        with multiprocessing.Pool(nproc) as pool:
            results = pool.imap(zoom_pad_images_kw, vkw)
            for i in range(nb_proj + 1):
                r = results.next(timeout=20)
                res.append(r)

        del iobs
        gc.collect()

        ref_zoom = res[0]
        iobs_align = np.array(res[1:], dtype=np.float32)
        ny, nx = iobs_align.shape[-2:]

        del res
        print("Zoom & pad images: dt = %6.2fs" % (timeit.default_timer() - t0))

        print("Pixel size after magnification: %6.3fnm" % (pixel_size * 1e9))

        # Align images
        t0 = timeit.default_timer()

        if align_rhapp is None:
            print("Aligning images")

            # This can sometimes (<1 in 10) fail (hang indefinitely). Why ?
            # res = pool.map(align_images, range(nb_proj))
            if padding:
                vkw = [{'x': iobs_align[i, :, padding:-padding, padding:-padding],
                        'x0': ref_zoom[:, padding:-padding, padding:-padding], 'nz': nz} for i in range(nb_proj)]
            else:
                vkw = [{'x': iobs_align[i], 'x0': ref_zoom, 'nz': nz} for i in range(nb_proj)]
            align_images_kw(vkw[0])
            align_ok = False
            nb_nok = 0
            while not align_ok:
                if nb_nok >= 4:
                    print("4 failures, bailing out")
                    sys.exit(1)
                try:
                    res = []
                    with multiprocessing.Pool(nproc) as pool:
                        results = pool.imap(align_images_kw, vkw, chunksize=1)
                        for i in range(nb_proj):
                            r = results.next(timeout=20)
                            res.append(r)
                    align_ok = True
                    print("align OK", len(res))
                except:
                    print("Timeout, re-trying")
                    nb_nok += 1

            # print(res)

            dx = np.zeros((nb_proj, nz))
            dy = np.zeros((nb_proj, nz))
            for i in range(len(iobs_align)):
                dx[i] = res[i][0]
                dy[i] = res[i][1]

            # Use polyfit to smooth shifts #####################################
            # TODO: use shift corrections common to all parallel optimisations with a prior determination ?
            dx, dy, dxraw, dyraw = dx.copy(), dy.copy(), dx, dy
            for iz in range(1, nz):
                polx = np.polynomial.polynomial.polyfit(proj_idx[:-1], dx[:-1, iz], 6)
                poly = np.polynomial.polynomial.polyfit(proj_idx[:-1], dy[:-1, iz], 6)
                dx[:, iz] = np.polynomial.polynomial.polyval(proj_idx, polx)
                dy[:, iz] = np.polynomial.polynomial.polyval(proj_idx, poly)

            if True:
                # Plot shift of images vs first distance
                plt.figure(figsize=(12.5, 4))
                for iz in range(1, nz):
                    plt.subplot(1, nz - 1, iz)
                    plt.plot(proj_idx[:-1], dxraw[:-1, iz], 'r.', label='x')
                    plt.plot(proj_idx[:-1], dyraw[:-1, iz], 'b.', label='y')
                    plt.plot(proj_idx[:-1], dx[:-1, iz], 'r', label='x')
                    plt.plot(proj_idx[:-1], dy[:-1, iz], 'b', label='y')

                    plt.title("Alignment iz=%d vs iz=0 [PyNX]" % (iz))
                    plt.legend()
                plt.tight_layout()
                plt.savefig(prefix_result + '_i0=%04d_shifts.png' % i0)
                np.savez_compressed(prefix_result + '_i0=%04d_shifts.npz' % i0, dx=dx, dy=dy, dxraw=dxraw, dyraw=dyraw)
        else:
            print("Aligning images: using shift imported from holoCT")
            # Load alignment shifts from rhapp (holoCT)
            nb = np.loadtxt(align_rhapp, skiprows=4, max_rows=1, dtype=int)[2]
            m = np.loadtxt(align_rhapp, skiprows=5, max_rows=nb * 8, dtype=np.float32).reshape((nb, 4, 2))
            dx = m[..., 1]
            dy = m[..., 0]
            tmp_idx = proj_idx
            tmp_idx[-1] = 0

            dx = np.take(dx.copy(), tmp_idx, axis=0)
            dy = np.take(dy.copy(), tmp_idx, axis=0)

        # Switch reference plane:
        if reference_plane:
            for iz in range(nz):
                if iz != reference_plane:
                    dx[:, iz] -= dx[:, reference_plane]
                    dy[:, iz] -= dy[:, reference_plane]
            dx[:, reference_plane] = 0
            dy[:, reference_plane] = 0

        print("Align images: dt = %6.2fs" % (timeit.default_timer() - t0))
    else:
        ref_zoom = ref
        iobs_align = iobs
        dx, dy = None, None

    ################################################################
    # Use coherent probe modes ?
    ################################################################
    if nb_probe > 1:
        # Use linear ramps for the probe mode coefficients
        coherent_probe_modes = np.zeros((nb_proj, nz, nb_probe))
        dn = nb_proj // (nb_probe - 1)
        for iz in range(nz):
            for i in range(nb_probe - 1):
                coherent_probe_modes[i * dn:(i + 1) * dn, iz, i] = np.linspace(1, 0, dn)
                coherent_probe_modes[i * dn:(i + 1) * dn, iz, i + 1] = np.linspace(0, 1, dn)
    else:
        coherent_probe_modes = False

    print("################################################################")
    print(" Initialising GPU")
    print("################################################################")
    # Creating an holotomo operator will force the initialisation
    # of the ProcessingUnit, which can be passed to HolotomoData and
    # Holotomo to directly allocate pinned memory.
    pu = ScaleObj().processing_unit  # Also passed later to nabu

    if stack_size is None:
        # Estimate how much memory will be used
        mem = pu.cu_device.total_memory()
        print("Available memory: %6.2fGB  data size: %6.2fGB" % (mem / 1024 ** 3, iobs_align.nbytes / 1024 ** 3))
        # Required memory: data, iobs, object (complex + phase), 1 copy of psi per projection (DRAP/DM/RAAR)
        nproj, nz, ny, nx = iobs_align.shape
        mem_req = iobs_align[0].nbytes
        mem_req += 12 * nx * ny  # object projections & phase
        if True:  # 'RAAR' in algorithm or 'DM' in algorithm or 'DRAP' in algorithm:
            mem_req += 2 * iobs_align[0].nbytes
        print("Estimated memory requirement: %8.4fGB/projection" % (mem_req / 1024 ** 3))
        # NB: we need two stacks (one for computing, the other for swap)
        stack_size = int(np.ceil((mem - 0.5 * 1024 ** 3) / (mem_req * 4)))
        if nproj // stack_size == 2:
            stack_size = int(np.ceil(nproj // 3))
        print("Using stack size = %d" % stack_size)

    ################################################################
    # Create HoloTomoData and HoloTomo objects
    ################################################################
    data = HoloTomoData(iobs_align, ref_zoom, pixel_size_detector=pixel_size,
                        wavelength=wavelength, detector_distance=detector_distance,
                        stack_size=stack_size,
                        dx=dx, dy=dy, idx=proj_idx, padding=padding, pu=pu)
    # Create PCI object
    p = HoloTomo(data=data, obj=None, probe=None, coherent_probe_modes=coherent_probe_modes, pu=pu)
    dt = timeit.default_timer() - t00
    print("Elapsed time since beginning:  %4.1fs" % dt)

    ################################################################
    # Algorithms
    ################################################################
    t0 = timeit.default_timer()
    # print(nz, 1, ny, nx, p.data.nz, p.nb_probe, p.data.ny, p.data.nx)
    p.set_probe(np.ones((nz, 1, ny, nx)))
    p = ScaleObjProbe() * p

    db = delta_beta
    update_obj = True
    update_probe = True
    for algo in algorithm.split(",")[::-1]:
        if "=" in algo:
            print("Changing parameter? %s" % algo)
            k, v = algo.split("=")
            if k == "delta_beta":
                db = eval(v)
                if db == 0:
                    db = None
                elif db == 1:
                    db = delta_beta
                else:
                    delta_beta = db
            elif k == "beta":
                beta = eval(v)
            elif k == "obj_smooth":
                obj_smooth = eval(v)
            elif k == "obj_min":
                obj_min = eval(v)
            elif k == "obj_max":
                obj_max = eval(v)
            elif k == "probe":
                update_probe = eval(v)
            elif k == "obj":
                update_obj = eval(v)
            else:
                print("WARNING: did not understand algorithm step: %s" % algo)
        elif "paganin" in algo.lower():
            print("Paganin back-projection")
            p = BackPropagatePaganin(delta_beta=delta_beta) * p
            p.set_probe(np.ones((nz, 1, ny, nx)))
            p = ScaleObjProbe() * p
        elif "ctf" in algo.lower():
            print("CTF back-projection")
            p = BackPropagateCTF(delta_beta=delta_beta) * p
            # p.set_probe(np.ones((nz, 1, ny, nx)))
            p = ScaleObjProbe() * p
        else:
            print("Algorithm step: %s" % algo)
            dm = DM(update_object=update_obj, update_probe=update_probe,
                    calc_llk=10, delta_beta=db, obj_min=obj_min, obj_max=obj_max,
                    reg_obj_smooth=obj_smooth, weight_empty=1, show_obj_probe=show_obj_probe)
            ap = AP(update_object=update_obj, update_probe=update_probe,
                    calc_llk=10, delta_beta=db, obj_min=obj_min, obj_max=obj_max,
                    reg_obj_smooth=obj_smooth, weight_empty=1, show_obj_probe=show_obj_probe)
            apn = AP(update_object=update_obj, update_probe=update_probe,
                     calc_llk=10, delta_beta=db, obj_min=obj_min, obj_max=obj_max,
                     reg_obj_smooth=obj_smooth, weight_empty=1, show_obj_probe=show_obj_probe,
                     iobs_normalise=True)
            raar = RAAR(update_object=update_obj, update_probe=update_probe, beta=beta,
                        calc_llk=10, delta_beta=db, obj_min=obj_min, obj_max=obj_max,
                        reg_obj_smooth=obj_smooth, weight_empty=1, show_obj_probe=show_obj_probe)
            drap = DRAP(update_object=update_obj, update_probe=update_probe, beta=beta,
                        calc_llk=10, delta_beta=db, obj_min=obj_min, obj_max=obj_max,
                        reg_obj_smooth=obj_smooth, weight_empty=1, show_obj_probe=show_obj_probe)
            p = eval(algo.lower()) * p

    p = FreePU() * MemUsage() * p
    print("Algorithms: dt = %6.2fs" % (timeit.default_timer() - t0))

    if save_phase_chunks:
        filename_prefix = prefix_result + '_i0=%04d'
        print("################################################################")
        print(" Saving phased projections to hdf5 file: " + filename_prefix)
        print("################################################################")
        t0 = timeit.default_timer()
        p.save_obj_probe_chunk(chunk_prefix=filename_prefix, save_obj_phase=True,
                               save_obj_complex=False, save_probe=True, dtype=np.float16,
                               verbose=True, crop_padding=True, process_parameters=params)

        dt = timeit.default_timer() - t0
        print("Time to save phases:  %4.1fs" % dt)

    if save_edf:
        print("################################################################")
        print(" Saving phased images to edf files: " + prefix_result + "_%04d.edf")
        print("################################################################")
        t0 = timeit.default_timer()

        # Somehow getting the unwrapped phases in // using multiprocessing does not work
        # ... the process seems to hang, even with just a fork... Side effect of pinned memory ?
        ph = p.get_obj_phase_unwrapped(crop_padding=True, dtype=np.float32, idx=proj_idx)[1]
        print("Got unwrapped phases in %4.1fs" % (timeit.default_timer() - t0))
        vkw = [{'idx': proj_idx[i], 'ph': ph[i], 'prefix_result': prefix_result} for i in range(nb_proj - 1)]
        with multiprocessing.Pool(nproc) as pool:
            pool.map(save_phase_edf_kw, vkw)

        dt = timeit.default_timer() - t0
        print("Time to unwrap & save phases:  %4.1fs" % dt)

    if save_fbp_vol:
        print("################################################################")
        print(" 3D Tomography reconstruction (FBP) with Nabu")
        print("################################################################")
        t0 = timeit.default_timer()
        idx, ph = p.get_obj_phase_unwrapped(crop_padding=True, dtype=np.float32)
        # Free some memory before going further
        del p, data
        gc.collect()
        gc.collect()

        # print("Got unwrapped phases in %4.1fs" % (timeit.default_timer() - t0))
        if tomo_angular_step is None:
            scan_range, tomo_n = EDFTomoScan.get_scan_range(data_src), EDFTomoScan.get_tomo_n(data_src)
            tomo_angular_step_orig = scan_range / (tomo_n - 1)  # Keep in degrees to match user input
            # Take into account real indices used - could be 180 or 360°...
            print(iobs_align.shape, tomo_angular_step_orig, idx.max(), idx.min(), len(idx))
            tomo_angular_step = tomo_angular_step_orig * (idx.max() - idx.min()) / (len(idx) - 1)
            print("Calculated angular step: %6.2f° / %d = %6.4f°" % (scan_range, tomo_n, tomo_angular_step))
        else:
            print("Angular step: %6.4f°" % tomo_angular_step)

        if tomo_rot_center is None:
            print("No tomo_rot_center given - estimating rotation center using tomotools/nabu")
            if False:
                # Get 0/180° frames
                # im0, im1 = ph[0], np.fliplr(ph[int(round(180 / tomo_angular_step))])
                # Is using the original radios more stable ?
                im0 = iobs_align[0, reference_plane]
                im1 = np.fliplr(iobs_align[int(round(180 / tomo_angular_step)), reference_plane])
                if padding:
                    im0 = im0[padding:-padding, padding:-padding]
                    im1 = im1[padding:-padding, padding:-padding]

                cor_sliding = CenterOfRotationSlidingWindow().find_shift(im0, im1, side="center") + im0.shape[1] / 2
                print(" ...estimated CoR [sliding window]: %7.2f" % cor_sliding)
                cor_growing = CenterOfRotationGrowingWindow().find_shift(im0, im1, side="center") + im0.shape[1] / 2
                print(" ...estimated CoR [growing window]: %7.2f" % cor_growing)
                tomo_rot_center = cor_sliding
            else:
                di = analyze_dataset("/data/visitor/im98/id16b/20220614/nf_screw_ht_25nm_1_")
                di.rotation_angles = np.linspace(0, np.deg2rad(len(idx) * tomo_angular_step),
                                                 len(di.projections), True)
                cf = CORFinder(di)
                cor_sliding = cf.find_cor(method="sliding-window") / binning
                print(" ...estimated CoR [sliding window]: %7.2f" % cor_sliding)
                # cor_growing = cf.find_cor(method="growing-window") / binning
                # print(" ...estimated CoR [growing window]: %7.2f" % cor_growing)
                tomo_rot_center = cor_sliding

        print("Using center of rotation = %7.2f" % tomo_rot_center)

        from nabu.reconstruction.fbp import Backprojector

        nproj, ny, nx = ph.shape
        vol = np.empty((ny, nx, nx), dtype=np.float16)
        # Mask outside of reconstruction using NaN
        ix, iy = np.meshgrid(np.arange(nx) - nx / 2, np.arange(nx) - nx / 2)
        r2 = ix ** 2 + iy ** 2
        idx_nan = r2 > (nx * nx / 4)
        B = Backprojector((nproj, nx), rot_center=tomo_rot_center,
                          angles=np.deg2rad(np.arange(nproj) * tomo_angular_step),
                          filter_name=None, cuda_options={'ctx': pu.cu_ctx})
        sys.stdout.write('Reconstructing %d slices (%d x %d): ' % (ny, nproj, nx))
        for i in range(ny):
            sino = ph[:, i, :].copy()
            if i % 50 == 0:
                sys.stdout.write('%d..' % (ny - i))
                sys.stdout.flush()
            res = B.fbp(sino)
            res[idx_nan] = np.nan
            vol[i] = res
        print("\n")
        del B

        dt = timeit.default_timer() - t0
        print("Time to perform FBP reconstruction:  %4.1fs" % dt)

        print("################################################################")
        print(" Saving central XYZ cuts of volume to: %s_volXYZ.png" % prefix_result)
        print("################################################################")
        plt.figure(figsize=(24, 8))
        plt.subplot(131)
        c = vol[ny // 2]
        vmin, vmax = np.nanpercentile(c, (0.1, 99.9))
        plt.imshow(c, vmin=vmin, vmax=vmax, cmap='gray')
        plt.title(os.path.split(prefix_result)[-1] + "[XY]")

        plt.subplot(132)
        c = vol[:, nx // 2]
        vmin, vmax = np.nanpercentile(c, (0.1, 99.9))
        plt.imshow(c, vmin=vmin, vmax=vmax, cmap='gray')
        plt.title(os.path.split(prefix_result)[-1] + "[XZ]")
        plt.colorbar()

        plt.subplot(133)
        c = vol[..., nx // 2]
        vmin, vmax = np.nanpercentile(c, (0.1, 99.9))
        plt.imshow(c, vmin=vmin, vmax=vmax, cmap='gray')
        plt.colorbar()
        plt.title(os.path.split(prefix_result)[-1] + "[YZ]")

        plt.tight_layout()
        plt.savefig(prefix_result + "_volXYZ.png")

        if isinstance(save_fbp_vol, str):
            if "tif" in save_fbp_vol.lower():
                save_fbp_vol = "tiff"
            else:
                save_fbp_vol = "hdf5"
        else:
            save_fbp_vol = "hdf5"
        if save_fbp_vol == "tiff":
            print("################################################################")
            print(" Exporting FBP volume as an uint16 TIFF file [SLOW]")
            print("################################################################")
            t1 = timeit.default_timer()
            from tifffile import imwrite
            filename = prefix_result + "_vol.tiff"
            vmin = np.nanmin(vol)
            vmax = np.nanmax(vol)
            volint = ((vol - vmin) / (vmax - vmin) * 2 ** 16 - 1).astype(np.uint16)
            pxum = pixel_size * 1e6
            imwrite(filename, volint, imagej=True, resolution=(1 / pxum, 1 / pxum),
                    metadata={'spacing': pxum, 'unit': 'um', 'hyperstack': False}, maxworkers=nproc)
            dt = timeit.default_timer() - t1
            print("Finished saving volume to: %s" % filename)
            print("Time to export volume as tiff:  %4.1fs" % dt)
        else:
            print("################################################################")
            print(" Saving FBP volume as an hdf5/float16")
            print("################################################################")
            t1 = timeit.default_timer()
            import h5py as h5
            filename = prefix_result + "pynx_vol.h5"
            f = h5.File(filename, "w")
            f.attrs['creator'] = 'PyNX'
            # f.attrs['NeXus_version'] = '2018.5'  # Should only be used when the NeXus API has written the file
            f.attrs['HDF5_Version'] = h5.version.hdf5_version
            f.attrs['h5py_version'] = h5.version.version
            f.attrs['default'] = 'entry_1'

            entry_1 = f.create_group("entry_1")
            entry_1.attrs['NX_class'] = 'NXentry'
            entry_1.attrs['default'] = 'data_1'
            data_1 = entry_1.create_group("data_1")
            data_1.attrs['NX_class'] = 'NXdata'
            data_1.attrs['signal'] = 'data'
            data_1.attrs['interpretation'] = 'image'
            data_1['title'] = 'PyNX (FBP:Nabu)'
            nz, ny, nx = vol.shape
            data_1.create_dataset("data", data=vol.astype(np.float16), chunks=(1, ny, nx), shuffle=True,
                                  compression="zstd")
            f.close()
            print("Finished saving %s" % filename)

            dt = timeit.default_timer() - t1
            print("Time to save hdf5 volume:  %4.1fs" % dt)

        dt = timeit.default_timer() - t0
        print("Time for 3D FBP & save volume:  %4.1fs" % dt)
    else:
        ph = None

    if save_3sino:
        if ph is None:
            ph = p.get_obj_phase_unwrapped(crop_padding=True, dtype=np.float32)[1]
        filename = prefix_result + "_3sino.npz"
        print("################################################################")
        print(" Saving 3 sinograms to: %s" % filename)
        print("################################################################")
        ny = ph.shape[1]
        np.savez_compressed(filename, sino=np.take(ph, (ny // 4, ny // 2, 3 * ny // 4), axis=1))

    dt = timeit.default_timer() - t00
    print("Elapsed time since beginning:  %4.1fs" % dt)


if __name__ == '__main__':
    main()
