########################################################################
#
# Example of the ptychograpic reconstruction using OpenCL on simulated data
# (c) ESRF
# Authors: Vincent Favre-Nicolin <favre@esrf.fr>
#
########################################################################

from pylab import *
import numpy as np
from pynx.ptycho import simulation, shape
from pynx.ptycho.old import clptycho
import pyopencl as cl

##################
# Simulation of the ptychography data:
n = 256
obj_info = {'type': 'phase_ampl', 'phase_stretch': pi / 2, 'alpha_win': .2}
probe_info = {'type': 'focus', 'aperture': (60e-6, 200e-6), 'focal_length': .08, 'defocus': 100e-6, 'shape': (n, n)}

# 50 scan positions correspond to 4 turns, 78 to 5 turns, 113 to 6 turns
scan_info = {'type': 'spiral', 'scan_step_pix': 30, 'n_scans': 100}
data_info = {'num_phot_max': 1e9, 'bg': 0, 'wavelength': 1.5e-10, 'detector_dist': 1, 'detector_pixel_size': 55e-6,
             'noise': None}

# Initialisation of the simulation with specified parameters, specific <object>, <probe> or <scan>positions can be passed as:
# s = ptycho.Simulation(obj=<object>, probe=<probe>, scan = <scan>)
# omitting obj_info, probe_info or scan_info (or passing it as empty dictionary "{}")
s = simulation.Simulation(obj_info=obj_info, probe_info=probe_info, scan_info=scan_info, data_info=data_info)

# Data simulation: probe.show(), obj.show(), scan.show() and s.show_illumination_sum() will visualise the integrated total coverage of the beam
s.make_data()

posy, posx = s.scan.values
iobs = s.amplitude.values ** 2  # square root of the measured diffraction pattern intensity

# Add some background, for 20% of average frame intensity
iobs_mean = iobs.sum(axis=0).mean()
background = simulation.gauss2D((n, n), sigma=(n / 4, n / 4))
background += background.max() / 4
background *= iobs_mean * 0.05 / background.sum()
# Use Poisson statistics for the detector
for frame in iobs:
    frame[:] = np.random.poisson(frame + background)
# background = None

##################
# Evaluation:
# Size of the reconstructed object (obj)
ny, nx = shape.calc_obj_shape(posy, posx, iobs.shape[1:])

# Initial object
# obj_init_info = {'type':'flat','shape':(nx,ny)}
obj_init_info = {'type': 'random', 'range': (0, 1, 0, 0.5), 'shape': (nx, ny)}
# Initial probe
probe_init_info = {'type': 'focus', 'aperture': (50e-6, 180e-6), 'focal_length': .08, 'defocus': 50e-6, 'shape': (n, n)}
data_info = {'wavelength': 1.5e-10, 'detector_dist': 1, 'detector_pixel_size': 55e-6}
init = simulation.Simulation(obj_info=obj_init_info, probe_info=probe_init_info, data_info=data_info)

init.make_obj()
init.make_probe()

# Run on GPU, first update only object and keep the probe
gpu = None  # Use 'Tesla' or 'GTX' or 'Iris' etc... to choose a GPU
if gpu is None:
    ctx = cl.create_some_context()  # interactive choice of GPU & OpenCL device
    gpu = ctx.devices[0]

doplot = True
lambdaz = None  # 1.5e-10
pixel_size_object = 1e-8

# Start without a background
p = clptycho.Ptycho2D(iobs=iobs, positions=(posy, posx), probe=s.probe.values.copy(), obj=init.obj.values.copy(),
                      opencl_device=gpu, lambdaz=lambdaz, pixel_size_object=pixel_size_object, background=None)

p.run_difference_map(100, update_object=True, update_probe=True, verbose=10, doplot=doplot)
p.run_alternating_projection(100, update_object=True, update_probe=True, verbose=10, doplot=doplot)
p.run_ml_poisson(100, update_object=True, update_probe=True, update_background=False, verbose=10, doplot=doplot)
p.run_alternating_projection(100, update_object=True, update_probe=True, update_background=True, verbose=10,
                             doplot=doplot)
p.run_ml_poisson(100, update_object=True, update_probe=True, update_background=False, verbose=10, doplot=doplot)
p.run_alternating_projection(100, update_object=True, update_probe=True, update_background=True, verbose=10,
                             doplot=doplot)
# ML optimisation of background is not very efficient yet
p.run_ml_poisson(100, update_object=True, update_probe=True, update_background=True, verbose=10, doplot=doplot)
