# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2019-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr
#         Julio Cesar da Silva (mailto: jdasilva@esrf.fr) (nyquist and ring_thickness code)


import numpy as np


def ring_thickness(shape):
    """
    Define ring_thick
    """
    n = shape
    nmax = np.max(n)
    x = np.arange(-np.fix(n[1] / 2.0), np.ceil(n[1] / 2.0)) * np.floor(nmax / 2.0) / np.floor(n[1] / 2.0)
    y = np.arange(-np.fix(n[0] / 2.0), np.ceil(n[0] / 2.0)) * np.floor(nmax / 2.0) / np.floor(n[0] / 2.0)
    if len(shape) == 3:
        z = np.arange(-np.fix(n[2] / 2.0), np.ceil(n[2] / 2.0)) * np.floor(nmax / 2.0) / np.floor(n[2] / 2.0)
        xx = np.meshgrid(x, y, z)
    elif len(shape) == 2:
        xx = np.meshgrid(x, y)
    else:
        print('Number of dimensions is different from 2 or 3.Exiting...')
        raise SystemExit('Number of dimensions is different from 2 or 3.Exiting...')
    sumsquares = np.zeros_like(xx[-1])
    for ii in np.arange(0, len(shape)):
        sumsquares += xx[ii] ** 2
    index = np.round(np.sqrt(sumsquares))
    return index


def nyquist(shape):
    """
    Evaluate the Nyquist Frequency
    """
    nmax = np.max(shape)
    fnyquist = np.floor(nmax / 2.0)
    freq = np.arange(0, fnyquist + 1)
    return freq, fnyquist


def prtf(icalc, iobs, mask=None, ring_thick=5):
    """
    Compute the phase retrieval transfer function, given calculated and observed intensity.

    :param icalc: the calculated intensity array (origin at array center), either 2D or 3D
    :param iobs: the observed intensity (origin at array center)
    :param mask: the mask, values > 0 indicating bad pixels in the observed intensity array
    :param ring_thick: the thickness of each shell or ring, in pixel units
    :return: a tuple with the (frequency, prtf) arrays, or (resoltion, frequency, prtf) if pixel_size_object is given
    """
    calc = np.sqrt(icalc)
    obs = np.sqrt(abs(iobs))
    calc_f, obs_f, nb = [], [], []
    index = ring_thickness(iobs.shape)
    freq, fnyquist = nyquist(iobs.shape)
    if mask is None:
        mask = np.zeros(iobs.shape, dtype=np.int8)
    for ii in freq:
        tmp = np.where(np.logical_and(index == ii, mask == 0))
        if len(tmp[0]):
            if ring_thick == 0:
                tmpcalc = calc[tmp]
                tmpobs = obs[tmp]
            else:
                r2 = ring_thick / 2
                tmpcalc = calc[(np.where((index >= (ii - r2)) & (index <= (ii + r2)) & (mask == 0)))]
                tmpobs = obs[(np.where((index >= (ii - r2)) & (index <= (ii + r2)) & (mask == 0)))]
            calc_f.append(tmpcalc.sum())
            obs_f.append(tmpobs.sum())
        else:
            # All values are masked (central stop ?)
            # print(ii)
            calc_f.append(0)
            obs_f.append(0)
    calc_f = np.array(calc_f)
    obs_f = np.array(obs_f)
    pr = np.ma.masked_array(calc_f / (obs_f + (obs_f == 0)), mask=(obs_f == 0))
    return freq, pr
