# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2019-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr
#         Julio Cesar da Silva (mailto: jdasilva@esrf.fr) (nyquist and ring_thickness code)


import numpy as np


def ring_thickness(shape):
    """
    Define ring_thick
    """
    n = shape
    nmax = np.max(n)
    x = np.arange(-np.fix(n[1] / 2.0), np.ceil(n[1] / 2.0)) * np.floor(nmax / 2.0) / np.floor(n[1] / 2.0)
    y = np.arange(-np.fix(n[0] / 2.0), np.ceil(n[0] / 2.0)) * np.floor(nmax / 2.0) / np.floor(n[0] / 2.0)
    if len(shape) == 3:
        z = np.arange(-np.fix(n[2] / 2.0), np.ceil(n[2] / 2.0)) * np.floor(nmax / 2.0) / np.floor(n[2] / 2.0)
        xx = np.meshgrid(x, y, z)
    elif len(shape) == 2:
        xx = np.meshgrid(x, y)
    else:
        print('Number of dimensions is different from 2 or 3.Exiting...')
        raise SystemExit('Number of dimensions is different from 2 or 3.Exiting...')
    sumsquares = np.zeros_like(xx[-1])
    for ii in np.arange(0, len(shape)):
        sumsquares += xx[ii] ** 2
    index = np.round(np.sqrt(sumsquares))
    return index


def nyquist(shape):
    """
    Evaluate the Nyquist Frequency
    """
    nmax = np.max(shape)
    fnyquist = np.floor(nmax / 2.0)
    freq = np.arange(0, fnyquist + 1)
    return freq, fnyquist


def prtf(icalc, iobs, mask=None, ring_thick=5, shell_averaging_method='before'):
    """
    Compute the phase retrieval transfer function, given calculated and observed intensity. Note that this
    function assumes that calc and obs are orthonormal arrays.

    :param icalc: the calculated intensity array (origin at array center), either 2D or 3D
    :param iobs: the observed intensity (origin at array center)
    :param mask: the mask, values > 0 indicating bad pixels in the observed intensity array
    :param ring_thick: the thickness of each shell or ring, in pixel units
    :param shell_averaging_method: by default ('before') the amplitudes are averaged over the shell before
                                   the PRTF=<calc>/<obs> is computed. If 'after' is given, then the ratio of calc
                                   and observed amplitudes will first be calculated (excluding zero-valued observed
                                   pixels), and then averaged to compute the PRTF=<calc/obs>.
    :return: a tuple with the (frequency, frequency_nyquist, prtf) arrays
    """

    # TODO assumes uniform grid i.e pixel same size in all dimensions and no curvature 
    #  -  need reciprocal space coords / an ortho-normalisation matrix
    calc = np.sqrt(icalc)
    obs = np.sqrt(abs(iobs))
    prtf, nb = [], []
    index = ring_thickness(iobs.shape)
    freq, fnyquist = nyquist(iobs.shape)
    if mask is None:
        mask = np.zeros(iobs.shape, dtype=np.int8)
    for ii in freq:
        tmp = np.where(np.logical_and(index == ii, mask == 0))
        if len(tmp[0]):
            if ring_thick == 0:
                tmpcalc = calc[tmp]
                tmpobs = obs[tmp]
            else:
                r2 = ring_thick / 2
                tmpcalc = calc[(np.where((index >= (ii - r2)) & (index <= (ii + r2)) & (mask == 0)))]
                tmpobs = obs[(np.where((index >= (ii - r2)) & (index <= (ii + r2)) & (mask == 0)))]
            # TODO: check how PRTF should be calculated, after or before cumming intensities in ring thickness
            if 'after' in shell_averaging_method.lower():
                # Average Icalc/Iobs
                arrtmpcalc = np.array(tmpcalc)
                arrtmpobs = np.array(tmpobs)
                nbvox = (arrtmpobs > 0).sum()
                prtfcalc = np.divide(arrtmpcalc, arrtmpobs, out=np.zeros_like(arrtmpcalc), where=arrtmpobs != 0)
                prtf.append(prtfcalc.sum() / nbvox)
            else:
                # Average Icalc.sum(ring) / Iobs.sum(ring) - more optimistic
                prtf.append(tmpcalc.sum() / tmpobs.sum())
        else:
            # All values are masked (central stop ?)
            # print(ii)
            prtf.append(0)
    prtf = np.array(prtf)
    # pr = np.ma.masked_array(prtf, mask=(prtf == 0))
    return freq, fnyquist, prtf


def plot_prtf(freq, fnyquist, prtf, pixel_size=None):
    """
    Plot the phase retrieval transfer function

    :param freq: the frequencies for which the phase retrieval transfer function was calculated
    :param fnyquist: the nyquist frequency
    :param prtf: the phase retrieval transfer function
    :param pixel_size: the pixel size in metres, for resolution axis
    :return: a tuple with the (frequency, prtf) arrays, or (resoltion, frequency, prtf) if pixel_size_object is given
    """
    from pynx.utils.matplotlib import pyplot as plt
    plt.figure(101, figsize=(8, 4))
    ax1 = plt.gca()
    ax1.plot(freq / fnyquist, prtf)
    ax1.grid()
    ax1.set_xlabel(r"spatial frequency (nm$^{-1}$)")
    ax1.set_ylabel("PRTF")
    ax1.set_xlim(0, 1)
    ax1.set_ylim(0, 1)
    ax1.hlines(1 / np.exp(1), 0, 1, 'r', '--', label="Threshold (1/e)")
    ax1.legend()

    if pixel_size is not None:
        s = np.log10(pixel_size)
        # Add secondary X-axis with resolution in metric units
        if s < -6:
            unit_name = "nm"
            s = 1e9
        elif s < -3:
            unit_name = "$\mu m$"
            s = 1e6
        elif s < 0:
            unit_name = "mm"
            s = 1e3
        else:
            unit_name = "m"
            s = 1
        ax2 = ax1.twiny()
        x = plt.xticks()[0][1:]
        x2 = pixel_size * s / x
        ax2.set_xticks(x)
        ax2.set_xticklabels(["%.1f" % xx for xx in x2])
        ax2.set_xlabel(r"Resolution in %s" % (unit_name))

    plt.tight_layout()
    plt.show()
