#! /opt/local/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr
from __future__ import division

import numpy as np
from pynx.ptycho import simulation

from .runner import PtychoRunner, PtychoRunnerScan, PtychoRunnerException, params_generic

helptext_beamline = """
Script to perform a ptychography analysis on a simulated dataset

Examples:
    pynx-simulationpty.py frame_nb=128 frame_size=256 algorithm=analysis,ML**100,DM**200,nbprobe=2,probe=1 
                          saveplot liveplot

command-line arguments:
    frame_nb: number of simulated frames (will be generated along a spiral)
    
    frame_size: size along x and y of each frame
"""

params_beamline = {'frame_nb': 128, 'frame_size': 256, 'probe': 'focus,60e-6x200e-6,0.09',
                   'defocus': 100e-6, 'algorithm': 'ML**100,DM**200,probe=1', 'instrument': 'simulation', 'nrj': 8,
                   'detectordistance': 1, 'pixelsize': 55e-6}

params = params_generic.copy()
for k, v in params_beamline.items():
    params[k] = v


class PtychoRunnerScanSimul(PtychoRunnerScan):
    def __init__(self, params, scan):
        super(PtychoRunnerScanSimul, self).__init__(params, scan)

    def load_data(self):
        # Simulation of the ptychographic data:
        nb = self.params['frame_nb']
        n = self.params['frame_size']
        pixel_size_detector = self.params['pixelsize']
        wavelength = 12.3984e-10 / self.params['nrj']
        detector_distance = self.params['detectordistance']
        obj_info = {'type': 'phase_ampl', 'phase_stretch': np.pi / 2, 'alpha_win': .2}
        probe_info = {'type': 'focus', 'aperture': (50e-6, 120e-6), 'focal_length': .08, 'defocus': 200e-6,
                      'shape': (n, n)}
        # probe_info = {'type': 'gauss', 'sigma_pix': (40, 40), 'defocus': 100e-6, 'shape': (n, n)}

        # 50 scan positions correspond to 4 turns, 78 to 5 turns, 113 to 6 turns
        scan_info = {'type': 'spiral', 'scan_step_pix': 30, 'n_scans': nb}
        data_info = {'num_phot_max': 1e6, 'bg': 0, 'wavelength': wavelength, 'detector_distance': detector_distance,
                     'detector_pixel_size': pixel_size_detector, 'noise': 'poisson'}

        # Initialisation of the simulation with specified parameters
        s = simulation.Simulation(obj_info=obj_info, probe_info=probe_info, scan_info=scan_info, data_info=data_info)
        s.make_data()

        # Positions from simulation are given in pixels
        posx, posy = s.scan.values
        pixel_size_object = wavelength * detector_distance / pixel_size_detector / n
        self.x = posx * pixel_size_object
        self.y = posy * pixel_size_object

        self.raw_data = s.amplitude.values ** 2
        self.load_data_post_process()


class PtychoRunnerSimul(PtychoRunner):
    """
    Class to process a series of scans with a series of algorithms, given from the command-line
    """

    def __init__(self, argv, params, ptycho_runner_scan_class):
        super(PtychoRunnerSimul, self).__init__(argv, params, ptycho_runner_scan_class)
        self.help_text += helptext_beamline

    def parse_arg_beamline(self, k, v):
        """
        Parse argument in a beamline-specific way. This function only parses single arguments.
        If an argument is recognized and interpreted, the corresponding value is added to self.params

        This method should be superseded in a beamline/instrument-specific child class
        Returns:
            True if the argument is interpreted, false otherwise
        """
        if k in ['frame_nb', 'frame_size']:
            self.params[k] = int(v)
            return True
        return False

    def check_params_beamline(self):
        """
        Check if self.params includes a minimal set of valid parameters, specific to a beamiline
        Returns: Nothing. Will raise an exception if necessary
        """
        pass
