########################################################################
#
# Example of the ptychograpic reconstruction of simulated data with refinement of the scan positions.
# (c) Univ. Grenoble Alpes
# (c) Fondation Nanosciences
# (c) CEA Grenoble / INAC
# Authors: Vincent Favre-Nicolin, Ondrej Mandula
#
########################################################################

from pylab import *
from pynx.ptycho import simulation, shape
from pynx.ptycho.old import clptycho
import pyopencl as cl

savethis = True

# Simulaiton of the ptychography data:
obj_info = {'type': 'phase_ampl', 'phase_stretch': pi / 5,
            'alpha_win': .2}  # or one can pass the obj (2D complex image) into the Simulation: s = ptycho.Simulation(obj = img) ....
probe_info = {'type': 'disc', 'radius_pix': (50),
              'shape': (256, 256)}  # 'FZP', 'gauss' or 'flat' are other possibilities
scan_info = {'type': 'spiral', 'scan_step_pix': 30,
             'n_scans': 50}  # 50 scan positions correspond to 4 turns, 78 to 5 turns, 113 to 6 turns
data_info = {'pix_size_direct_nm': 10, 'num_phot_max': 1e6, 'bg': 0}

# Initialisation of the simulation with specivied parameters, specific <object>, <probe> or <scan>positions can be passed as s = ptycho.Simulation(obj=<object>, probe=<probe>, scan = <scan>) omitting obj_info, probe_info or scan_info (or passing it as empty dictionary "{}")
s = simulation.Simulation(obj_info=obj_info, probe_info=probe_info, scan_info=scan_info, data_info=data_info)

# Data simulation: probe.show(), obj.show(), scan.show() and s.show_illumination_sum() will visualise the integrated total coverage of the beam
s.make_data()

pos = s.scan.values  # positions of the scan
ampl = s.amplitude.values  # square root of the measured diffraction pattern intensity

# Run on GPU, first update only object and keep the probe
gpu = None  # Use 'Tesla' or 'GTX' or 'Iris' etc... to choose a GPU
if gpu is None:
    ctx = cl.create_some_context()  # interactive choice of GPU & OpenCL device
    gpu = ctx.devices[0]

posy, posx = pos

# TODO Adding a random displacement to the true positions
sigPos = 0
posx_rand = posx + sigPos * np.random.randn(posx.size)
posy_rand = posy + sigPos * np.random.randn(posy.size)
pos_rand = (posy_rand, posx_rand)

# Size of the reconstructed object (obj)
ny, nx = shape.calc_obj_shape(posy_rand, posx_rand, ampl.shape[1:])

# Specify the initial object
obj_init_info = {'type': 'flat', 'shape': (nx + 40, ny + 40)}  # to keep margins for optimisation of the positions
# Specify the initial probe
# probe_init_info = {'type':'disc','shape': s.probe.values.shape, 'radius_pix':40}
probe0 = s.probe.values  # Situation when the correct probe is known.
init = simulation.Simulation(obj_info=obj_init_info, probe_info=probe_info)
init.make_obj()
# init.make_probe()
# Initialisation of the data object p with ptychograhic data (views), initial probe (probe0.values) and initial object (obj0.values)

doplot = True

p = clptycho.Ptycho2D(iobs=ampl ** 2, positions=pos_rand, probe=probe0.copy(), obj=init.obj.values.copy(),
                      opencl_device=gpu)

p.run_difference_map(100, update_object=True, update_probe=False, verbose=20, doplot=doplot)
p.run_difference_map(100, update_object=True, update_probe=True, verbose=20, doplot=doplot)
p.run_ml_poisson(100, update_object=True, update_probe=True, verbose=20, doplot=doplot)

# Refinement of the scan positions. TODO
