# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2018-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr
"""
This module includes functions to:
- compare and match solutions (objects) from CDI optimisations
- provide figures-of-merit for optimised objects
- combine and solutions
"""

import numpy as np
from scipy.optimize import minimize
from scipy.ndimage import fourier_shift
from skimage.feature import register_translation


def match_shape(arrays, method='min'):
    """
    Match the shape of two or more arrays by cropping the borders, or zero-padding if necessary
    :param arrays: a list or tuple of arrays, all either 2D or 3D
    :param method: either 'min' (use the largest size along each dimension which is smaller than all array sizes)
        or 'median': use the median value for the size along each dimension. The 'median' option is better when
        matching more than 2 arrays, when one may be an outlier with incorrect dimensions.
    :return: a list of the arrays cropped or zero-padded to the same shape. The data type of each individual array
        is preserved
    """
    d1 = arrays[0]
    v = []
    ndim = d1.ndim
    if method == 'median':
        nx = int(np.median(list(d.shape[-1] for d in arrays)))
        ny = int(np.median(list(d.shape[-2] for d in arrays)))
        if ndim == 3:
            nz = int(np.median(list(d.shape[-3] for d in arrays)))
    else:
        nx = min((d.shape[-1] for d in arrays))
        ny = min((d.shape[-2] for d in arrays))
        if ndim == 3:
            nz = min((d.shape[-3] for d in arrays))

    for d in arrays:
        if ndim == 3:
            nz1, ny1, nx1 = d.shape
            tmp = np.zeros((nz, ny, nx), dtype=d.dtype)
        else:
            ny1, nx1 = d.shape
            tmp = np.zeros((ny, nx), dtype=d.dtype)

        n, n1 = nx, nx1
        if n <= n1:
            d = d[..., n1 // 2 - n // 2:n1 // 2 - n // 2 + n]

        n, n1 = ny, ny1
        if n <= n1:
            d = d[..., n1 // 2 - n // 2:n1 // 2 - n // 2 + n, :]

        if ndim >= 3:
            n, n1 = nz, nz1
            if n <= n1:
                d = d[..., n1 // 2 - n // 2:n1 // 2 - n // 2 + n, :, :]

        if ndim == 3:
            nz1, ny1, nx1 = d.shape
            tmp[nz // 2 - nz1 // 2:nz // 2 - nz1 // 2 + nz1, ny // 2 - ny1 // 2:ny // 2 - ny1 // 2 + ny1,
            nx // 2 - nx1 // 2:nx // 2 - nx1 // 2 + nx1] = d
        else:
            ny1, nx1 = d.shape
            tmp[ny // 2 - ny1 // 2:ny // 2 - ny1 // 2 + ny1, nx // 2 - nx1 // 2:nx // 2 - nx1 // 2 + nx1] = d

        v.append(tmp)

    return v


def flipn(d: np.ndarray, flip_axes):
    """
    Flip an array along any combination of axes
    :param d: the array to manipulate
    :param flip_axes: an iterable list of d.ndim values with True/False values indicating if the array must be flipped
                      along each axis.
    :return: a copy of the array after modification
    """
    for i in range(d.ndim):
        if flip_axes[i]:
            d = np.flip(d, i)
    return d


def corr_phase(pars, d: np.ndarray):
    """
    Apply a linear phase shift to a complex array, either 2D or 3D
    :param pars: the shift parameters for the phase. The correction applied is a multiplication by exp(-1j * dphi),
        with dphi = pars[0] + pars[1] * x0 + pars[2] * x1 {+ pars[3] * x2}, where x0, x1, x2 are coordinates array
        covering [0 ; 1[ along each axis.
    :param d: the array for which the phase will be shifted
    :return: a copy of the array after phase-shifting
    """
    if d.ndim == 3:
        nz, ny, nx = d.shape
        iz, iy, ix = np.meshgrid(np.arange(nz) / nz, np.arange(ny) / ny, np.arange(nx) / nx, indexing='ij')
        return (d * np.exp(-1j * (pars[0] + pars[1] * iz + pars[2] * iy + pars[3] * ix))).astype(d.dtype)
    else:
        ny, nx = d.shape
        iy, ix = np.meshgrid(np.arange(ny) / ny, np.arange(nx) / nx, indexing='ij')
        return (d * np.exp(-1j * (pars[0] + pars[1] * iy + pars[2] * ix))).astype(d.dtype)


def fit_phase2(pars, phi1: np.ndarray, phi2: np.ndarray, a1: np.ndarray, a2: np.ndarray, xyz: tuple):
    """
    Fit function to match the phase between two arrays, with linear phase shift parameters.
    :param pars: a list of ndim+1 linear phase shift parameters (one constant plus one along each axis)
    :param phi1: the phase (angle) of the first array, assumed to be within [-pi, pi]
    :param phi2: the phase (angle) of the second array, assumed to be within [-pi, pi]
    :param a1: the amplitude (or weight) of the first array. Should be >=0
    :param a2:  the amplitude (or weight) of the second array. Should be >=0
    :param xyz: a tuple with of ndim arrays with the same shape as the arrays, giving a set of [0;1[ coordinates along
        each axis.
    :return: a floating point figure of merit, ((a1 + a2) * delta_phi ** 2).sum()
    """
    dphi = phi2 - pars[0]
    for i in range(len(xyz)):
        dphi -= pars[i + 1] * xyz[i]
    dphi = abs(phi1 - dphi)
    dphi = np.minimum(dphi, 2 * np.pi - dphi)
    return ((a1 + a2) * dphi ** 2).sum()


def r_match(d1: np.ndarray, d2: np.ndarray, percent=99, threshold=0.05):
    """
    Compute an unweighted R-factor between two arrays (complex or real)
    :param d1: the first array
    :param d2: the second array
    :param percent: a percent value between 0 and 100. If used, the R factor will only be calculated over the
        data points above the nth percentile multiplied by the threshold value in either arrays
    :param threshold: the R factor will only be calculated over the data points above the maximum or
        nth percentile multiplied by the threshold value
    :return: the R-factor calculated as sqrt(sum(abs(d1-d2)**2)/(0.5*sum(abs(d1)**2+abs(d2)**2)))
    """
    a1, a2 = abs(d1), abs(d2)
    idx = np.logical_or(a1 > (threshold * np.percentile(a1, percent)), (a2 > (threshold * np.percentile(a2, percent))))
    return np.sqrt(2 * (abs(d1[idx] - d2[idx]) ** 2).sum() / (abs(d1[idx]) ** 2 + abs(d2[idx]) ** 2).sum())


def match2(d1: np.ndarray, d2: np.ndarray, match_phase=True, verbose=False, upsample_factor=1):
    """
    Match array d2 against array d1, by flipping it along one or several axis and/or calculating its conjugate,
    translation registration, and matching amplitudes. Both arrays should be 2D or 3D.
    :param d1: the first array
    :param d2: the second array
    :param match_phase: if True (the default), the phase ramps and shift will be matched as well. This optimisation
        is slower.
    :param verbose: print some info while matching arrays
    :param upsample_factor: upsampling factor for subpixel registration (default: 1 - no subpixel). Good values
                            are 10 or 20.
    :return: (d1, d2, r) the two arrays after matching their shape, orientation, translation and (optionally) phase.
        The first array is only cropped if necessary, but otherwise unmodified.
        r is the result of r_match(d1,d2, percent=99)
    """
    # Match the two arrays shape
    d1, d2 = match_shape((d1, d2))

    # Match the orientation and translation based on the amplitudes
    a1 = abs(d1)
    a2 = abs(d2)
    errmin = 1e6
    flip_min = None
    # TODO: do we really need to try all 4 (8) orientations ?
    if d1.ndim == 3:
        for flipx in [0, 1]:
            for flipy in [0, 1]:
                for flipz in [0, 1]:
                    s, err, dphi = register_translation(a1, flipn(a2, [flipz, flipy, flipx]))
                    if errmin > err:
                        flip_min = [flipz, flipy, flipx]
                        errmin = err
                    # print("Matching orientation: %d %d %d:" % (flipz, flipy, flipx), s, err, dphi)
        if verbose:
            print('Best orientation match: flipz=%d flipy=%d flipx=%d error=%6.3f' % (
                flip_min[0], flip_min[1], flip_min[2], errmin))
    else:
        for flipx in [0, 1]:
            for flipy in [0, 1]:
                s, err, dphi = register_translation(a1, flipn(a2, [flipy, flipx]))
                if errmin > err:
                    flip_min = [flipy, flipx]
                    errmin = err
                # print("Matching orientation: %d %d:" % (flipy, flipx), s, err, dphi)
        if verbose:
            print('Best orientation match: flipy=%d flipx=%d error=%6.3f' % (
                flip_min[0], flip_min[1], errmin))

    d2 = flipn(d2, flip_min)
    s, err, dphi = register_translation(a1, abs(d2), upsample_factor=upsample_factor)
    # Roll can be used for pixel registration
    # d2 = np.roll(d2, [int(round(v)) for v in s], range(d2.ndim)) * np.exp(1j * dphi)
    d2 = np.fft.ifftn(fourier_shift(np.fft.fftn(d2), s)) * np.exp(1j * dphi)
    if verbose:
        print(register_translation(a1, abs(d2)))

    if match_phase:
        # Match phase shift and ramp, testing both conjugates
        phi1, a1 = np.angle(d1), abs(d1)
        phi2, a2 = np.angle(d2), abs(d2)
        if d1.ndim == 3:
            nz, ny, nx = d1.shape
            xyz = np.meshgrid(np.arange(nz) / nz, np.arange(ny) / ny, np.arange(nx) / nx, indexing='ij')
        else:
            ny, nx = d1.shape
            xyz = np.meshgrid(np.arange(ny) / ny, np.arange(nx) / nx, indexing='ij')

        p0 = minimize(fit_phase2, np.zeros(1 + d1.ndim), args=(phi1, phi2, a1, a2, xyz), method='powell')
        if verbose:
            print(p0)
        p1 = minimize(fit_phase2, np.zeros(1 + d1.ndim), args=(phi1, -phi2, a1, a2, xyz), method='powell')
        if verbose:
            print(p1)
        if p0['fun'] < p1['fun']:
            d2 = corr_phase(p0['x'], d2)
        else:
            d2 = corr_phase(p1['x'], d2.conj())

    # Match amplitudes to minimise (abs(a1-S*a2)**2).sum()
    a1, a2 = abs(d1), abs(d2)
    d2 *= (a1 * a2).sum() / (a2 ** 2).sum()

    if match_phase:
        r = r_match(d1, d2, percent=99, threshold=0.05)
    else:
        r = r_match(a1, abs(d2), percent=99, threshold=0.05)

    if verbose:
        print("Final R_match between arrays: R=%6.3f%%" % (r * 100))

    return d1, d2, r
