# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2018-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr
import logging
import ctypes
import numpy as np
from scipy.ndimage import zoom
from scipy.fft import rfftn, irfftn
import fabio
import silx
from silx.image.tomography import compute_fourier_filter
from silx.math.medianfilter import medfilt2d
from tomoscan.esrf.edfscan import EDFTomoScan

from ..utils.registration import phase_cross_correlation
from ..utils.array import rebin, pad2
from ..wavefront import Wavefront, PropagateNearField

# Module dictionary to hold shared multiprocessing.Array when using pools
shared_arrays = {}


def remove_spikes(img, threshold=0.04):
    """
    Remove 'spikes' from an images by comparing it to its median-filtered value.
    :param img: the image to filter
    :param threshold: the threshold (as a % of the image value) above which the pixel will be
        replaced by the filtered value.
    :return: nothing. The input image is modified in-place
    """
    imgf = medfilt2d(img, 3, mode='reflect')
    img[:] = np.where(abs(img - imgf) > (threshold * imgf), imgf, img)


def simulate_probe(shape, dz, pixel_size=1e-6, wavelength=0.5e-10, nb_line_v=10,
                   nb_line_h=10, nb_spot=10, defect_amplitude=1, amplitude=100):
    """
    Create a simulated probe corresponding to a HoloTomo object size, with vertical and horizontal lines and spots
    coming from optics defects, which are then propagated.

    :param shape: the 2D shape (ny, nx) of the probe
    :param dz: array of propagation distances (m) (unrelated to holo-tomo distances)
    :param pixel_size: detector pixel size (m)
    :param wavelength: the wavelength (m)
    :param nb_line_v: number of vertical lines. Width = max 5% of the probe horizontal size.
    :param nb_line_h: number of horizontal lines. Width = max 5% of the probe vertical size
    :param nb_spot: number of spots. Radius= max 5% of the probe horizontal size
    :param defect_amplitude: the relative amplitude of the introduced optical defects (default:1).
    :param amplitude: the average amplitude of the calculated wavefront. The square amplitude shall
        correspond to the average number of incident photons per pixel.
    :return: the simulated probe, with shape (nz, nbmode=1, ny, nx)
    """
    ny, nx = shape
    nz = len(dz)
    # take convenient dimensions for the wavefront
    d = np.zeros((nz, ny, nx), dtype=np.complex64)
    for j in range(nz):
        for i in range(nb_line_v):
            w = 1 + np.random.randint(0, nx * 0.05)
            ii = np.random.randint(0, nx - w)
            t = np.random.randint(10)
            d[j, :, ii:ii + w] = t

        for i in range(nb_line_h):
            w = 1 + np.random.randint(0, ny * 0.05)
            ii = np.random.randint(0, ny - w)
            t = np.random.randint(10)
            d[j, ii:ii + w] = t
        x, y = np.meshgrid(np.arange(0, nx), np.arange(0, ny))

        for i in range(nb_spot):
            w = 1 + np.random.randint(0, nx * 0.05)
            ix = np.random.randint(w, nx - w)
            iy = np.random.randint(w, ny - w)
            r = np.sqrt((x - ix) ** 2 + (y - iy) ** 2)
            t = np.random.uniform(0, 10)

            d[j] += t * (r < w)
        w = Wavefront(d=np.fft.fftshift(np.exp(1j * 1e-2 * d[j] * defect_amplitude)), pixel_size=pixel_size,
                      wavelength=wavelength)
        w = PropagateNearField(dz=dz[j]) * w
        d[j] = w.get(shift=True).reshape(ny, nx)
        d[j] *= amplitude / abs(d[j]).mean()
    return d.reshape((nz, 1, ny, nx))


def zoom_pad_images(x, magnification, padding, nz, pad_method='reflect_linear'):
    d0 = x[0]
    pady0, padx0 = padding
    ny0, nx0 = d0.shape[0] + 2 * pady0, d0.shape[1] + 2 * padx0
    xz = np.empty((nz, ny0, nx0), dtype=np.float32)
    for iz in range(nz):
        # Zoom
        mg = magnification[0] / magnification[iz]
        if np.isclose(mg, 1):
            d = x[iz]
        else:
            d = zoom(x[iz], zoom=mg, order=1)
        ny, nx = d.shape
        # easier with even shape, symmetric padding...
        if (nx0 - nx) % 2:
            d = d[:, :-1]
            nx -= 1
        if (ny0 - ny) % 2:
            d = d[:-1]
            ny -= 1

        if ny0 <= ny and nx0 <= nx:
            xz[iz] = d[ny // 2 - ny0 // 2:ny // 2 - ny0 // 2 + ny0, nx // 2 - nx0 // 2:nx // 2 - nx0 // 2 + nx0]
        elif ny0 > ny and nx0 <= nx:
            pady = (ny0 - ny) // 2
            xz[iz] = pad2(d[:, nx // 2 - nx0 // 2:nx // 2 - nx0 // 2 + nx0], (pady, 0), mode=pad_method, mask2neg=True)
        elif ny0 <= ny and nx0 > nx:
            padx = (nx0 - nx) // 2
            xz[iz, :, nx0 // 2 - nx // 2:nx0 // 2 - nx // 2 + nx] = \
                pad2(d[ny // 2 - ny0 // 2:ny // 2 - ny0 // 2 + ny0], (0, padx), mode=pad_method, mask2neg=True)
        else:
            padx = (nx0 - nx) // 2
            pady = (ny0 - ny) // 2
            # print(d.shape, xz[iz].shape, pady, padx)
            xz[iz] = pad2(d, (pady, padx), mode=pad_method, mask2neg=True)
    return xz


def zoom_pad_images_init(iobs_array):
    """
    Initialiser for the pool, to pass the shared arrays
    :param iobs_array:
    :return:
    """
    shared_arrays["iobs_array"] = iobs_array


def zoom_pad_images_kw(kwargs):
    nx, ny, nz = kwargs['nx'], kwargs['ny'], kwargs['nz']
    nb_proj_sharray = kwargs['nb_proj_sharray']
    pady, padx = kwargs['padding']
    pad_method = kwargs['pad_method']
    iobs_array = shared_arrays['iobs_array']
    i = kwargs['i']
    iobs_shape = (nb_proj_sharray, nz, ny + 2 * pady, nx + 2 * padx)
    iobs = np.frombuffer(iobs_array.get_obj(), dtype=ctypes.c_float).reshape(iobs_shape)
    iobs[i] = zoom_pad_images(iobs[i, :, :ny, :nx], kwargs['magnification'], padding=(pady, padx),
                              nz=nz, pad_method=pad_method)


def align_images(x, x0=None, nz=4, upsample_factor=1, low_cutoff=None,
                 low_width=0.03, high_cutoff=None, high_width=0.03):
    dx = [0]
    dy = [0]
    if x.ndim == 3:
        d0 = x[0] / x0[0]
    else:
        d0 = x[1, 0] - x[:, 0].mean(axis=0)
        d0 -= np.percentile(d0, 0.001)
    # d0_mask = x[0] >= 0  # Ignore masked areas ?
    d0 = rfftn(d0)
    for iz in range(1, nz):
        #         print(i,iz)
        # Align
        if x.ndim == 3:
            d = x[iz] / x0[iz]
        else:
            # Use differential alignment
            d = x[1, iz] - x[:, iz].mean(axis=0)
            d -= np.percentile(d, 0.001)
        # d_mask = x[iz] >= 0  # Ignore masked areas ?
        # Normalization=None gives better results
        d = rfftn(d)
        # pixel_shift = phase_cross_correlation(d0, d, upsample_factor=upsample_factor,
        #                                       reference_mask=None, moving_mask=None,
        #                                       return_error=False, normalization=None)
        pixel_shift = phase_cross_correlation(d0, d, upsample_factor=upsample_factor,
                                              reference_mask=None, moving_mask=None,
                                              return_error=False, normalization=None,
                                              low_cutoff=low_cutoff, low_width=low_width,
                                              high_cutoff=high_cutoff, high_width=high_width,
                                              space='fourier', r2c=True)
        d0 = d  # Always compare to previous distance
        dx.append(pixel_shift[1] + dx[-1])
        dy.append(pixel_shift[0] + dy[-1])
    return dx, dy


def align_images_pool_init(iobs_array):
    """
    Initial shared arrays for the pool process
    :param iobs_array:
    :return:
    """
    shared_arrays["iobs_array"] = iobs_array


def align_images_kw(kwargs):
    i = kwargs['i']
    pady, padx = kwargs['padding']
    nproj, nz, ny, nx = kwargs['nproj'], kwargs['nz'], kwargs['ny'], kwargs['nx']
    iobs_array = shared_arrays['iobs_array']
    iobs_shape = (nproj, nz, ny + 2 * pady, nx + 2 * padx)
    iobs = np.frombuffer(iobs_array.get_obj(), dtype=ctypes.c_float).reshape(iobs_shape)
    if pady and padx:
        x = iobs[i - 1 if i >= 1 else i: i + 2, :, pady:-pady, padx:-padx]
    elif pady:
        x = iobs[i - 1 if i >= 1 else i: i + 2, :, pady:-pady]
    elif padx:
        x = iobs[i - 1 if i >= 1 else i: i + 2, :, :, padx:-padx]
    else:
        x = iobs[i - 1 if i >= 1 else i: i + 2]
    del kwargs['i'], kwargs['padding'], kwargs['ny'], kwargs['nx'], kwargs['nproj']
    return align_images(x, **kwargs)


def load_data(i, dark, planes, img_name, binning=None):
    d = None
    nz = len(planes)
    for iz in range(nz):
        img = fabio.open(img_name % (planes[iz], planes[iz], i)).data
        if binning is not None:
            if binning > 1:
                img = rebin(img, binning)
        if d is None:
            ny, nx = img.shape
            d = np.empty((nz, ny, nx), dtype=np.float32)
        d[iz] = img
        if dark is not None:
            d[iz] -= dark[iz]
    return d


def load_data_pool_init(iobs_array, dark_array):
    """
    Initialiser for the pool, to pass the shared arrays
    :param iobs_array:
    :param dark_array:
    :return:
    """
    shared_arrays["iobs_array"] = iobs_array
    shared_arrays["dark_array"] = dark_array


def load_data_kw(kwargs):
    nproj, nx, ny, planes = kwargs['nproj'], kwargs['nx'], kwargs['ny'], kwargs['planes']
    nz = len(planes)
    pady, padx = kwargs['padding']
    iobs_array = shared_arrays['iobs_array']
    dark_array = shared_arrays['dark_array']
    i, idx = kwargs['i'], kwargs['idx']
    spikes_threshold = kwargs['spikes_threshold']
    iobs_shape = (nproj, nz, ny + 2 * pady, nx + 2 * padx)
    iobs = np.frombuffer(iobs_array.get_obj(), dtype=ctypes.c_float).reshape(iobs_shape)
    dark = np.frombuffer(dark_array.get_obj(), dtype=ctypes.c_float).reshape((nz, ny, nx))
    d = load_data(idx, dark, planes, kwargs['img_name'], kwargs['binning'])
    if spikes_threshold > 0:
        for iz in range(nz):
            remove_spikes(d[iz], spikes_threshold)
    iobs[i, :, :ny, :nx] = d


def load_holoct_data(i, img_name, binning=None):
    img = fabio.open(img_name % i).data
    if binning is not None:
        if binning > 1:
            img = rebin(img, binning)
    return img


def load_holoct_data_pool_init(ph_array):
    """
    Initialiser for the pool, to pass the shared arrays
    :param ph_array:
    :return:
    """
    shared_arrays["ph_array"] = ph_array


def load_holoct_data_kw(kwargs):
    nproj, nx, ny = kwargs['nproj'], kwargs['nx'], kwargs['ny']
    ph_array = shared_arrays['ph_array']
    ph_array_dtype = kwargs['ph_array_dtype']
    i, idx = kwargs['i'], kwargs['idx']
    ph_shape = (nproj, ny, nx)
    ph = np.frombuffer(ph_array.get_obj(), dtype=ph_array_dtype).reshape(ph_shape)
    d = load_holoct_data(idx, kwargs['img_name'], kwargs['binning'])
    # The data shape may be larger that the original frame size, depending
    # on mcrop/ncrop values
    if d.shape[0] > ny:
        d = d[d.shape[0] // 2 - ny // 2: d.shape[0] // 2 - ny // 2 + ny]
    if d.shape[1] > nx:
        d = d[:, d.shape[1] // 2 - nx // 2: d.shape[1] // 2 - nx // 2 + nx]
    ph[i] = d


def save_phase_edf(idx, ph, prefix_result):
    edf = fabio.edfimage.EdfImage(data=ph.astype(np.float32))
    edf.write("%s_%04d.edf" % (prefix_result, idx))


def save_phase_edf_kw(kwargs):
    return save_phase_edf(**kwargs)


def get_params_id16b(scan_path, planes=(1,), sx0=None, sx=None, cx=None, pixel_size_detector=None, logger=None):
    """
    Get the pixel size(s) and magnified propagation distances from an ID16B dataset.
    By default, this will use the 'info' file parsed by EDFTomoScan, but if
    sx0 is given, the motor positions will be used instead. If sx or cx are given,
    the given values supersede those from the EDF files.

    :param scan_path: the path to the scan, either "path/to/sample_name_1_" or
        "path/to/sample_name_%d_" for multi-distance datasets.
    :param planes: the list of planes to consider, usually [1] or [1, 2, 3, 4]
    :param sx0: focus offset in mm (optional)
    :param sx: sample motor position in mm (optional). Should be iterable (list, array)
        if len(planes)>1.
    :param cx: detector motor position in mm (optional)
    :param logger: a logger object
    :return: a dictionary with keys including 'wavelength', 'pixel_size', 'distance',
        the latter two being arrays, after taking int account magnification.
    """
    if logger is None:
        logger = logging.getLogger()
    logger.info("Getting ID16B experiment parameters")
    px = []
    d = []
    result = {}
    for iz in planes:
        if "%d" not in scan_path:
            scans = scan_path[:-2] + "%d_" % iz
        else:
            scans = scan_path % iz
        s = EDFTomoScan(scans, n_frames=1)
        if iz == planes[0]:
            result['wavelength'] = 12.3984e-10 / EDFTomoScan.retrieve_information(scans, dataset_basename=None,
                                                                                  key="Energy", ref_file=None,
                                                                                  type_=float)
            result['nx'] = s.dim_1
            result['ny'] = s.dim_2
            w = result['wavelength']
            logger.info("Energy: %6.2fkeV   Wavelength: %8.5fnm" % (12.3984e-10 / w, w * 1e9))
            logger.info("Frame size: (%d, %d)" % (result['ny'], result['nx']))

        logger.info("Pixel and distances for dataset: %s" % scans)
        if sx0 is not None:
            # Compute the distances from motor positions
            edfname = "%s/%s0050.edf" % (scans, s.get_dataset_basename())
            edf = silx.io.open(edfname)
            if cx is None:
                cx1 = edf["scan_0/instrument/positioners/cx"][0]
            else:
                cx1 = cx
            if sx is None:
                sx1 = edf["scan_0/instrument/positioners/sx"][0]
            else:
                if len(planes) == 1:
                    sx1 = float(sx)
                else:
                    sx1 = sx[iz]
            # Magnified pixel detector
            try:
                px_magnified_edf = edf["scan_0/instrument/detector_0/others/pixel_size"][0] / 1e6
            except KeyError:
                logger.info("No pixel_size field available from edf files")
                px_magnified_edf = None
            if pixel_size_detector is None:
                px1 = edf["scan_0/instrument/detector_0/others/optic_used"][0] / 1e6
            else:
                px1 = pixel_size_detector

            z1 = (sx1 - sx0) / 1000  # Focus-sample distance (m)
            z2 = d0 = (cx1 - sx1) / 1000  # # Sample-detector distance (m) ; +sx0 to match .info's Distance ??
            magnification = (z1 + z2) / z1
            px.append(px1 / magnification)

            # print(px[-1], px_magnified_edf, px[-1]- px_magnified_edf)
            if px_magnified_edf is not None:
                if not np.isclose(px[-1], px_magnified_edf, rtol=5e-4, atol=0):
                    logger.warning("Magnified pixel size (%6.2fnm) calculated from sx, cx, sx0 "
                                   "differs from the one in edf file (%6.2fnm)" % (
                                       px[-1] * 1e9, px_magnified_edf * 1e9))
        else:
            d0 = s.distance  # real distance in meters, from the .info Distance
            px.append(s.pixel_size)  # Magnified pixel size
            # real pixel size (taking into account on-detector optics)
            px1 = EDFTomoScan.retrieve_information(scans, dataset_basename=None, key="Optic_used",
                                                   ref_file=None, type_=float) * 1e-6
        m = px1 / px[-1]  # Magnification
        logger.info("  Pixel size: %8.2fnm (magnified) %8.2fnm (detector), "
                    "Distance: %8.6fm (magnified) %8.6fm (real)" % (px[-1] * 1e9, px1 * 1e9, d0 / m, d0))
        d.append(d0 / m)
    result['pixel_size'] = np.array(px)
    result['distance'] = np.array(d)
    return result


def get_params_id16a(scan_path, planes=(1,), logger=None, **kwargs):
    """
    Get the pixel size(s) and magnified propagation distances from an ID16A dataset.
    This will use the 'info' file parsed by EDFTomoScan.

    :param scan_path: the path to the scan, either "path/to/sample_name_1_" or
        "path/to/sample_name_%d_" for multi-distance datasets.
    :param planes: the list of planes to consider, usually [1] or [1, 2, 3, 4]
    :param logger: a logger object
    :return: a dictionary with keys including 'wavelength', 'pixel_size', 'distance',
        the latter two being arrays, after taking int account magnification.
    """
    if logger is None:
        logger = logging.getLogger()
    logger.info("Getting ID16A experiment parameters")
    px = []
    d = []
    result = {}
    for iz in planes:
        if "%d" not in scan_path:
            scans = scan_path[:-2] + "%d_" % iz
        else:
            scans = scan_path % iz
        s = EDFTomoScan(scans, n_frames=1)
        if iz == planes[0]:
            result['wavelength'] = 12.3984e-10 / EDFTomoScan.retrieve_information(scans, dataset_basename=None,
                                                                                  key="Energy", ref_file=None,
                                                                                  type_=float)
            result['nx'] = s.dim_1
            result['ny'] = s.dim_2
            w = result['wavelength']
            logger.info("Energy: %6.2fkeV   Wavelength: %8.5fnm" % (12.3984e-10 / w, w * 1e9))
            logger.info("Frame size: (%d, %d)" % (result['ny'], result['nx']))

        logger.info("Pixel and distances for dataset: %s" % scans)
        # Convention is different from ID16B (magnified distance here)
        d0 = s.distance  # *magnified* distance in meters, from the .info Distance field (?)
        # Tomoscan assumes the distance is in mm
        if d0 < 0.001:
            logger.info("ID16A distance<.001 from EDFTomoScan.retrieve_information, apparently "
                        "it was already in meters so x1000 and hope for the best")
            d0 *= 1000
        px.append(s.pixel_size)  # Magnified pixel size
        # real pixel size (taking into account on-detector optics)
        px1 = EDFTomoScan.retrieve_information(scans, dataset_basename=None, key="Optic_used",
                                               ref_file=None, type_=float) * 1e-6
        m = px1 / px[-1]  # Magnification
        logger.info("  Pixel size: %8.2fnm (magnified) %8.2fnm (detector), "
                    "Distance: %8.6fm (magnified) %8.6fm (real)" % (px[-1] * 1e9, px1 * 1e9, d0, d0 * m))
        d.append(d0)
    result['pixel_size'] = np.array(px)
    result['distance'] = np.array(d)
    return result


def sino_filter_pad(d, filter_name, padding):
    """
    Perform a 1D filtering along the x-axis, prior to filtered back-projection.
    This is only useful for padded arrays, in order to use the array values
    in the padded areas before cropping for the back-projection.
    In order to have the same normalisation as in nabu, the resulting array
    will need to be multiplied by pi/nb_proj.

    :param d: the array to filter, can have any dimensions, filtering will be
        done along the last (fastest) dimension
    :param filter_name: Available filters (from silx): Ram-Lak,
        Shepp-Logan, Cosine, Hamming, Hann, Tukey, Lanczos. Case-insensitive.
    :param padding: the number of padding pixels on each side of the last dimension
    :return: the filtered array, with the original array shape
    """
    nxd = d.shape[-1]
    # nx is the non-padded range. Filtering is done by padding to 2*nx
    nx = nxd - 2 * padding
    filter_f = compute_fourier_filter(2 * nx, filter_name=filter_name)[:nx + 1]  # R2C filter
    # filter_f *= np.pi / nproj  # normalisation as in nabu. Needs to be done elsewhere
    if nxd == 2 * nx:
        d1 = d
    else:
        # Need to pad d to 2*nx
        pad1 = (2 * nx - nxd) // 2
        vpad = [(0, 0) for i in range(d.ndim - 1)] + [(pad1, pad1)]
        d1 = np.pad(d, vpad, mode='edge')  # edge or reflect ?
    ixleft = (d1.shape[-1] - nxd) // 2
    return irfftn(rfftn(d1, axes=(-1,)) * filter_f, axes=(-1,))[..., ixleft:ixleft + nxd]
