#! /opt/local/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr


import timeit
import argparse

from ...utils import h5py as h5
import numpy as np

from .runner import PtychoRunner, PtychoRunnerScan, PtychoRunnerException, default_params as params0
from .parser import ActionPtychoMotors

helptext_epilog = """
Examples:

* ``pynx-ptycho-id16b-nf --data raw_data_master.h5\
--algorithm analysis,ML**100,DM**200,probe=1``

Using a reference (direct beam) frame, allows to use Paganin, and 
with use_direct_beam will provide an absolute scale for the probe.
In this case, only projection 57 (scan #) is reconstructed.
Also, only save the object phase for a smaller output:
    
* ``pynx-ptycho-id16b-nf --data raw_data_master.h5 --scan 57\
--algorithm analysis,ML**100,DM**200,Paganin,probe=1\
--use_direct_beam --cxi_output object_phase``

"""

# NB: for id16 we start from a flat object (high energy, high transmission)
params_beamline = {
    'adu_scale': None,
    'autocenter': False,
    'delta_beta': None,
    'detector': 'pco1',
    'instrument': 'ESRF id16b',
    'maxsize': 10000,
    'near_field': True,
    'object': 'random,0.95,1,0,0.1',
    'ptychomotors': ['sy', 'sz', 'x', '-y'],
    'roi': 'full',
    'saveprefix': '{data_prefix}{scan:04d}'
}

default_params = params0.copy()
for k, v in params_beamline.items():
    default_params[k] = v


class PtychoRunnerScanID16bNF(PtychoRunnerScan):
    def __init__(self, params, scan, timings=None):
        super().__init__(params, scan, timings=timings)
        self._refs_entry = None
        self._scans_entry = []
        self._scans_data = []
        self._technique_entry = []
        self.tomo_metadata = {}

    @staticmethod
    def _get_data_units(v):
        if 'units' in v.attrs:
            u = v.attrs['units']
            if u in ['um', 'µm']:
                return v[()] * 1e-6
            elif u in ['nm']:
                return v[()] * 1e-9
            elif u in ['mm']:
                return v[()] * 1e-3
        return v[()]

    def _find_bliss_scan_entries(self):
        """
        Find the bliss scan entries in the hdf5 raw data master file
        :return:
        """
        with h5.File(self.params['data'], mode='r', enable_file_locking=False) as h:
            for k, e in h.items():
                t = e['title'][()].decode()
                if 'dark' in t:
                    self.params['dark'] = f"{self.params['data']}:{k}/measurement/{self.params['detector']}"
                    print(f"Found dark entry: {self.params['dark']}")
                elif 'flat' in t:
                    self._refs_entry = f"{k}/measurement/{self.params['detector']}"
                    print(f"Found flat entry : {self._refs_entry}")
                elif 'projection' in t:
                    if f"measurement/{self.params['detector']}" in e:
                        self._scans_entry.append(k)
                        self._scans_data.append(f"{k}/measurement/{self.params['detector']}")
                    if self.params['pixelsize'] is None:
                        self.params['pixelsize'] = \
                            self._get_data_units(e[f"instrument/tomo_config/sample_pixel_size"])
                elif 'tomo:basic' in t:
                    try:
                        if self.params['detectordistance'] is None:
                            self.params['detectordistance'] = \
                                self._get_data_units(e["technique/scan/effective_propagation_distance"])
                        if self.params['nrj'] is None:
                            self.params['nrj'] = e["technique/scan/energy"][()]
                    except KeyError as ex:
                        print(ex)
                        print(k, e, t)
        if self.params['dark'] is None:
            print("WARNING-ID16B runner: no dark found in bliss file ! Hint: --dark")

    def load_scan(self):
        self._find_bliss_scan_entries()
        if self.scan is None:
            self.scan = 0
        with h5.File(self.params['data'], mode='r', enable_file_locking=False) as h:
            # Magnification
            e = h[self._scans_entry[0]]
            # optics_pixel_size = self._get_data_units(e[f"instrument/tomo_config/optics_pixel_size"])
            # sample_pixel_size = self._get_data_units(e[f"instrument/tomo_config/sample_pixel_size"])
            # mag = optics_pixel_size / sample_pixel_size

            # Tomo angle
            a = e[f"instrument/positioners/srot"][self.scan]
            self.tomo_metadata['angle'] = np.deg2rad(a)

            # Sample
            self.sample_name = e[f"sample/name"][()].decode()

            # motor positions
            m = self.params['ptychomotors']
            xmot, ymot = m[0:2]
            vx, vy = [], []
            for k in self._scans_entry:
                vx.append(self._get_data_units(h[f"{k}/instrument/positioners/{xmot}"]))
                vy.append(self._get_data_units(h[f"{k}/instrument/positioners/{ymot}"]))
            self.x = np.array(vx, dtype=np.float32)
            self.y = np.array(vy, dtype=np.float32)

            if len(m) == 4:
                x, y = self.x, self.y
                self.x, self.y = eval(m[2]), eval(m[3])

        self.tomo_metadata['data'] = self.params['data']

        imgn = np.arange(len(self.x), dtype=np.int32)
        if self.params['moduloframe'] is not None:
            n1, n2 = self.params['moduloframe']
            idx = np.where(imgn % n1 == n2)[0]
            imgn = imgn.take(idx)
            self.x = self.x.take(idx)
            self.y = self.y.take(idx)

        if self.params['maxframe'] is not None:
            N = self.params['maxframe']
            if len(imgn) > N:
                print(f"MAXFRAME: only using first {N} frames")
                imgn = imgn[:N]
                self.x = self.x[:N]
                self.y = self.y[:N]

        if len(self.x) < 4:
            raise PtychoRunnerException("Less than 4 scan positions, is this a ptycho scan ?")

        self.imgn = np.array(imgn)

    def load_data(self):
        with h5.File(self.params['data'], mode='r', enable_file_locking=False) as h:
            # refs
            if self._refs_entry is not None:
                d = h[self._refs_entry][()].astype(np.float32, copy=False)
                if d.ndim == 3:
                    self.data_ref = d.mean(axis=0)
                else:
                    self.data_ref = d
                if self.params['adu_scale'] is not None:
                    self.data_ref = self.data_ref.astype(np.float32, copy=False)
                    self.data_ref *= self.params['adu_scale']

            # read all frames
            imgn = self.imgn
            t0 = timeit.default_timer()
            print("Reading frames:")
            n = len(imgn)
            ny, nx = h[self._scans_data[0]].shape[-2:]
            dt = timeit.default_timer() - t0
            self.raw_data = np.empty((n, ny, nx), dtype=np.float32)
            for i, e in enumerate(self._scans_data):
                self.raw_data[i] = h[e][self.scan]
            print('Time to read all frames: %4.1fs [%5.2f Mpixel/s]' %
                  (dt, self.raw_data[0].size * len(self.raw_data) / 1e6 / dt))

            if self.params['adu_scale'] is not None:
                self.raw_data *= self.params['adu_scale']

        self.load_data_post_process()


class PtychoRunnerID16bNF(PtychoRunner):
    """
    Class to process a series of scans with a series of algorithms, given from the command-line
    """

    def __init__(self, argv, params, *args, **kwargs):
        super().__init__(argv, default_params if params is None else params)
        self.PtychoRunnerScan = PtychoRunnerScanID16bNF

    @classmethod
    def make_parser(cls, default_par, description=None, script_name="pynx-ptycho-id16b-nf", epilog=None):
        if epilog is None:
            epilog = helptext_epilog
        if description is None:
            description = "Script to perform a ptychography analysis on NEAR FIELD data from ID16B@ESRF"
        p = default_par

        parser = super().make_parser(p, script_name, description, epilog)
        grp = parser.add_argument_group("ID16B (near field) parameters")
        grp.add_argument('--data', type=str, default=p['data'], required=True,
                         help="path to the bliss file, e.g. path/to/data.h5")

        grp.add_argument('--detector', type=str, default=p['detector'], required=False,
                         help="Name of the dectector used")

        grp.add_argument('--ptycho_motors', '--ptychomotors', type=str, default=p['ptychomotors'],
                         nargs='+', dest='ptychomotors', action=ActionPtychoMotors,
                         help="name of the two motors used for ptychography, optionally followed "
                              "by a mathematical expression to be used to calculate the actual "
                              "motor positions (axis convention, angle..) in meters. Values will be read "
                              "from the dat files:\n\n"
                              "* ``--ptychomotors=spy,spz,-x*1e-6,y*1e-6``  (the default)\n"
                              "* ``--ptychomotors pix piz``\n"
                              "Note that if the ``--xy=-y,x`` command-line argument is used, "
                              "it is applied _after_ this, using ``--ptychomotors=spy,spz,-x,y`` "
                              "is equivalent to ``--ptychomotors spy spz --xy=-x,y``")

        grp.add_argument('--delta_beta', type=float, default=p['delta_beta'],
                         help="delta/beta value, required for Paganin or CTF algorithms."
                              "Can also be set in the algorithm string")

        grp.add_argument('--adu_scale', type=float, default=p['adu_scale'],
                         help="optional scale factor by which the all data (including data_ref and dark) "
                              "will be multiplied in order to have photon counts.")

        # Change const value
        grp.add_argument('--save_plot', '--saveplot', type=str,
                         default=False, const='object_phase', dest='saveplot',
                         choices=['object_phase', 'object_rgba'], nargs='?',
                         help='will save plot at the end of the optimization (png file).\n'
                              'A string can be given to specify if only the object phase '
                              '(default if only --saveplot is given) or rgba should be plotted.')

        grp.add_argument('--tapering', type=int, default=0, nargs='?', const=200,
                         help=argparse.SUPPRESS)
        # "Add a tapering (cosine/Tukey) window so that the intensities
        # "reach zero on the boundaries of the observed intensity data")`

        return parser

    def check_params_beamline(self):
        """
        Check if self.params includes a minimal set of valid parameters, specific to a beamline
        Returns: Nothing. Will raise an exception if necessary
        """
        if 'paganin' in self.params['algorithm'].lower() or 'ctf' in self.params['algorithm'].lower():
            if self.params['delta_beta'] is None and 'delta_beta:' not in self.params['algorithm'].lower():
                raise PtychoRunnerException('Need delta_beta=... for Paganin and CTF')


def make_parser_sphinx():
    """Returns the argparse for sphinx documentation"""
    return PtychoRunnerID16bNF.make_parser(default_params)
