# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

import numpy as np
from scipy.ndimage import gaussian_filter
from scipy.fftpack import fftshift
import fabio
from matplotlib import pyplot as plt

from pynx.cdi import *

# CDI example from an experimental data set from id10 (courtesy of Yuriy Chushkin)

iobs = np.flipud(fabio.open("data/logo5mu_20sec.edf").data)
support = np.flipud(fabio.open("data/mask_logo5mu_20sec.edf").data)
n = len(iobs)
x, y = np.meshgrid(np.arange(0, n, dtype=np.float32), np.arange(0, n, dtype=np.float32))

# Mask specific to this dataset (from beamstop, after symmetrization of observed data)
mask = np.logical_or(iobs != 0, np.logical_or(abs(x - 256) > 30, abs(y - 256) > 30))
mask *= np.logical_or(iobs != 0, np.logical_or(abs(x - 266) > 30, abs(y) < 495))
mask *= np.logical_or(iobs != 0, np.logical_or(abs(x - 246) > 30, abs(y) > 20))
mask *= np.logical_or(iobs != 0, np.logical_or(abs(x - 10) > 30, abs(y - 240) > 5))
mask *= np.logical_or(iobs != 0, np.logical_or(abs(x - 498) > 30, abs(y - 270) > 5))
mask = (mask == 0)  # 0: OK, 1: masked

plt.figure(1, figsize=(8, 8))

#========================= Try first from the known support (easy !) =============================
cdi = CDI(fftshift(iobs), obj=None, support=fftshift(support), mask=fftshift(mask), wavelength=1e-10,
          pixel_size_detector=55e-6)
# Initial scaling, required by mask
cdi = ScaleObj(method='F') * cdi

# Do 4 * (50 cycles of HIO + 20 of ER), displaying object after each group of cycle
cdi = (ShowCDI(fig_num=1) * ER() ** 20 * HIO() ** 50) ** 4 * cdi

print("\n======================================================================\n")
print("This was too easy - start again from a loose support !\n")

#========================= Try again from a loose support ========================================
support = np.flipud(fabio.open("data/mask_logo5mu_20sec.edf").data)
support = gaussian_filter(support.astype(np.float32), 4) > 0.2
cdi = CDI(fftshift(iobs), obj=None, support=fftshift(support), mask=fftshift(mask), wavelength=1e-10,
          pixel_size_detector=55e-6)

# Initial scaling, required by mask
cdi = ScaleObj(method='F') * cdi

# Do 50*4 cycles of HIO, displaying object every 50 cycle
cdi = (ShowCDI(fig_num=1) * HIO() ** 50) ** 4 * cdi

# Compute LLK
IFT() * LLK() * FT() * cdi
print("LLK_n = %8.3f" % (cdi.get_llk(noise='poisson')))

for i in range(20):
    # Support update operator
    s = 0.5 + 2 * np.exp(-i / 4)
    sup = SupportUpdate(threshold_relative=0.25, smooth_width=s, force_shrink=False)

    # Do 40 cycles of HIO, then 5 of ER
    cdi = ER() ** 5 * HIO() ** 40 * cdi

    # Update support & display current object & diffraction
    cdi = ShowCDI(fig_num=1) * sup * cdi

    IFT() * LLK() * FT() * cdi

    print("HIO+ER #%3d: LLK_n = %8.3f" % (i * 45, cdi.get_llk(noise='poisson')))

# Finish with ML
for i in range(5):
    cdi = ML(reg_fac=1e-2, nb_cycle=20) * cdi

    # Display current object & diffraction
    cdi = ShowCDI(fig_num=1) * cdi

    # Compute log-likelihood
    IFT() * LLK() * FT() * cdi
    print("ML #%3d: LLK_n = %8.3f" % (i * 20, cdi.get_llk(noise='poisson')))
