# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

import numpy as np
from scipy.fftpack import ifftshift, fftshift, fft2

# This can be used to select the GPU and/or the language, and must be called before other pynx imports
# Otherwise a default GPU will be used according to its speed, which is usually sufficient
#
# Alternatively, it is possible to use the PYNX_PU environment variable to choose the GPU or language,
# e.g. using PYNX_PU=opencl or PYNX_PU=cuda or PYNX_PU=cpu or PYNX_PU=Titan, etc..
if False:
    from pynx.processing_unit import default_processing_unit
    default_processing_unit.select_gpu(language='OpenCL', gpu_name='R9')

from pynx.utils.pattern import siemens_star
from pynx.cdi import *


# Test on a simulated pattern (2D)
n = 512

# Siemens-Star object
obj0 = siemens_star(n, nb_rays=18, r_max=60, nb_rings=3)

# Start from a slightly loose disc support
x, y = np.meshgrid(np.arange(-n // 2, n // 2, dtype=np.float32), np.arange(-n // 2, n // 2, dtype=np.float32))
r = np.sqrt(x ** 2 + y ** 2)
support = r < 65

iobs = abs(ifftshift(fft2(fftshift(obj0.astype(np.complex64))))) ** 2
iobs = np.random.poisson(iobs * 1e10 / iobs.sum())
mask = np.zeros_like(iobs, dtype=np.int16)
if True:
    # Mask some values in the central beam
    print("Removing %6.3f%% intensity" % (iobs[255:257, 255:257].sum() / iobs.sum() * 100))
    iobs[255:257, 255:257] = 0
    mask[255:257, 255:257] = 1

cdi = CDI(fftshift(iobs), obj=None, support=fftshift(support), mask=fftshift(mask), wavelength=1e-10,
          pixel_size_detector=55e-10)

# Initial scaling of the object [ only useful if there are masked pixels !]
cdi = ScaleObj(method='F') * cdi

show = True

# Do 100 cycles of RAAR
cdi = RAAR() ** 100 * cdi

# Compute LLK
IFT() * LLK() * FT() * cdi
print("LLK_n = %8.3f" % (cdi.get_llk(noise='poisson')))

if show:
    cdi = ShowCDI(fig_num=1) * cdi

for i in range(20):
    # Support update operator
    s = 0.25 + 1.75 * np.exp(-i / 4)
    sup = SupportUpdate(threshold_relative=0.3, smooth_width=s, force_shrink=False)

    if i == 2:
        cdi = DetwinRAAR(nb_cycle=10) * cdi

    # Do 40 cycles of RAAR, then 5 of ER, then update support
    cdi = sup * ER() ** 5 * RAAR() ** 40 * cdi

    if show:
        cdi = ShowCDI(fig_num=1) * cdi

    # Compute log-likelihood
    IFT() * LLK() * FT() * cdi

    print("RAAR+ER #%3d: LLK_n = %8.3f" % (i * 45, cdi.get_llk(noise='poisson')))

# Finish with ML (or could use ER)
for i in range(5):
    # Do 20 cycles of ML
    cdi = ML(reg_fac=1e-2) ** 20 * cdi

    if show:
        cdi = ShowCDI(fig_num=1) * cdi

    IFT() * LLK() * FT() * cdi

    print("ML #%3d: LLKn = %8.3f" % (i * 20, cdi.get_llk(noise='poisson')))

cdi = ShowCDI(fig_num=1) * cdi
