########################################################################
#
# Example of the ptychograpic reconstruction using OpenCL on simulated data
# (c) ESRF
# Authors: Vincent Favre-Nicolin <favre@esrf.fr>
#
########################################################################

from pylab import *
from pynx.ptycho import simulation, shape
from pynx.ptycho.old import clptycho
import pyopencl as cl

##################
# Simulation of the ptychographic data:
n = 256
obj_info = {'type': 'phase_ampl', 'phase_stretch': pi / 2, 'alpha_win': .2}
probe_info = {'type': 'focus', 'aperture': (30e-6, 30e-6), 'focal_length': .08, 'defocus': 100e-6, 'shape': (n, n)}
probe_info = {'type': 'gauss', 'sigma_pix': (40, 40), 'defocus': 100e-6, 'shape': (n, n)}

# 50 scan positions correspond to 4 turns, 78 to 5 turns, 113 to 6 turns
scan_info = {'type': 'spiral', 'scan_step_pix': 30, 'n_scans': 128}
data_info = {'num_phot_max': 1e9, 'bg': 0, 'wavelength': 1.5e-10, 'detector_distance': 1, 'detector_pixel_size': 55e-6,
             'noise': 'poisson'}

# Initialisation of the simulation with specified parameters, specific <object>, <probe> or <scan>positions can be passed as:
# s = ptycho.Simulation(obj=<object>, probe=<probe>, scan = <scan>)
# omitting obj_info, probe_info or scan_info (or passing it as empty dictionary "{}")
s = simulation.Simulation(obj_info=obj_info, probe_info=probe_info, scan_info=scan_info, data_info=data_info)

# Data simulation: probe.show(), obj.show(), scan.show() and s.show_illumination_sum() will visualise the integrated total coverage of the beam
s.make_data()

posy, posx = s.scan.values
ampl = s.amplitude.values  # square root of the measured diffraction pattern intensity

##################
# Evaluation:
# Size of the reconstructed object (obj)
nyo, nxo = shape.calc_obj_shape(posy, posx, ampl.shape[1:])

# Initial object
# obj_init_info = {'type':'flat','shape':(nx,ny)}
obj_init_info = {'type': 'random', 'range': (0, 1, 0, 0.5), 'shape': (nyo, nxo)}
# Initial probe
probe_init_info = {'type': 'focus', 'aperture': (20e-6, 20e-6), 'focal_length': .08, 'defocus': 50e-6, 'shape': (n, n)}
data_info = {'wavelength': 1.5e-10, 'detector_distance': 1, 'detector_pixel_size': 55e-6}
init = simulation.Simulation(obj_info=obj_init_info, probe_info=probe_init_info, data_info=data_info)

init.make_obj()
init.make_probe()

# Run on GPU, first update only object and keep the probe
gpu = None  # Use 'Tesla' or 'GTX' or 'Iris' etc... to choose a GPU
if gpu is None:
    ctx = cl.create_some_context()  # interactive choice of GPU & OpenCL device
    gpu = ctx.devices[0]

doplot = True
lambdaz = 1.5e-10
pixel_size_object = 1e-8
p = clptycho.Ptycho2D(iobs=ampl ** 2, positions=(posy, posx), probe=s.probe.values.copy(), obj=init.obj.values.copy(),
                      opencl_device=gpu,
                      lambdaz=lambdaz, pixel_size_object=pixel_size_object)

p.run_difference_map(40, update_object=True, update_probe=False, verbose=20, doplot=doplot)
p.run_difference_map(40, update_object=True, update_probe=True, verbose=20, doplot=doplot)
p.run_alternating_projection(20, update_object=True, update_probe=False, verbose=20, doplot=doplot)
p.run_ml_poisson(50, update_object=True, update_probe=True, verbose=20, doplot=doplot)

if True:
    # reg_fac = 1e-4
    # p.run_ml_poisson(100, reg_fac=reg_fac, update_object=True, update_probe=True, verbose=20, doplot=doplot)

    nbprobe = 3  # number of probe modes
    ny, nx = p.probe.shape
    probe0 = np.empty((nbprobe, ny, nx), dtype=np.complex64)
    for i in range(nbprobe):
        probe0[i] = p.probe * np.random.uniform(0.8, 1.2, (ny, nx)) * 0.1 ** i

    p.probe = probe0.copy()

    p.run_difference_map(40, update_object=True, update_probe=True, verbose=20, doplot=doplot)
    p.run_alternating_projection(20, update_object=True, update_probe=False, verbose=20, doplot=doplot)
    p.run_ml_poisson(100, reg_fac=None, update_object=True, update_probe=True, verbose=20, doplot=doplot)

# Access to the working object
w = p._cl_workers[0]
