# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

from packaging.version import parse as version_parse
import numpy as np
from scipy.ndimage import fourier_shift

from scipy.fft import ifftshift, fftshift, rfftn, rfftfreq, fftfreq, fftn, ifftn, irfftn
from scipy.special import erfc
# from skimage.registration import phase_cross_correlation as phase_cross_correlation_orig
# from skimage import __version__ as skimage_version
from ._phase_cross_correlation import phase_cross_correlation as phase_cross_correlation_orig


def shift(img, v):
    """
    Shift an image 2D or 3D, if necessary using a subpixel (FFT-based) approach. Values are wrapped around the array borders.

    Alternative methods:
        for real data: see scipy.ndimage.interpolation.shift (uses spline interpolation, complex not supported)
        for complex or real data: scipy.ndimage.fourier_shift
    
    Args:
        img: the image (2D or 3D) to be shifted, which will be treated as complex-valued data.
        v: the shift values (2 or 3 values according to array dimensions)

    Returns:
        the shifted array
    """
    assert (img.ndim == len(v))
    if all([(x % 1 == 0) for x in v]):
        # integer shifting
        simg = np.roll(img, v[-1], axis=-1)
        for i in range(img.ndim - 1):
            simg = np.roll(simg, v[i], axis=i)
        return simg
    else:
        if np.isrealobj(img):
            # use r2c transform
            return irfftn(fourier_shift(rfftn(img), v, n=img.shape[-1]))
        return ifftn(fourier_shift(fftn(img), v))


def phase_cross_correlation(reference_image, moving_image, low_cutoff=None, low_width=0.03,
                            high_cutoff=None, high_width=0.03, **kwargs):
    """
    phase cross-correlation from scikit-image. This version automatically adds
    the normalization keyword for skimage>=0.19, and adds low and/or high-bandpass
    filters.

    :param reference_image: the reference image
    :param moving_image: the moving image
    :param low_cutoff: a 0<value<<0.5 can be given (typically it should be a few 0.01),
        an erfc filter with a cutoff at low_cutoff*N (where N is the size along each dimension)
        will be applied, after the images have been FT'd
    :param low_width: the width of the low cutoff filter, also as a percentage of the size
    :param high_cutoff: same as low_cutoff fot the high frequency filter, should be close below 0.5
    :param high_width: same as low_width
    :param kwargs: keyword arguments which will be passed to skimage's phase_cross_correlation
    :return:
    """
    # if version_parse(skimage_version) >= version_parse('0.19'):
    if 'normalization' not in kwargs:
        kwargs['normalization'] = None

    sh = reference_image.shape

    assert np.allclose(sh, moving_image.shape)

    if low_cutoff is not None or high_cutoff is not None:
        need_ft = True
        r2c = kwargs['r2c'] if 'r2c' in kwargs else False
        if 'space' in kwargs:
            if 'fourier' == kwargs['space']:
                need_ft = False
        if need_ft:
            if np.isrealobj(reference_image) and np.isrealobj(moving_image):
                reference_image = rfftn(reference_image)
                moving_image = rfftn(moving_image)
                r2c = True
                kwargs['r2c'] = True
            else:
                reference_image = fftn(reference_image)
                moving_image = fftn(moving_image)
            kwargs['space'] = 'fourier'
        r = 0
        for i in range(len(sh)):
            if r2c and i == len(sh) - 1:
                if need_ft:
                    tmp = rfftfreq(sh[i]).astype(np.float32)
                else:
                    # space == 'fourier' and r2c
                    tmp = rfftfreq(2 * (sh[i] - 1)).astype(np.float32)
            else:
                tmp = fftfreq(sh[i]).astype(np.float32)
            for ii in range(len(sh) - i - 1):
                tmp = np.expand_dims(tmp, axis=1)
            r = r + tmp ** 2
        r = np.sqrt(r)
        if low_cutoff:
            tmp = 1 - 0.5 * erfc((r - low_cutoff) / low_width)
            reference_image *= tmp
            moving_image *= tmp
        if high_cutoff:
            tmp = 0.5 * erfc((r - high_cutoff) / high_width)
            reference_image *= tmp
            moving_image *= tmp
    return phase_cross_correlation_orig(reference_image, moving_image, **kwargs)


def paraboloid_max(cc, align_max_shift=None, align_max_shift_center=None):
    """
    Get the coordinates of the maximum in a 2D array (or if d is a stack of 2D
    arrays, the coordinates of the maximum for each array).
    :param cc: a 2D real array, or a stack of 2D arrays (3D, 4D,...)
    :param align_max_shift: maximum shift along each axis. Optional.
    :param align_max_shift_center: if align_max_shift is used, this can be used to provide
        a center value (y0, x0): then the shift will be limited to y0+/-align_max_shift
        and x0+/-align_max_shift. This is useful when aligning series of images and the
        average shift is far from zero while the distribution of shifts remains small.
    :return: a tuple(cy, cx) of shifts for 2D images, or two arrays (cy, cx) for the pixel
        shifts, each having the same shape as the original arrays, minus the last two dimensions.
    """
    sh0 = cc.shape
    if len(sh0) > 2:
        nz = np.prod(sh0[:-2])
        ny, nx = sh0[-2:]
        sh = sh0[-2:]
    else:
        ny, nx = sh0
        sh = sh0
        nz = 1
    shu = list(cc.shape[:-1])
    shu[-1] *= sh[-1]  # shape with the registration axes collapsed

    midpoints = np.array([np.fix(axis_size / 2) for axis_size in sh])

    if align_max_shift is not None:
        if align_max_shift_center is not None:
            y0, x0 = np.array(align_max_shift_center) - midpoints
        else:
            y0, x0 = 0, 0
        y1 = int(y0 - align_max_shift)
        y2 = int(y0 + align_max_shift)
        x1 = int(x0 - align_max_shift)
        x2 = int(x0 + align_max_shift)
        if x1 < 0 <= x2:
            cc[:, :, x2:x1] = 0
        else:  # both >=0 or <0
            cc[:, :, :x1] = 0
            cc[:, :, x2:] = 0

        if y1 < 0 <= y2:
            cc[:, y2:y1] = 0
        else:  # both >=0 or <0
            cc[:, y1] = 0
            cc[:, y2:] = 0

    maxima = np.unravel_index(np.argmax(cc.reshape(shu), axis=-1), sh)

    iy, ix = np.fix(maxima).astype(np.int32)
    if np.isscalar(ix):
        if iy > midpoints[0]:
            iy -= sh[0]
        if ix > midpoints[1]:
            ix -= sh[1]
    else:
        iy[iy > midpoints[0]] -= sh[0]
        ix[ix > midpoints[1]] -= sh[1]

    # Paraboloid fit
    ixm = (ix - 1) % nx
    ixp = (ix + 1) % nx
    iym = (iy - 1) % ny
    iyp = (iy + 1) % ny

    # Paraboloid fit
    # CC(x,y)=Ax^2 + By^2 + Cxy + Dx + Ey + F

    # First reshape so the ccmax equations work with array of coordinates
    shz = list(cc.shape)
    if len(shz) == 2:
        shz = [1] + shz
    vz = list(range(shz[0]))
    ccz = cc.reshape(shz)

    vf = ccz[vz, iy, ix]
    va = (ccz[vz, iy, ixp] + ccz[vz, iy, ixm]) / 2 - vf
    vb = (ccz[vz, iyp, ix] + ccz[vz, iym, ix]) / 2 - vf
    vd = (ccz[vz, iy, ixp] - ccz[vz, iy, ixm]) / 2
    ve = (ccz[vz, iyp, ix] - ccz[vz, iym, ix]) / 2
    vc = (ccz[vz, iyp, ixp] - ccz[vz, iym, ixp] - ccz[vz, iyp, ixm] + ccz[vz, iym, ixm]) / 4

    dx = ix + np.maximum(np.minimum((2 * vb * vd - vc * ve) / (vc * vc - 4 * va * vb), 0.5), -0.5)
    dy = iy + np.maximum(np.minimum((2 * va * ve - vc * vd) / (vc * vc - 4 * va * vb), 0.5), -0.5)
    cx = dx - nx * (dx > (nx / 2)) + nx * (dx < (-nx / 2))
    cy = dy - ny * (dy > (ny / 2)) + ny * (dy < (-ny / 2))

    if nz == 1:
        return cy[0], cx[0]
    return cy.reshape(sh0[:-2]), cx.reshape(sh0[:-2])


def phase_cross_correlation_paraboloid(reference_image, moving_image,
                                       low_cutoff=None, low_width=0.03,
                                       high_cutoff=None, high_width=0.03,
                                       align_max_shift=None,
                                       align_max_shift_center=None):
    """
    Image registration of 2D images using phase cross-correlation from scikit-image, providing
    sub-pixel accuracy through the paraboloid fit of the cross-correlation map.
    If a stack of 2D images is provided (instead of just 1), the shifts are returned for
    all images.

    :param reference_image: the reference image
    :param moving_image: the moving image
    :param low_cutoff: a 0<value<<0.5 can be given (typically it should be a few 0.01),
        an erfc filter with a cutoff at low_cutoff*N (where N is the size along each dimension)
        will be applied, after the images have been FT'd
    :param low_width: the width of the low cutoff filter, also as a percentage of the size
    :param high_cutoff: same as low_cutoff fot the high frequency filter, should be close below 0.5
    :param high_width: same as low_width
    :param align_max_shift: maximum shift along each axis. Optional.
    :param align_max_shift_center: if align_max_shift is used, this can be used to provide
        a center value (y0, x0): then the shift will be limited to y0+/-align_max_shift
        and x0+/-align_max_shift. This is useful when aligning series of images and the
        average shift is far from zero while the distribution of shifts remains small.
    :return: a tuple(cy, cx) of shifts for 2D images, or two arrays (cy, cx) for the pixel
        shifts, each having the same shape as the original arrays, minus the last two dimensions.
    """

    assert np.allclose(reference_image.shape, moving_image.shape)
    # Make sure we have a stack of 2D images
    sh0 = reference_image.shape
    if len(sh0) > 2:
        nz = np.prod(sh0[:-2])
        ny, nx = sh0[-2:]
        sh = sh0[-2:]
    else:
        ny, nx = sh0
        sh = sh0
        nz = 1
    reference_image = reference_image.reshape((nz, ny, nx))
    moving_image = moving_image.reshape((nz, ny, nx))

    r2c = np.isrealobj(reference_image)
    if r2c:
        reference_image = rfftn(reference_image.astype(np.float32), axes=(-2, -1))
        moving_image = rfftn(moving_image.astype(np.float32), axes=(-2, -1))
    else:
        reference_image = fftn(reference_image.astype(np.complex64), axes=(-2, -1))
        moving_image = fftn(moving_image.astype(np.complex64), axes=(-2, -1))

    if low_cutoff is not None or high_cutoff is not None:
        r = 0
        for i in range(len(sh)):
            if r2c and i == len(sh) - 1:
                tmp = rfftfreq(sh[i]).astype(np.float32)
            else:
                tmp = fftfreq(sh[i]).astype(np.float32)
            for ii in range(len(sh) - i - 1):
                tmp = np.expand_dims(tmp, axis=1)
            r = r + tmp ** 2
        r = np.sqrt(r)
        if low_cutoff:
            tmp = 1 - 0.5 * erfc((r - low_cutoff) / low_width)
            reference_image *= tmp
            moving_image *= tmp
        if high_cutoff:
            tmp = 0.5 * erfc((r - high_cutoff) / high_width)
            reference_image *= tmp
            moving_image *= tmp

    if r2c:
        cc = irfftn(reference_image * moving_image.conj(), axes=(-2, -1))
    else:
        cc = np.abs(ifftn(reference_image * moving_image.conj(), axes=(-2, -1)))

    cy, cx = paraboloid_max(cc, align_max_shift, align_max_shift_center)
    if nz == 1:
        return cy, cx
    return cy.reshape(sh0[:-2]), cx.reshape(sh0[:-2])


def register_l2_norm(reference_image, moving_image, align_max_shift,
                     low_cutoff=None, low_width=0.03,
                     high_cutoff=None, high_width=0.03,
                     align_max_shift_center=None,
                     gauss_fwhm=100, gauss_kernel=None,
                     return_gauss_kernel=False):
    """
    Image registration of 2D images by minimising the L2 norm difference between
    images, after normalisation. [EXPERIMENTAL - not performing well enough]
    This follows the following algorithm:
    - each array (reference and moving) are subtracted by a Gaussian-convolution
      of the arrays
    - each array is then normalised by the local standard deviation,
      based on the same Gaussian convolution
    - The minimum is then found for (R-M)**2 where R and M are the reference
      and  (shifted)moving arrays. This is done by convoluting the moving
      array by a window function (based on the maximum shift)
      and using FFT for efficiency.

    Sub-pixel accuracy is achieved through a paraboloid fit.
    If a stack of 2D images is provided (instead of just 1), the shifts are returned for
    all images.

    :param reference_image: the reference image
    :param moving_image: the moving image
    :param align_max_shift: maximum shift along each axis.
    :param low_cutoff: a 0<value<<0.5 can be given (typically it should be a few 0.01),
        an erfc filter with a cutoff at low_cutoff*N (where N is the size along each dimension)
        will be applied, after the images have been FT'd
    :param low_width: the width of the low cutoff filter, also as a percentage of the size
    :param high_cutoff: same as low_cutoff fot the high frequency filter, should be close below 0.5
    :param high_width: same as low_width
    :param align_max_shift_center: if align_max_shift is used, this can be used to provide
        a center value (y0, x0): then the shift will be limited to y0+/-align_max_shift
        and x0+/-align_max_shift. This is useful when aligning series of images and the
        average shift is far from zero while the distribution of shifts remains small.
    :param gauss_fwhm: the full-width at half maximum for the Gaussian convolution
        used to remove a local average of each array, and normalise to the local
        standard deviation.
    :param gauss_kernel: the FT of the Gaussian kernel, to avoid re-computing it
        for multiple images alignments.
    :param return_gauss_kernel: if True, will also return the gaussian kernel
        in Fourier space (half-hermitian array).
    :return: a tuple(cy, cx) of shifts for 2D images, or two arrays (cy, cx) for the pixel
        shifts, each having the same shape as the original arrays, minus the last
        two dimensions. If return_gauss_kernel is True, return (cy, cx, g)
    """
    align_max_shift = None  # TODO debug
    # Prepare and filter arrays (same as phase cross-correlation)
    sh0 = reference_image.shape
    if len(sh0) > 2:
        nz = np.prod(sh0[:-2])
        ny, nx = sh0[-2:]
        sh = sh0[-2:]
    else:
        ny, nx = sh0
        sh = sh0
        nz = 1
    reference_image = reference_image.reshape((nz, ny, nx))
    moving_image = moving_image.reshape((nz, ny, nx))

    # Move to Fourier space
    reference_image = rfftn(reference_image.astype(np.float32, copy=False), axes=(-2, -1))
    moving_image = rfftn(moving_image.astype(np.float32, copy=False), axes=(-2, -1))

    if low_cutoff is not None or high_cutoff is not None:
        r = 0
        for i in range(len(sh)):
            tmp = rfftfreq(sh[i]).astype(np.float32)
            for ii in range(len(sh) - i - 1):
                tmp = np.expand_dims(tmp, axis=1)
            r = r + tmp ** 2
        r = np.sqrt(r)
        if low_cutoff:
            tmp = 1 - 0.5 * erfc((r - low_cutoff) / low_width)
            reference_image *= tmp
            moving_image *= tmp
        if high_cutoff:
            tmp = 0.5 * erfc((r - high_cutoff) / high_width)
            reference_image *= tmp
            moving_image *= tmp

    # Subtract Gaussian-averaged
    if gauss_kernel is None:
        iy, ix = fftfreq(ny), fftfreq(nx)
        iy, ix = np.meshgrid(iy, ix, indexing='ij')
        sig = gauss_fwhm * 2 * np.sqrt(2 * np.log(2))
        g = 1 / (sig * np.sqrt(2 * np.pi)) * np.exp(-(ix ** 2 + iy ** 2) / (2 * sig ** 2))
        g = rfftn(g, axes=(-2, -1), norm='ortho').conj()  # This could be computed directly in Fourier space...
    else:
        g = gauss_kernel

    # Remove Gaussian-convoluted image in Fourier space
    reference_image -= reference_image * g
    moving_image -= moving_image * g

    # Back to real space & normalise to rolling standard deviation (Gaussian convolution)
    reference_image = irfftn(reference_image, axes=(-2, -1), norm='ortho') / \
                      irfftn(reference_image * g, axes=(-2, -1), norm='ortho')
    moving_image = irfftn(moving_image, axes=(-2, -1), norm='ortho') / \
                   irfftn(moving_image * g, axes=(-2, -1), norm='ortho')

    do_plot = False
    if do_plot:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(18, 5))
        plt.subplot(141)
        t = reference_image[0]
        vmin, vmax = np.percentile(t, (1, 99))
        plt.imshow(t, vmin=vmin, vmax=vmax)
        plt.subplot(142)
        t = moving_image[0]
        vmin, vmax = np.percentile(t, (1, 99))
        plt.imshow(t, vmin=vmin, vmax=vmax)

    reference_image = rfftn(reference_image, axes=(-2, -1), norm='ortho')

    # Windowed moving image
    w = np.ones((ny, nx), dtype=np.float32)
    if align_max_shift is not None:
        w[:align_max_shift] = 0
        w[-align_max_shift:] = 0
        w[:, :align_max_shift] = 0
        w[:, -align_max_shift:] = 0
        # w = fftshift(w)

        b2 = irfftn(rfftn(moving_image ** 2, axes=(-2, -1), norm='ortho')
                    * rfftn(w, norm='ortho').conj(), axes=(-2, -1), norm='ortho')
        moving_image *= w
    else:
        b2 = moving_image ** 2
    moving_image = rfftn(w * moving_image, axes=(-2, -1), norm='ortho')

    # Negative normalised L2 diff
    # (negative so we search the max like for CC, not min)
    l2_diff = 2 * irfftn(reference_image * moving_image.conj(), axes=(-2, -1), norm='ortho')
    l2_diff -= b2
    l2_diff0 = l2_diff.copy()

    cy, cx = paraboloid_max(l2_diff, align_max_shift=align_max_shift,
                            align_max_shift_center=align_max_shift_center)

    if do_plot:
        plt.subplot(143)
        t = fftshift(l2_diff0[0])
        vmin, vmax = np.percentile(t, (1, 100))
        plt.imshow(t, vmin=vmin, vmax=vmax)

        plt.subplot(144)
        t = fftshift(b2[0])
        vmin, vmax = np.percentile(t, (1, 100))
        plt.imshow(t, vmin=vmin, vmax=vmax)
        plt.tight_layout()
    if return_gauss_kernel:
        return cy, cx, g
    return cy, cx
