# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

from __future__ import division

import numpy as np
from numpy import pi
from scipy import misc
import sys


def get_img(index=0):
    """
    Returns image (numpy array) from scipy.misc
    
    Args:
        index:
            0-> return scipy.misc.face()[:,:,1], cropped to 512x512
            1 -> return 512x512 scipy.misc.ascent()

    """
    if index == 0:
        return np.flipud(misc.face(gray=True)[100:612, 280:792])
    return np.flipud(misc.ascent())


def siemens_star(dsize=512, nb_rays=36, r_max=None, nb_rings=8, cheese_holes_nb=0, cheese_hole_max_radius=5):
    """
    Calculate a binary Siemens star.
    
    Args:
        dsize: size in pixels for the 2D array with the star data
        nb_rays: number of radial branches for the star. Must be > 0
        r_max: maximum radius for the star in pixels. If None, dsize/2 is used
        nb_rings: number of rings (the rays will have some holes between successive rings)
        cheese_holes_nb: number of cheese holes other the entire area, resulting more varied frequencies
        cheese_hole_max_radius: maximum axial radius for the holes (with random radius along x and y)

    Returns:
        a 2D array with the Siemens star.
    """
    if r_max is None:
        r_max = dsize // 2
    x, y = np.meshgrid(np.arange(-dsize // 2, dsize // 2, dtype=np.float32),
                       np.arange(-dsize // 2, dsize // 2, dtype=np.float32))

    a = np.arctan2(y, x)
    r = np.sqrt(x ** 2 + y ** 2)
    am = 2 * pi / nb_rays
    d = (a % (am)) < (am / 2)
    if r_max != 0 and r_max is not None:
        d *= r < r_max
    if nb_rings != 0 and nb_rings is not None:
        if r_max is None:
            rm = dsize * np.sqrt(2) / 2 / nb_rings
        else:
            rm = r_max / nb_rings
        d *= (r % rm) < (rm * 0.9)
    if cheese_holes_nb > 0:
        cx = np.random.randint(x.min(), x.max(), cheese_holes_nb)
        cy = np.random.randint(y.min(), y.max(), cheese_holes_nb)
        rx = np.random.uniform(1, cheese_hole_max_radius, cheese_holes_nb)
        ry = np.random.uniform(1, cheese_hole_max_radius, cheese_holes_nb)
        for i in range(cheese_holes_nb):
            dn = int(np.ceil(max(rx[i], ry[i])))
            x0, x1 = dsize // 2 + cx[i] - dn, dsize // 2 + cx[i] + dn
            y0, y1 = dsize // 2 + cy[i] - dn, dsize // 2 + cy[i] + dn
            d[y0:y1, x0:x1] *= (((x[y0:y1, x0:x1] - cx[i]) / rx[i]) ** 2 + ((y[y0:y1, x0:x1] - cy[i]) / ry[i]) ** 2) > 1
    return d.astype(np.float32)


def fibonacci_urchin(dsize=256, nb_rays=100, r_max=64, nb_rings=8):
    """
    Calculate a binary urchin (in 3D).
    
    Args:
        dsize: size in pixels for the 2D array with the star data
        nb_rays: number of radial branches. Must be > 0
        r_max: maximum radius in pixels
        nb_rings: number of rings (the rays will have some holes between successive rings)

    Returns:
        a 3D array with the binary urchin.
    """
    tmp = np.arange(-dsize // 2, dsize // 2, dtype=np.float32)
    z, y, x = np.meshgrid(tmp, tmp, tmp, indexing='ij')
    r = np.sqrt(x ** 2 + y ** 2 + z ** 2) + 1e-6

    # Generate points on a sphere of radius=1
    i = np.arange(nb_rays)
    z1 = i * 2. / nb_rays - 1 + 1. / nb_rays
    r1 = np.sqrt(1 - z1 ** 2)
    phi1 = i * np.pi * (3 - np.sqrt(5))
    x1 = np.cos(phi1) * r1
    y1 = np.sin(phi1) * r1

    # Approximate distance between points on a sphere with unit radius (approximation assuming an hexagonal packing)
    d = np.sqrt(8. / 3. / np.sqrt(3) * 4 * np.pi / nb_rays)
    # Approximate average angular distance between points
    da = d / 2

    rho = np.zeros_like(x)

    sys.stdout.write("Simulating 3d binary urchin (this WILL take a while)...")
    sys.stdout.flush()
    i = nb_rays
    for xi, yi, zi in zip(x1, y1, z1):
        # This could go MUCH faster on a GPU
        sys.stdout.write('%d ' % (i))
        sys.stdout.flush()
        rho += ((x * xi + y * yi + z * zi) / r) > np.cos(da / 2)
        i -= 1
    print("\n")

    if r_max is not None:
        rho *= r < r_max

    if nb_rings is not None:
        if r_max is None:
            rm = dsize * np.sqrt(3) / 2 / nb_rings
        else:
            rm = r_max / nb_rings
        rho *= (r % rm) < (rm * 0.8)

    return rho
