# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

import timeit
import numpy as np
from scipy.ndimage import fourier_shift

from scipy import log10, sqrt, pi, exp
from scipy.fftpack import ifftshift, fftshift, fft2, ifft2, fftfreq, fftn, ifftn

from pynx import scattering

def shift(img, v):
    """
    Shift an image 2D or 3D, if necessary using a subpixel (FFT-based) approach. Values are wrapped around the array borders.

    Alternative methods:
        for real data: see scipy.ndimage.interpolation.shift (uses spline interpolation, complex not supported)
        for complex data: scipy.ndimage.fourier_shift
    
    Args:
        img: the image (2D or 3D) to be shifted, which will be treated as complex-valued data.
        v: the shift values (2 or 3 values according to array dimensions)

    Returns:
        the shifted array
    """
    assert(img.ndim == len(v))
    if all([ (x%1 == 0) for x in v]):
        #integer shifting
        simg = np.roll(img, v[-1], axis=-1)
        for i in range(img.ndim-1):
            simg = np.roll(simg, v[i], axis=i)
        return simg
    else:
        if img.dtype == np.float64 or img.dtype == np.complex128:
            dtype = np.complex128
        else:
            dtype = np.complex64
        return ifftn(fourier_shift(fftn(img.astype(dtype)), v))


def register_translation(ref_img, img, upsampling=1, verbose = False, gpu='CPU'):
    """
    Calculate the translation shift between two images.

    Also see in scikit-image: skimage.feature.register_translation, for a more complete implementation
    
    Args:
        ref_img: the reference image
        img: the image to be translated to match ref_img
        upsampling: integer value - the pixel resolution of the computed shift will be equal to 1/upsampling

    Returns:
        the shift value along each axis of the image
    """
    assert(ref_img.shape == img.shape)
    assert(img.ndim == 2) # TODO: 3D images
    t00 = timeit.default_timer()
    ny, nx = img.shape
    ref_img_f = fftn(np.array(ref_img, dtype=np.complex64, copy=False))
    img_f = fftn(np.array(img, dtype=np.complex64, copy=False))

    cc_f = img_f * ref_img_f.conj()

    # maxshift = 8 #try to limit the range to search for the peak (only activated useful with a _very_ small range => deactivated)

    #if maxshift is None:
    #    firstfft = True
    #elif maxshift < np.sqrt(img.size)/4:
    #    firstfft = False
    #else:
    #    firstfft = True
    if True:
        # integer pixel registration
        nflop = 5 * nx * ny * (np.log2(nx) + np.log2(ny))
        t0 = timeit.default_timer()
        icc = ifftn(cc_f)

        theshift = np.array(np.unravel_index(np.argmax(abs(icc)), cc_f.shape))
        for i in range(len(theshift)):
            theshift[i] -= (theshift[i] > cc_f.shape[i]//2) * cc_f.shape[i]
        dt = timeit.default_timer() - t0
        if verbose:
            print("integer shift using FFT (%5.2fs, nGflop=%6.3f)" % (dt, nflop/1e9), theshift)
    #else:
    #    # integer pixel registration, assuming the shift is smaller than maxshift, using a DFT
    #    w = 2 * maxshift
    #    nflop = 2 * 8 * nx * ny * w ** 2
    #    t0 = timeit.default_timer()
    #    z, y, x= np.meshgrid(0, fftshift(np.arange(-.5, .5, 1 / ny)), fftshift(np.arange(-.5, .5, 1 / nx)), indexing='ij')
    #    hk = np.arange(-w / 2, w / 2, 1 / w)
    #    l, k, h= np.meshgrid(0, hk, hk, indexing='ij')

    #    cc = scattering.Fhkl_thread(h,k,l,x,y,z,occ=cc_f.real,gpu_name='Iris')[0] + 1j*scattering.Fhkl_thread(h,k,l,x,y,z,occ=cc_f.imag,
    #                                                                                                        gpu_name='Iris')[0]
    #    theshift = np.array(np.unravel_index(np.argmax(abs(cc)), cc.shape[-2:]))
    #    theshift = [k[0,int(np.round(theshift[0])),0],h[0,0,int(np.round(theshift[1]))]]
    #    dt = timeit.default_timer() - t0
    #    print("theshift using DFT (%5.2fs, nGflop=%6.3f)"%(dt, nflop/1e9), theshift)

    if gpu.lower() != 'CPU':
        language = 'opencl'
    else:
        language = 'cpu'
    if upsampling>1:
        # subpixel registration
        k1 = np.sqrt(upsampling)
        #for uw, us in [(1.5, 1 / upsampling)]: # one-step optimization
        for uw, us in [(1.5, 1.5 / k1), (3 / k1, 1 / upsampling)]:  # two-step optimization à la Guizar-Sicairos
            # uw: width of upsampled region
            # us: step size in upsampled region
            t0 = timeit.default_timer()

            z, y, x= np.meshgrid(0, fftshift(np.arange(-.5, .5, 1 / ny)), fftshift(np.arange(-.5, .5, 1 / nx)), indexing='ij')
            h = np.arange(-uw / 2, uw / 2, us) + theshift[-1]
            k = np.arange(-uw / 2, uw / 2, us) + theshift[-2]
            l, k, h= np.meshgrid(0, k, h, indexing='ij')

            cc = scattering.Fhkl_thread(h,k,l,x,y,z,occ=cc_f.real,gpu_name=gpu, language=language)[0] \
                 + 1j * scattering.Fhkl_thread(h,k,l,x,y,z,occ=cc_f.imag, gpu_name=gpu, language=language)[0]
            theshift = np.array(np.unravel_index(np.argmax(abs(cc)), cc.shape[-2:]))
            theshift = [k[0,int(np.round(theshift[0])),0],h[0,0,int(np.round(theshift[1]))]]
            dt = timeit.default_timer() - t0
            if verbose:
                print("subpixel shift using DFT (%5.2fs)"%(dt), theshift)
    if verbose:
        print("Final shift (%5.2fs)" % (timeit.default_timer()-t00), theshift)
    return theshift




if __name__ == '__main__':
    from scipy.misc import face, ascent
    a = ascent()[8:-8]
    a2 = shift(np.random.poisson(a),(-5.25,2.4))
    print(register_translation(a, a2))
    print(register_translation(a, a2, upsampling=100, verbose=True, gpu='Iris'))
