# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

import numpy as np
from scipy.ndimage.measurements import center_of_mass


def rebin(a, rebin_f, normalize=False):
    """
     Rebin a 2 or 3-dimensional array. If its dimensions are not a multiple of rebin_f, the array will be cropped.
     
    Args:
        a: the array to resize
        rebin_f: the rebin factor - pixels will be summed by groups of rebin_f x rebin_f (x rebin_f). This can
                 also be a tuple/list of rebin values along each axis, e.g. rebin_f=(4,1,2) for a 3D array
        normalize: if True, the average value of the array will be kept, by dividing the result by rebin_f**a.ndim.
                   By default nor normalization is applied to preserve statistical properties.

    Returns:

    """
    ndim = a.ndim
    if type(rebin_f) is int:
        rebin_f = [rebin_f] * ndim
    else:
        assert ndim == len(rebin_f), "Rebin: number of dimensions does not agree with number of rebin values:" + str(
            rebin_f)
    if ndim == 2:
        ny, nx = a.shape
        a = a[:ny - (ny % rebin_f[0]), :nx - (nx % rebin_f[1])]
        sh = ny // rebin_f[0], rebin_f[0], nx // rebin_f[1], rebin_f[1]
        if normalize:
            if isinstance(rebin_f, int):
                return a.reshape(sh).sum(axis=(1, 3)) / rebin_f ** ndim
            else:
                n = rebin_f[0]
                for r in rebin_f[1:]:
                    n *= r
                return a.reshape(sh).sum(axis=(1, 3)) / n
        else:
            return a.reshape(sh).sum(axis=(1, 3))
    elif ndim == 3:
        nz, ny, nx = a.shape
        a = a[:nz - (nz % rebin_f[0]), :ny - (ny % rebin_f[1]), :nx - (nx % rebin_f[2])]
        sh = nz // rebin_f[0], rebin_f[0], ny // rebin_f[1], rebin_f[1], nx // rebin_f[2], rebin_f[2]
        if normalize:
            if isinstance(rebin_f, int):
                return a.reshape(sh).sum(axis=(1, 3, 5)) / rebin_f ** ndim
            else:
                n = rebin_f[0]
                for r in rebin_f[1:]:
                    n *= r
                return a.reshape(sh).sum(axis=(1, 3, 5)) / n
        else:
            return a.reshape(sh).sum(axis=(1, 3, 5))
    elif ndim == 4:
        n3, nz, ny, nx = a.shape
        a = a[:n3 - (n3 % rebin_f[0]), :nz - (nz % rebin_f[1]), :ny - (ny % rebin_f[2]), :nx - (nx % rebin_f[3])]
        sh = n3 // rebin_f[0], rebin_f[0], nz // rebin_f[1], rebin_f[1], ny // rebin_f[2], rebin_f[2],\
             nx // rebin_f[3], rebin_f[3]
        a = a.reshape(sh)
        # print("rebin(): a.shape=", a.shape)
        if normalize:
            if isinstance(rebin_f, int):
                return a.sum(axis=(1, 3, 5, 7)) / rebin_f ** ndim
            else:
                n = rebin_f[0]
                for r in rebin_f[1:]:
                    n *= r
                return a.sum(axis=(1, 3, 5, 7)) / n
        else:
            return a.sum(axis=(1, 3, 5, 7))
    else:
        raise Exception("pynx.utils.array.rebin() only accept arrays of dimensions 2, 3 and 4")


def center_array_2d(a, other_arrays=None, threshold=0.2, roi=None, iz=None):
    """
    Center an array in 2D so that its absolute value barycenter is in the middle.
    If the array is 3D, it is summed along the first axis to determine the barycenter, and all frames along the first
    axis are shifted.
    The array is 'rolled' so that values shifted from the right appear on the left, etc...
    Shifts are integer - no interpolation is done.

    Args:
        a: the array to be shifted, can be a floating point or complex 2D or 3D array.
        other_arrays: can be another array or a list of arrays to be shifted by the same amount as a
        threshold: only the pixels above the maximum amplitude * threshold will be used for the barycenter
        roi: tuple of (x0, x1, y0, y1) corners coordinate of ROI to calculate center of mass
        iz: if a.ndim==3, the centering will be done based on the center of mass of the absolute value summed over all
            2D stacks. If iz is given, the center of mass will be calculated just on that stack

    Returns:
        the shifted array if only one is given or a tuple of the shifted arrays.
    """
    if a.ndim == 3:
        if iz is None:
            tmp = abs(a).astype(np.float32).sum(axis=0)
        else:
            tmp = abs(a[iz]).astype(np.float32)
    else:
        tmp = abs(a).astype(np.float32)

    if threshold is not None:
        tmp *= tmp > (tmp.max() * threshold)

    y0, x0 = center_of_mass(tmp)

    if roi is not None:
        xo, x1, yo, y1 = roi
        tmproi = tmp[yo:y1, xo:x1]
        y0, x0 = center_of_mass(tmproi)
        y0 += yo
        x0 += xo

    ny, nx = tmp.shape
    dx, dy = (int(round(nx // 2 - x0)), int(round(ny // 2 - y0)))
    # print("Shifting by: dx=%6.2f dy=%6.2f" % (dx, dy))

    # Multi-axis shift is supported only in numpy version >= 1.12 (2017)
    a1 = np.roll(np.roll(a, dy, axis=-2), dx, axis=-1)
    if other_arrays is None:
        return a1
    else:
        if type(other_arrays) is list:
            v = []
            for b in other_arrays:
                v.append(np.roll(np.roll(b, dy, axis=-2), dx, axis=-1))
            return a1, v
        else:
            return a1, np.roll(np.roll(other_arrays, dy, axis=-2), dx, axis=-1)
