########################################################################
#
# Example of the ptychograpic reconstruction using OpenCL on simulated data
# (c) ESRF
# Authors: Vincent Favre-Nicolin <favre@esrf.fr>
#
########################################################################

from pylab import *
# OLD
from pynx.ptycho import simulation, shape
from pynx.ptycho.old import clptycho
# NEW with operators
from pynx.ptycho.ptycho import *
from pynx.ptycho.operator import *  # Use CUDA > OpenCL > CPU operators, as available

##################
# Simulation of the ptychographic data:
n = 256
detector_distance = np.float32(1)
wavelength = np.float32(1.5e-10)
nb_frame = 16
obj_info = {'type': 'phase_ampl', 'phase_stretch': pi / 2, 'alpha_win': .2}
probe_info = {'type': 'focus', 'aperture': (30e-6, 30e-6), 'focal_length': .08, 'defocus': 100e-6, 'shape': (n, n)}
probe_info = {'type': 'gauss', 'sigma_pix': (40, 40), 'defocus': 100e-6, 'shape': (n, n)}

# 50 scan positions correspond to 4 turns, 78 to 5 turns, 113 to 6 turns
scan_info = {'type': 'spiral', 'scan_step_pix': 60, 'n_scans': nb_frame}
data_info = {'num_phot_max': 1e9, 'bg': 0, 'wavelength': wavelength, 'detector_distance': detector_distance,
             'detector_pixel_size': 55e-6, 'noise': 'poisson'}

# Initialisation of the simulation with specified parameters, specific <object>, <probe> or <scan>positions can be passed as:
# s = ptycho.Simulation(obj=<object>, probe=<probe>, scan = <scan>)
# omitting obj_info, probe_info or scan_info (or passing it as empty dictionary "{}")
s = simulation.Simulation(obj_info=obj_info, probe_info=probe_info, scan_info=scan_info, data_info=data_info)

# Data simulation: probe.show(), obj.show(), scan.show() and s.show_illumination_sum() will visualise the integrated total coverage of the beam
s.make_data()

posy, posx = s.scan.values
ampl = s.amplitude.values  # square root of the measured diffraction pattern intensity

##################
# Evaluation:
# Size of the reconstructed object (obj)
nyo, nxo = shape.calc_obj_shape(posy, posx, ampl.shape[1:])

# Initial object
# obj_init_info = {'type':'flat','shape':(nx,ny)}
obj_init_info = {'type': 'random', 'range': (0, 1, 0, 0.5), 'shape': (nyo, nxo)}
# Initial probe
probe_init_info = {'type': 'focus', 'aperture': (20e-6, 20e-6), 'focal_length': .08, 'defocus': 50e-6, 'shape': (n, n)}
data_info = {'wavelength': 1.5e-10, 'detector_distance': detector_distance, 'detector_pixel_size': 55e-6}
init = simulation.Simulation(obj_info=obj_init_info, probe_info=probe_init_info, data_info=data_info)

init.make_obj()
init.make_probe()

doplot = True

##################################################### END COMMON INITIALIZATION PART ##############################

##### Ptycho with Operators #####
data = PtychoData(iobs=ampl ** 2, positions=(posy, posx), detector_distance=detector_distance, mask=None,
                  pixel_size_detector=55e-6, wavelength=wavelength)

p = ScaleObjProbe() * Ptycho(probe=s.probe.values, obj=init.obj.values, data=data, background=None)
# p = ShowObjProbe(2) * p
# p = ShowObjProbe(3) * ScaleObjProbe() * p
# p = ShowObjProbe(1) * AP(update_object=True, update_probe=False) ** 20 * p
# p = ShowObjProbe(1) * AP(update_object=True, update_probe=True) ** 20 * p
p = DM(update_object=True, update_probe=True) ** 100 * p

if True:
    ##### Ptycho old API #####
    lambdaz = p.data.detector_distance * p.data.wavelength
    pixel_size_object = p.pixel_size_object
    pp = clptycho.Ptycho2D(iobs=ampl ** 2, positions=(posy, posx), probe=s.probe.values.copy(),
                           obj=init.obj.values.copy(),
                           opencl_device=default_processing_unit.cl_device, lambdaz=lambdaz,
                           pixel_size_object=pixel_size_object)

    # pp.run_difference_map(40, update_object=True, update_probe=False, verbose=20, doplot=doplot)
    pp.run_difference_map(100, update_object=True, update_probe=True, verbose=20, doplot=doplot)
    # pp.run_alternating_projection(20, update_object=True, update_probe=False, verbose=20, doplot=doplot)
    # pp.run_alternating_projection(20, update_object=True, update_probe=True, verbose=20, doplot=doplot)
    w = pp._cl_workers[0]

pp.run_ml_poisson(100, update_object=True, update_probe=True, verbose=20, doplot=doplot)

p = ML(update_object=True, update_probe=True, show_obj_probe=10, calc_llk=10) ** 100 * p
# wpsi1 = w.psi1[0]
# ppsi1 = p.psi1[0,0]
# figure(3,figsize=(12,5)); clf() ; subplot(121); imshow(np.log10(abs(wpsi1))); title("old") ; subplot(122) ; imshow(np.log10(abs(ppsi1))); title("new")
# figure(4); clf() ; imshow(abs(wpsi1/ppsi1)) ; colorbar()  # ,vmin=0.9,vmax=1.1
# print(abs(p.probe1).mean(), abs(w.probe1).mean())
# print(abs(p.obj1).mean(), abs(w.obj1).mean())
