#! /opt/local/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

import sys
import timeit
import numpy as np
import fabio
from silx.io.specfile import SpecFile
from .runner import CDIRunner, CDIRunnerException, CDIRunnerScan, params_generic
from pynx.cdi import *

params_beamline = {'detwin': True, 'support_size': None, 'nb_raar': 600, 'nb_hio': 0, 'nb_er': 200, 'nb_ml': 0,
                   'instrument': 'ESRF id01', 'specfile': None, 'imgcounter': 'auto', 'imgname': None, 'scan': None,
                   'auto_center_resize': True}

helptext_beamline = """
Script to perform a CDI reconstruction of data from id01@ESRF.
command-line/file parameters arguments: (all keywords are case-insensitive):

    specfile=/some/dir/to/specfile.spec: path to specfile [mandatory, unless data= is used instead]

    scan=56: scan number in specfile [mandatory].
             Alternatively a list or range of scans can be given:
                scan=12,23,45 or scan="range(12,25)" (note the quotes)

    imgcounter=mpx4inr: spec counter name for image number
                        [default='auto', will use either 'mpx4inr' or 'ei2mint']

    imgname=/dir/to/images/prefix%05d.edf.gz: images location with mask 
            [default: will be extracted from the ULIMA_mpx4 entry in the spec scan header]
            
    Specific defaults for this script:
        auto_center_resize = True
        detwin = True
        nb_raar = 600
        nb_hio = 0
        nb_er = 200
        nb_ml = 0
        support_size = None
"""

params = params_generic.copy()
for k, v in params_beamline.items():
    params[k] = v


class CDIRunnerScanID01(CDIRunnerScan):
    def __init__(self, params, scan):
        super(CDIRunnerScanID01, self).__init__(params, scan)

    def load_data(self):
        """
        Loads data. If no id01-specific keywords have been supplied, use the default data loading.

        """
        if self.params['specfile'] is None or self.scan is None:
            super(CDIRunnerScanID01, self).load_data()
        else:
            if isinstance(self.scan, str):
                # do we combine several scans ?
                vs = self.scan.split('+')
            else:
                vs = [self.scan]
            imgn = None
            scan_motor_last_value = None
            for scan in vs:
                if scan is None:
                    scan = 0
                else:
                    scan = int(scan)
                s = SpecFile(self.params['specfile'])['%d.1' % (scan)]
                h = s.scan_header_dict

                if self.params['imgcounter'] == 'auto':
                    if 'ei2minr' in s.labels:
                        self.params['imgcounter'] = 'ei2minr'
                    elif 'mpx4inr' in s.labels:
                        self.params['imgcounter'] = 'mpx4inr'
                    print("Using image counter: %s" % (self.params['imgcounter']))

                if self.params['wavelength'] is None and 'UMONO' in h:
                    nrj = float(h['UMONO'].split('mononrj=')[1].split('ke')[0])
                    w = 12.384 / nrj
                    self.params['wavelength'] = w
                    print("Reading nrj from spec data: nrj=%6.3fkeV, wavelength=%6.3fA" % (nrj, w))

                if self.params['detector_distance'] is None and 'UDETCALIB' in h:
                    if 'det_distance_CC=' in h['UDETCALIB']:
                        # UDETCALIB cen_pix_x=18.347,cen_pix_y=278.971,pixperdeg=445.001,det_distance_CC=1402.175,det_distance_COM=1401.096,timestamp=20170926...
                        self.params['detector_distance'] = float(h['UDETCALIB'].split('stance_CC=')[1].split(',')[0])
                        print("Reading detector distance from spec data: %6.3fm", self.params['detector_distance'])
                    else:
                        print('No detector distance given. No det_distance_CC in UDETCALIB ??: %s' % (h['UDETCALIB']))

                # Read images
                if imgn is None:
                    imgn = s.data_column_by_name(self.params['imgcounter']).astype(np.int)
                else:
                    # try to be smart: exclude first image if first motor position is at the end of last scan
                    if scan_motor_last_value == s.data[0, 0]:
                        i0 = 0
                    else:
                        print("Scan %d: excluding first image at same position as previous one" % (scan))
                        i0 = 1
                    imgn = np.append(imgn, s.data_column_by_name(self.params['imgcounter'])[i0:].astype(np.int))
                scan_motor_last_value = s.data[0, 0]
            self.iobs = None
            t0 = timeit.default_timer()
            if self.params['imgname'] == None:
                if 'ULIMA_eiger2M' in h:
                    imgname = h['ULIMA_eiger2M'].strip().split('_eiger2M_')[0] + '_eiger2M_%05d.edf.gz'
                    print("Using Eiger 2M detector images: %s" % (imgname))
                else:
                    imgname = h['ULIMA_mpx4'].strip().split('_mpx4_')[0] + '_mpx4_%05d.edf.gz'
                    print("Using Maxipix mpx4 detector images: %s" % (imgname))
            else:
                imgname = self.params['imgname']
            sys.stdout.write('Reading frames: ')
            ii = 0
            for i in imgn:
                if (i - imgn[0]) % 20 == 0:
                    sys.stdout.write('%d ' % (i - imgn[0]))
                    sys.stdout.flush()
                frame = fabio.open(imgname % i).data
                if self.iobs is None:
                    self.iobs = np.empty((len(imgn), frame.shape[0], frame.shape[1]), dtype=frame.dtype)
                self.iobs[ii] = frame
                ii += 1
            print("\n")
            dt = timeit.default_timer() - t0
            print('Time to read all frames: %4.1fs [%5.2f Mpixel/s]' % (dt, self.iobs.size / 1e6 / dt))

    def run(self):
        """
        Main work function. Will run HIO and ER according to parameters

        :return:
        """
        if self.params['algorithm'] is not None:
            self.run_algorithm(self.params['algorithm'])
            return
        nb_raar = self.params['nb_raar']
        nb_hio = self.params['nb_hio']
        nb_er = self.params['nb_er']
        nb_ml = self.params['nb_ml']
        positivity = self.params['positivity']
        support_only_shrink = self.params['support_only_shrink']
        beta = self.params['beta']
        detwin = self.params['detwin']
        psf = self.params['psf']  # Experimental, to take into account partial coherence
        live_plot = self.params['live_plot']
        support_update_period = self.params['support_update_period']
        support_smooth_width_begin = self.params['support_smooth_width_begin']
        support_smooth_width_end = self.params['support_smooth_width_end']
        support_threshold = self.params['support_threshold']
        support_threshold_method = self.params['support_threshold_method']
        support_post_expand = self.params['support_post_expand']
        verbose = self.params['verbose']
        if live_plot:
            live_plot = verbose
        print(nb_raar, nb_hio, nb_er, nb_ml, beta, positivity, detwin, support_only_shrink, live_plot, verbose)

        for i in range(0, nb_hio + nb_raar):
            if support_update_period > 0:
                if i % support_update_period == 0 and i > 0:
                    if detwin and i // support_update_period == (nb_hio + nb_raar) // (3 * support_update_period):
                        print("Detwinning with 10 cycles of RAAR or HIO and a half-support")
                        if i < nb_raar:
                            self.cdi = DetwinRAAR(beta=beta, positivity=False, nb_cycle=10) * self.cdi
                        else:
                            self.cdi = DetwinHIO(beta=beta, positivity=False, nb_cycle=10) * self.cdi
                    else:
                        s = support_smooth_width_begin, support_smooth_width_end, nb_raar + nb_hio
                        sup = SupportUpdate(threshold_relative=support_threshold, smooth_width=s,
                                            force_shrink=support_only_shrink,
                                            method=support_threshold_method, post_expand=support_post_expand)
                        self.cdi = sup * self.cdi

            if psf and i >= (nb_raar + nb_hio) * 0.75 and (i - (nb_raar + nb_hio) * 0.75) % 50 == 0:
                print("Evaluating point-spread-function (partial coherence correction)")
                self.cdi = EstimatePSF() ** 100 * self.cdi

            if i < nb_raar:
                self.cdi = RAAR(beta=beta, positivity=positivity, calc_llk=verbose, show_cdi=live_plot) * self.cdi
            else:
                self.cdi = HIO(beta=beta, positivity=positivity, calc_llk=verbose, show_cdi=live_plot) * self.cdi

        for i in range(0, nb_er):
            if psf and i % 50 == 10:
                print("Evaluating point-spread-function (partial coherence correction)")
                self.cdi = EstimatePSF() ** 100 * self.cdi

            self.cdi = ER(positivity=positivity, calc_llk=verbose, show_cdi=live_plot) * self.cdi

        if nb_ml > 0:
            if psf:
                print("ML deactivated - PSF is unimplemented in ML")
            else:
                print("Finishing with %d cycles of ML" % (nb_ml))
                self.cdi = ML(nb_cycle=nb_ml, calc_llk=verbose) * self.cdi
        self.cdi = FreePU() * self.cdi

    def prepare_cdi(self):
        """
        Prepare CDI object from input data.

        Returns: nothing. Creates self.cdi object

        """
        super(CDIRunnerScanID01, self).prepare_cdi()
        # Scale initial object
        self.cdi = ScaleObj(method='F') * self.cdi


class CDIRunnerID01(CDIRunner):
    """
    Class to process a series of scans with a series of algorithms, given from the command-line
    """

    def __init__(self, argv, params, ptycho_runner_scan_class):
        super(CDIRunnerID01, self).__init__(argv, params, ptycho_runner_scan_class)
        self.help_text += helptext_beamline

    def parse_arg_beamline(self, k, v):
        """
        Parse argument in a beamline-specific way. This function only parses single arguments.
        If an argument is recognized and interpreted, the corresponding value is added to self.params

        This method should be superseded in a beamline/instrument-specific child class
        Returns:
            True if the argument is interpreted, false otherwise
        """
        if k in ['specfile', 'imgcounter', 'imgname', 'scan']:
            self.params[k] = v
            return True
        return False

    def check_params_beamline(self):
        """
        Check if self.params includes a minimal set of valid parameters, specific to a beamiline
        Returns: Nothing. Will raise an exception if necessary
        """
        print()
        if self.params['data'] is None and (self.params['specfile'] is None or self.params['scan'] is None):
            raise CDIRunnerException('No data provided. Need at least data=, or specfile= and scan=')
