#! /opt/local/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

from __future__ import division

import sys
import time
import locale
import timeit

try:
    import hdf5plugin

    print("WARNING: hdf5plugin not found. This may prevent reading Eiger hdf5 files. Hint: 'pip3 install hdf5plugin'")
except:
    pass
import fabio
import h5py
import numpy as np

try:
    import silx

    if silx.version_info.minor < 3:
        # start using silx with 0.3-dev
        silx = None
    else:
        from silx.io.specfile import SpecFile
    print("Using SILX for spec access!")
except ImportError:
    silx = None

if silx is None:
    from Vincent import spec

from .runner import PtychoRunner, PtychoRunnerScan, PtychoRunnerException, params_generic

helptext_beamline = """
Script to perform a ptychography analysis on data from 25A@TPS

Example:
    pynx-tps25apty.py h5file=pty_41_61_master.h5 specfile=scan.dat scan=41 detectordistance=1.3
      ptychomotors=pix,piz,-x,y probe=60e-6x200e-6,0.09 
      algorithm=analysis,ML**100,DM**200,nbprobe=3,probe=1 verbose=10 save=all saveplot liveplot

Command-line arguments (beamline-specific):
    scanfile=/some/dir/to/scan.txt: path to text file with motor positions [mandatory]

    data=path/to/data.h5: path to hdf5 data, with images, wavelength, detector distance [mandatory]
"""

params_beamline = {'scanfile': None, 'data': None, 'object': 'random,0.9,1,-.2,.2', 'instrument': 'TPS 25A'}

params = params_generic.copy()
for k, v in params_beamline.items():
    params[k] = v


class PtychoRunnerScanTPS25A(PtychoRunnerScan):
    def __init__(self, params, scan):
        super(PtychoRunnerScanTPS25A, self).__init__(params, scan)

    def load_data(self):
        print("reading motor positions from: ", self.params['scanfile'])
        self.x, self.y = np.loadtxt('scan_2670.txt', delimiter=',', skiprows=1, unpack=True, usecols=(1, 2))

        self.h5 = h5py.File(self.params['data'], 'r')

        if self.params['pixelsize'] is None:
            self.params['pixelsize'] = np.array(self.h5.get('entry/instrument/detector/x_pixel_size'))
            print("Pixelsize?", self.params['pixelsize'])

        self.params['detectordistance'] = np.array(self.h5.get('entry/instrument/detector/detector_distance'))

        if self.params['nrj'] is None:
            assert (self.h5.get("entry/instrument/beam/incident_wavelength")[()] > 0.1)
            self.params['nrj'] = 12.3984 / self.h5.get("entry/instrument/beam/incident_wavelength")[()]

        if len(self.x) < 4:
            raise PtychoRunnerException("Less than 4 scan positions, is this a ptycho scan ?")
        # Spec values are in microns, convert to meters
        self.x *= 1e-6
        self.y *= 1e-6

        imgn = np.arange(len(self.x), dtype=np.int)

        read_all_frames = True
        if self.params['moduloframe'] is not None:
            n1, n2 = self.params['moduloframe']
            idx = np.where(imgn % n1 == n2)[0]
            if len(idx[0]) < len(imgn):
                read_all_frames = False
            imgn = imgn.take(idx)
            self.x = self.x.take(idx)
            self.y = self.y.take(idx)

        if self.params['maxframe'] is not None:
            N = self.params['maxframe']
            if len(imgn) > N:
                print("MAXFRAME: only using first %d frames" % (N))
                imgn = imgn[:N]
                self.x = self.x[:N]
                self.y = self.y[:N]
                read_all_frames = False

        # Load all frames
        t0 = timeit.default_timer()
        vimg = None
        d0 = 0
        sys.stdout.write('Reading HDF5 frames: ')
        # frames are grouped in different subentries
        i0 = 0
        entry0 = 1
        h5_data_prefix = 'entry/data/data_%06d'
        h5entry = h5_data_prefix % entry0
        print("\nReading h5 data entry: %s" % (h5entry))
        if read_all_frames:
            # Read all the data frames in one go - faster
            h5d = np.array(self.h5[h5entry][()])
        else:
            # Read frame-by-frame - more efficient, at least memory-wise
            h5d = self.h5[h5entry]
        ii = 0
        for i in imgn:
            if (i - imgn[0]) % 20 == 0:
                sys.stdout.write('%d ' % (i - imgn[0]))
                sys.stdout.flush()
            # Load all frames
            if i >= (i0 + len(h5d)):
                # Open next data pack (read frame-by frame)
                i0 += len(h5d)
                entry0 += 1
                h5entry = h5_data_prefix % entry0
                print("\nReading h5 data entry: %s" % (h5entry))
                if read_all_frames:
                    # Read all the data frames in one go - faster
                    h5d = np.array(self.h5[h5entry][()])
                else:
                    # Read frame-by-frame - more efficient, at least memory-wise
                    h5d = self.h5[h5entry]
            frame = h5d[i - i0]
            # Values of 2**32-1 and -2 are invalid (module gaps or invalidated pixels)
            # Need to test on all frames, as 'invalid' pixels may differ from frame to frame...
            tmp = (frame > 2 ** 32 - 3)
            if tmp.sum() > 0:
                if self.raw_mask is None:
                    self.raw_mask = tmp.astype(np.int8)
                else:
                    self.raw_mask[tmp > 0] = 1
            # Set to zero masked pixels ? No, this will be done in center_crop_data()
            # if self.raw_mask is not None:
            #    frame *= (self.raw_mask == 0)
            if vimg is None:
                vimg = np.empty((len(imgn), frame.shape[0], frame.shape[1]), dtype=frame.dtype)
            vimg[ii] = frame
            d0 += frame
            ii += 1
        print("\n")
        dt = timeit.default_timer() - t0
        print('Time to read all frames: %4.1fs [%5.2f Mpixel/s]' % (dt, d0.size * len(vimg) / 1e6 / dt))
        if self.raw_mask is not None:
            if self.raw_mask.sum() > 0:
                print("\nMASKING %d pixels from detector flags" % (self.raw_mask.sum()))

        self.raw_data = vimg
        self.load_data_post_process()


class PtychoRunnerTPS25A(PtychoRunner):
    """
    Class to process a series of scans with a series of algorithms, given from the command-line
    """

    def __init__(self, argv, params, ptycho_runner_scan_class):
        super(PtychoRunnerTPS25A, self).__init__(argv, params, ptycho_runner_scan_class)
        self.help_text += helptext_beamline

    def parse_arg_beamline(self, k, v):
        """
        Parse argument in a beamline-specific way. This function only parses single arguments.
        If an argument is recognized and interpreted, the corresponding value is added to self.params

        This method should be superseded in a beamline/instrument-specific child class
        Returns:
            True if the argument is interpreted, false otherwise
        """
        if k == 'scanfile' or k == 'data':
            self.params[k] = v
            return True
        return False

    def check_params_beamline(self):
        """
        Check if self.params includes a minimal set of valid parameters, specific to a beamiline
        Returns: Nothing. Will raise an exception if necessary
        """
        if self.params['scanfile'] is None:
            raise PtychoRunnerException('Missing argument: no scanfile given')
        if self.params['data'] is None:
            raise PtychoRunnerException('Missing argument: no data (hdf5 master) given')
