#! /opt/local/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

import timeit
import numpy as np

from .runner import CDIRunner, CDIRunnerException, CDIRunnerScan, params_generic
from pynx.cdi import *

params_beamline = {'auto_center_resize': False, 'support_type': 'circle', 'detwin': False, 'support_size': 50,
                   'nb_raar': 0, 'nb_hio': 600, 'nb_er': 200, 'nb_ml': 0, 'instrument': 'ESRF id10', 'mask': 'zero'}

helptext_beamline = """
Script to perform a CDI reconstruction of data from id10@ESRF.
command-line/file parameters arguments: (all keywords are case-insensitive):

    Specific defaults for this script:
        auto_center_resize=True
        detwin = False
        mask = zero
        nb_raar = 600
        nb_hio = 0
        nb_er = 200
        nb_ml = 0
        support_size = 50
        support_type = circle
"""

params = params_generic.copy()
for k, v in params_beamline.items():
    params[k] = v


class CDIRunnerScanID10(CDIRunnerScan):
    def __init__(self, params, scan):
        super(CDIRunnerScanID10, self).__init__(params, scan)

    def run(self):
        """
        Main work function. Will run HIO and ER according to parameters

        :return: 
        """
        if self.params['algorithm'] is not None:
            self.run_algorithm(self.params['algorithm'])
            return

        nb_raar = self.params['nb_raar']
        nb_hio = self.params['nb_hio']
        nb_er = self.params['nb_er']
        nb_ml = self.params['nb_ml']
        positivity = self.params['positivity']
        support_only_shrink = self.params['support_only_shrink']
        beta = self.params['beta']
        detwin = self.params['detwin']
        psf = self.params['psf']  # Experimental, to take into account partial coherence
        live_plot = self.params['live_plot']
        support_update_period = self.params['support_update_period']
        support_smooth_width_begin = self.params['support_smooth_width_begin']
        support_smooth_width_end = self.params['support_smooth_width_end']
        support_threshold = self.params['support_threshold']
        support_threshold_method = self.params['support_threshold_method']
        support_post_expand = self.params['support_post_expand']
        verbose = self.params['verbose']
        if live_plot:
            live_plot = verbose
        print(nb_raar, nb_hio, nb_er, nb_ml, beta, positivity, detwin, support_only_shrink)

        for i in range(0, nb_hio + nb_raar):
            if support_update_period > 0:
                if i % support_update_period == 0 and i > 0:
                    if detwin and i // support_update_period == (nb_hio + nb_raar) // (3 * support_update_period):
                        print("Detwinning with 10 cycles of RAAR or HIO and a half-support")
                        if i < nb_raar:
                            self.cdi = DetwinRAAR(beta=beta, positivity=False, nb_cycle=10) * self.cdi
                        else:
                            self.cdi = DetwinHIO(beta=beta, positivity=False, nb_cycle=10) * self.cdi
                    else:
                        s = support_smooth_width_begin, support_smooth_width_end, nb_raar + nb_hio
                        sup = SupportUpdate(threshold_relative=support_threshold, smooth_width=s,
                                            force_shrink=support_only_shrink,
                                            method=support_threshold_method, post_expand=support_post_expand)
                        self.cdi = sup * self.cdi

            if psf and i >= (nb_raar + nb_hio) * 0.75 and (i - (nb_raar + nb_hio) * 0.75) % 50 == 0:
                print("Evaluating point-spread-function (partial coherence correction)")
                self.cdi = EstimatePSF() ** 10 * self.cdi

            if i < nb_raar:
                self.cdi = RAAR(beta=beta, positivity=positivity, calc_llk=verbose, show_cdi=live_plot) * self.cdi
            else:
                self.cdi = HIO(beta=beta, positivity=positivity, calc_llk=verbose, show_cdi=live_plot) * self.cdi

        for i in range(0, nb_er):
            if psf and i % 50 == 10:
                print("Evaluating point-spread-function (partial coherence correction)")
                self.cdi = EstimatePSF() ** 10 * self.cdi

            self.cdi = ER(positivity=positivity, calc_llk=verbose, show_cdi=live_plot) * self.cdi

        if nb_ml > 0:
            if psf:
                print("ML deactivated - PSF is unimplemented in ML")
            else:
                print("Finishing with %d cycles of ML" % (nb_ml))
                self.cdi = ML(nb_cycle=nb_ml, calc_llk=verbose, show_cdi=live_plot) * self.cdi
        self.cdi = FreePU() * self.cdi

    def prepare_cdi(self):
        """
        Prepare CDI object from input data.

        Returns: nothing. Creates self.cdi object

        """
        super(CDIRunnerScanID10, self).prepare_cdi()
        # Scale initial object
        self.cdi = ScaleObj(method='F') * self.cdi


class CDIRunnerID10(CDIRunner):
    """
    Class to process a series of scans with a series of algorithms, given from the command-line
    """

    def __init__(self, argv, params, ptycho_runner_scan_class):
        super(CDIRunnerID10, self).__init__(argv, params, ptycho_runner_scan_class)
        self.help_text += helptext_beamline

    # def parse_arg_beamline(self, k, v):
    #     """
    #     Parse argument in a beamline-specific way. This function only parses single arguments.
    #     If an argument is recognized and interpreted, the corresponding value is added to self.params
    #
    #     This method should be superseded in a beamline/instrument-specific child class
    #     Returns:
    #         True if the argument is interpreted, false otherwise
    #     """
    #     if k in []:
    #         self.params[k] = v
    #         return True
    #     elif k in []:
    #         self.params[k] = float(v)
    #         return True
    #     elif k in []:
    #         if v is None:
    #             self.params[k] = True
    #             return True
    #         elif type(v) is bool:
    #             self.params[k] = v
    #             return True
    #         elif type(v) is str:
    #             if v.lower() == 'true' or v == '1':
    #                 self.params[k] = True
    #                 return True
    #             else:
    #                 self.params[k] = False
    #                 return True
    #         else:
    #             return False
    #     elif k in []:
    #         self.params[k] = int(v)
    #         return True
    #     return False

    def check_params_beamline(self):
        """
        Check if self.params includes a minimal set of valid parameters, specific to a beamiline
        Returns: Nothing. Will raise an exception if necessary
        """
        if self.params['data'] is None:
            raise CDIRunnerException('No data provided. Need at least data=..., or a parameters input file')
