# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2008-2015 : Univ. Joseph Fourier (Grenoble 1), CEA/INAC/SP2M
#   (c) 2013-2014 : Fondation Nanosciences, Grenoble
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr
#         Ondrej Mandula

import numpy as np
from scipy.linalg import polar


def primes(n):
    """ Returns the prime decomposition of n as a list
    """
    v = [1]
    assert (n > 0)
    i = 2
    while i * i <= n:
        while n % i == 0:
            v.append(i)
            n //= i
        i += 1
    if n > 1:
        v.append(n)
    return v


def test_smaller_primes(n, maxprime=13, required_dividers=(4,)):
    """
    Test if the largest prime divider is <=maxprime, and optionally includes some dividers.
    
    Args:
        n: the integer number for which the prime decomposition will be tested
        maxprime: the maximum acceptable prime number. This defaults to the largest integer accepted by the clFFT library for OpenCL GPU FFT.
        required_dividers: list of required dividers in the prime decomposition. If None, this test is skipped.
    Returns:
        True if the conditions are met.
    """
    p = primes(n)
    if max(p) > maxprime:
        return False
    if required_dividers is not None:
        for k in required_dividers:
            if n % k != 0:
                return False
    return True


def smaller_primes(n, maxprime=13, required_dividers=(4,), decrease=True):
    """ Find the closest integer <= or >=n (or list/array of integers), for which the largest prime divider
    is <=maxprime, and has to include some dividers.
    The default values for maxprime is the largest integer accepted by the clFFT library for OpenCL GPU FFT.
    
    Args:
        n: the integer number
        maxprime: the largest prime factor acceptable
        required_dividers: a list of required dividers for the returned integer.
        decrease: if True (thed default), the integer returned will be <=n, otherwise it will be >=n.
    Returns:
        the integer (or list/array of integers) fulfilling the requirements
    """
    if (type(n) is list) or (type(n) is tuple) or (type(n) is np.ndarray):
        vn = [smaller_primes(i, maxprime=maxprime, required_dividers=required_dividers, decrease=decrease) for i in n]
        if type(n) is np.ndarray:
            return np.array(vn)
        else:
            return type(n)(tuple(vn))
    else:
        if maxprime < n:
            assert (n > 1)
            while test_smaller_primes(n, maxprime=maxprime, required_dividers=required_dividers) is False:
                if decrease:
                    n = n - 1
                    if n == 0:
                        # TODO: should raise an exception
                        return 0
                else:
                    n = n + 1
        return n


def ortho_modes(m):
    """
    Orthogonalize modes from a 3D array of shape (nbmode, ny, nx), using polar decomposition.
    The decomposition is such that the total intensity (i.e. (abs(m)**2).sum()) is conserved.
    
    Args:
        m: the 3d stack of modes to orthogonalize, with shape (nbmode, ny, nx)
    Returns:
        an array (mo) with the same shape ias given in input, but with orthogonal modes, i.e. (mo[i]*mo[j].conj()).sum()=0 for i!=j
        The modes are sorted by decreasing norm.
    """
    nz, ny, nx = m.shape
    mm = np.reshape(m, (nz, ny * nx))
    # Polar decomposition: m = dot(p, u), with p Hermitian matrix and u orthonormal modes
    u, p = polar(mm, side='left')

    mo = np.empty_like(m)
    for i in range(nz):
        mo[i] = np.reshape(u[i], (ny, nx)) * np.sqrt((abs(p[:, i]) ** 2).sum())

    return mo[np.argsort(-(abs(mo) ** 2).sum(axis=(1, 2)))]


def full_width(x, y, ratio=0.5, outer=False):
    """
    Determine the full-width from an XY dataset.
    
    Args:
        x: the abscissa (1d array) of the data to be fitted. Assumed to be in monotonic ascending order
        y: the y-data as a 1d array. The absolute value of the array will be taken if it is complex.
        ratio: the fraction (default=0.5 for FWHM) at which the width should be measured
        outer: if True, the width will be measured by taking the outermost points which fall below the ratio to the maiximum.
               Otherwise, the width will be taken as the width around the maximum (regardless of secondary peak which may be above maw*ratio)

    Returns:
        the full width
    """
    n = len(y)
    if np.iscomplexobj(y):
        y = abs(y)
    imax = y.argmax()
    if imax == y.size:
        imax -= 1
    ay2 = y[imax] * ratio
    if outer:
        # Find the first and last values below max/2, in the entire array.
        ix1 = np.where(y[:imax + 1] > ay2)[0][0] - 1
        ix2 = np.where(y[imax:] > ay2)[0][-1] + imax + 1
    else:
        # Find the first values below max/2, left and right of the peak. This allows FWHM to be found even if secondary peaks are > max/2
        ix1 = np.where(y[:imax] < ay2)[0][-1]
        ix2 = np.where(y[imax:] < ay2)[0][0] + imax

    if ix1 >= 0:
        v0, v1 = y[ix1], y[ix1 + 1]
        xleft = (x[ix1 + 1] * (ay2 - v0) + x[ix1] * (v1 - ay2)) / (v1 - v0)
    else:
        xleft = x[0]

    if ix2 <= (n - 1):
        v2, v3 = y[ix2 - 1], y[ix2]
        xright = (x[ix2 - 1] * (ay2 - v3) + x[ix2] * (v2 - ay2)) / (v2 - v3)
    else:
        xright = x[n - 1]
    # print(ix1,imax,ix2, xleft, xright)
    return xright - xleft


def llk_poisson(iobs, imodel):
    """
    Compute the Poisson log-likelihood for a calculated intensity, given observed values.
    The value computed is normalized so that its asymptotic value (for a large number of points) is equal
    to the number of observed points

    Args:
        iobs: the observed intensities
        imodel: the calculated/model intensity

    Returns:
        The negative log-likelihood.
    """
    if np.isscalar(iobs):
        if iobs > 0:
            llk = imodel - iobs + iobs * np.log(iobs / imodel)
        else:
            llk = imodel
    else:
        llk = np.empty(iobs.shape, dtype=np.float32)
        idx = np.where(iobs.flat > 0)
        llk.flat[idx] = np.take((imodel - iobs + iobs * np.log(iobs / imodel)).flat, idx)
        idx = np.where(iobs.flat <= 0)
        llk.flat[idx] = np.take(imodel.flat, idx)
    return 2 * llk


def llk_gaussian(iobs, imodel):
    """
    Compute the Gaussian log-likelihood for a calculated intensity, given observed values.
    The value computed is normalized so that its asymptotic value (for a large number of points) is equal
    to the number of observed points.

    Args:
        iobs: the observed intensities
        imodel: the calculated/model intensity

    Returns:
        The negative log-likelihood.
    """
    return (imodel - iobs) ** 2 / (iobs + 1)


def llk_euclidian(iobs, imodel):
    """
    Compute the Eucldian log-likelihood for a calculated intensity, given observed values.
    The value computed is normalized so that its asymptotic value (for a large number of points) is equal
    to the number of observed points. This model is valid if obs and calc are reasonably close.

    Args:
        iobs: the observed intensities
        imodel: the calculated/model intensity

    Returns:
        The negative log-likelihood.
    """
    return 4 * (np.sqrt(imodel) - np.sqrt(iobs)) ** 2


if __name__ == '__main__':
    from pylab import *

    rc('text', usetex=True)
    if True:
        # Testing asymptotic values for the Poisson log-likelihood, for observed data following Poisson statistics.
        figure()
        nb = 2 ** 20
        vim = np.arange(0, 16, dtype=np.float32)
        vsum = np.zeros_like(vim)
        vsumllk = np.zeros_like(vim)
        for i in range(vim.size):
            m = 2 ** vim[i]
            print(m)
            imodel = np.random.uniform(0, m, nb)
            iobs = np.random.poisson(imodel)
            vsum[i] = iobs.sum()
            vsumllk[i] = llk_poisson(iobs, imodel).sum()
        semilogx(vsum / nb, vsumllk / nb, 'k.', label='Uniform')

        vsum = np.zeros_like(vim)
        vsumllk = np.zeros_like(vim)
        for i in range(vim.size):
            m = 2 ** vim[i]
            print(m)
            # imodel = np.random.pareto(2, nb) * m
            imodel = np.random.exponential(m, nb)
            iobs = np.random.poisson(imodel)
            vsum[i] = iobs.sum()
            vsumllk[i] = llk_poisson(iobs, imodel).sum()
        semilogx(vsum / nb, vsumllk / nb, 'r.', label='Exponential')
        legend()
        xlabel("$<I_{obs}>$")
        ylabel("$<LLK>$")
        title(r"$<LLK_{Poisson}>=\frac{2}{N_{obs}}\left\{\displaystyle\sum_{I_{obs}>0}\left[I_{obs}"
              r"*ln(\frac{I_{obs}}{I_{calc}}) + I_{calc} -I_{obs}\right]+ \sum_{I_{obs}=0}I_{calc}\right\}$")

    if True:
        # Testing asymptotic values for the Gaussian log-likelihood, for observed data following Poisson statistics.
        figure()
        nb = 2 ** 20
        vim = np.arange(0, 16, dtype=np.float32)
        vsum = np.zeros_like(vim)
        vsumllk = np.zeros_like(vim)
        for i in range(vim.size):
            m = 2 ** vim[i]
            print(m)
            imodel = np.random.uniform(0, m, nb)
            iobs = np.random.poisson(imodel)
            vsum[i] = iobs.sum()
            vsumllk[i] = llk_gaussian(iobs, imodel).sum()
        semilogx(vsum / nb, vsumllk / nb, 'k.', label='Uniform')

        vsum = np.zeros_like(vim)
        vsumllk = np.zeros_like(vim)
        for i in range(vim.size):
            m = 2 ** vim[i]
            print(m)
            # imodel = np.random.pareto(2, nb) * m
            imodel = np.random.exponential(m, nb)
            iobs = np.random.poisson(imodel)
            vsum[i] = iobs.sum()
            vsumllk[i] = llk_gaussian(iobs, imodel).sum()
        semilogx(vsum / nb, vsumllk / nb, 'r.', label='Exponential')
        legend()
        xlabel("$<I_{obs}>$")
        ylabel("$<LLK>$")
        title(r"$<LLK_{Gaussian}>=\frac{1}{N_{obs}}\displaystyle\sum\frac{(I_{obs}-I_{calc})^2}{I_{obs}+1}$")

    if True:
        # Testing asymptotic values for the Euclidian log-likelihood, for observed data following Poisson statistics.
        figure()
        nb = 2 ** 20
        vim = np.arange(0, 16, dtype=np.float32)
        vsum = np.zeros_like(vim)
        vsumllk = np.zeros_like(vim)
        for i in range(vim.size):
            m = 2 ** vim[i]
            print(m)
            imodel = np.random.uniform(0, m, nb)
            iobs = np.random.poisson(imodel)
            vsum[i] = iobs.sum()
            vsumllk[i] = llk_euclidian(iobs, imodel).sum()
        semilogx(vsum / nb, vsumllk / nb, 'k.', label='Uniform')

        vsum = np.zeros_like(vim)
        vsumllk = np.zeros_like(vim)
        for i in range(vim.size):
            m = 2 ** vim[i]
            print(m)
            # imodel = np.random.pareto(2, nb) * m
            imodel = np.random.exponential(m, nb)
            iobs = np.random.poisson(imodel)
            vsum[i] = iobs.sum()
            vsumllk[i] = llk_euclidian(iobs, imodel).sum()
        semilogx(vsum / nb, vsumllk / nb, 'r.', label='Exponential')
        legend()
        xlabel("$<I_{obs}>$")
        ylabel("$<LLK>$")
        title(r"$<LLK_{Euclidian}>=\frac{4}{N_{obs}}\displaystyle\sum(\sqrt{I_{obs}}-\sqrt{I_{calc}})^2$")
