#! /opt/local/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

from __future__ import division


import sys
import time
import locale
import timeit
try:
    import hdf5plugin
    print("WARNING: hdf5plugin not found. This may prevent reading Eiger hdf5 files. Hint: 'pip3 install hdf5plugin'")
except:
    pass
import fabio
import h5py
import numpy as np
try:
    import silx
    if silx.version_info.minor <3:
        # start using silx with 0.3-dev
        silx = None
    else:
        from silx.io.specfile import SpecFile
    print("Using SILX for spec access!")
except ImportError:
    silx = None

if silx is None:
    from Vincent import spec

from .runner import PtychoRunner, PtychoRunnerScan, PtychoRunnerException, params_generic

helptext_beamline="""
Script to perform a ptychography analysis on data from id13@ESRF

Example:
    pynx-id13pty.py h5file=pty_41_61_master.h5 specfile=scan.dat scan=41 detectordistance=1.3 
      ptychomotors=pix,piz,-x,y probe=60e-6x200e-6,0.09 
      algorithm=analysis,ML**100,DM**200,nbprobe=3,probe=1 verbose=10 save=all saveplot liveplot

Command-line arguments (beamline-specific):
    specfile=/some/dir/to/specfile.spec: path to specfile [mandatory]

    scan=56: scan number in specfile [mandatory].
             Alternatively a list or range of scans can be given:
                scan=12,23,45 or scan="range(12,25)" (note the quotes)

    h5data=entry/data/data_%6d: generic hdf5 path to stack of 2d images inside hdf5 file 
                                [default=entry/data/data_%06d]

    nrj=8: energy in keV [default: will be extracted from the HDF% file,
                          entry/instrument/beam/incident_wavelength]

    detectordistance=1.3: detector distance in meters [mandatory]

    pixelsize=75e-6: pixel size on detectorin meters 
                     [default: will read from hdf5 data: entry/instrument/detector/x_pixel_size]

    ptychomotors=nnp2,nnp3,-x,y: name of the two motors used for ptychography, optionally followed
                                 by a mathematical expression to be used to calculate the actual 
                                 motor positions (axis convention, angle..). Values will be 
                                 extracted from the edf files, and are assumed to be in microns.
                                 Example 1: ptychomotors=nnp2,nnp3
                                 Example 2: ptychomotors=nnp2,nnp3,-x,y
                                 Note that if the 'xy=-y,x' command-line argument is used, it is 
                                 applied after this, using 'ptychomotors=nnp2,nnp3,-x,y' is 
                                 equivalent to 'ptychomotors=nnp2,nnp3 xy=-x,y'
                                 [Mandatory]

    monitor=opt2: spec name for the monitor counter. The frames will be normalised by the ratio of
                  the counter value divided by the median value of the counter over the entire scan
                  (so as to remain close to Poisson statistics). A monitor intensity lower than 10%
                  of the median value will be interpreted as an image taken without beam and will be 
                  skipped.
                  [default = None]

    kmapfile=detector/kmap/kmap_00000.edf.gz: if images are saved in a multiframe data file.
                                              This superseeds imgname=...

    gpu=Titan: GPU name for OpenCL calculation [default = Titan]
"""

# NB: for id13 we start from a flat object
params_beamline={'specfile':None, 'h5file': None, 'h5data': "entry/data/data_%06d", 'nrj':None, 'detectordistance': None,
                 'pixelsize':None, 'kmapfile': None, 'monitor': None, 'ptychomotors': None, 'object':'random,1,1,0,0',
                 'instrument': 'ESRF id13', 'gpu': 'Titan'}

params = params_generic.copy()
for k, v in params_beamline.items():
    params[k] = v


class PtychoRunnerScanID13(PtychoRunnerScan):
    def __init__(self, params, scan):
        super(PtychoRunnerScanID13, self).__init__(params, scan)

    def load_data(self):
        if silx is not None:
            self.s = SpecFile(self.params['specfile'])['%d.1'%(self.scan)]
            self.h = self.s.scan_header_dict
        else:
            self.h, self.d = spec.ReadSpec(self.params['specfile'], self.scan)
        print('Read scan:', self.h['S'])

        date_string = self.h['D'] # 'Wed Mar 23 14:41:56 2016'
        date_string = date_string[date_string.find(' ')+1:]
        pattern = '%b %d %H:%M:%S %Y'
        try:
            lc = locale._setlocale(locale.LC_ALL)
            locale._setlocale(locale.LC_ALL,'C')
            epoch = int(time.mktime(time.strptime(date_string, pattern)))
            locale._setlocale(locale.LC_ALL,lc)
        except ValueError:
            print("Could not extract time from spec header, unrecognized format: %s, expected:"%(date_string)+pattern)

        self.h5 = h5py.File(self.params['h5file'],'r')

        if self.params['pixelsize'] is None:
            self.params['pixelsize'] = np.array(self.h5.get('entry/instrument/detector/x_pixel_size'))
            print("Pixelsize?", self.params['pixelsize'])

        if self.params['nrj'] is None:
            assert(self.h5.get("entry/instrument/beam/incident_wavelength").value>0.1)
            self.params['nrj'] = 12.3984/self.h5.get("entry/instrument/beam/incident_wavelength").value

        m=self.params['ptychomotors'].split(',')
        xmot, ymot = m[0:2]

        if silx is not None:
            if self.s.labels.count(xmot) == 0 or self.s.labels.count(ymot) == 0:
                raise PtychoRunnerException(
                    'Ptycho motors (%s, %s) not found in scan #%d of specfile:%s' % (xmot, ymot, self.scan, self.params['specfile']))

            self.x, self.y = self.s.data_column_by_name(xmot), self.s.data_column_by_name(ymot)
        else:
            if self.d.get(xmot) is None or self.d.get(ymot) is None:
                raise PtychoRunnerException(
                    'Ptycho motors (%s, %s) not found in scan #%d of specfile:%s' % (xmot, ymot, self.scan, self.params['specfile']))
            self.x, self.y = self.d[xmot], self.d[ymot]
        if len(m) == 4:
            x, y = self.x, self.y
            self.x, self.y = eval(m[2]), eval(m[3])
        if len(self.x) < 4:
            raise PtychoRunnerException("Less than 4 scan positions, is this a ptycho scan ?")
        # Spec values are in microns, convert to meters
        self.x *= 1e-6
        self.y *= 1e-6

        imgn = np.arange(len(self.x), dtype=np.int)

        if self.params['monitor'] is not None:
            if silx is not None:
                mon = self.s.data_column_by_name(self.params['monitor'])
            else:
                mon = self.d[self.params['monitor']]
            mon0 = np.median(mon)
            mon /= mon0
            self.validframes = np.where(mon>0.1)
            if len(self.validframes) != len(mon):
                print('WARNING: The following frames have a monitor value < 0.1 the median value and will be ignored (no beam ?)')
                print(np.where(mon<=(mon0*0.1)))
            self.x = np.take(self.x, self.validframes)
            self.y = np.take(self.y, self.validframes)
            imgn = np.take(imgn, self.validframes)
        else:
            mon = None

        if self.params['moduloframe'] is not None:
            n1, n2 = self.params['moduloframe']
            idx = np.where(imgn%n1 == n2)[0]
            imgn = imgn.take(idx)
            self.x = self.x.take(idx)
            self.y = self.y.take(idx)

        if self.params['maxframe'] is not None:
            N = self.params['maxframe']
            if len(imgn)>N:
                print("MAXFRAME: only using first %d frames"%(N))
                imgn = imgn[:N]
                self.x = self.x[:N]
                self.y = self.y[:N]

        # Load all frames
        t0 = timeit.default_timer()
        vimg = None
        d0 = 0
        if self.params['kmapfile'] is not None:
            sys.stdout.write("Reading frames from KMAP file (this WILL take a while)...")
            sys.stdout.flush()
            kfile = fabio.open(self.params['kmapfile'])
            if kfile.getNbFrames() < len(imgn):
                raise PtychoRunnerException("KMAP: only %d frames instead of %d in data file (%s) ! Did you save all frames in a single file ?"%
                                            (kfile.getNbFrames(), len(imgn), self.params['kmapfile']))
            ii = 0
            for i in imgn:
                if (i - imgn[0]) % 20 == 0:
                    sys.stdout.write('%d ' % (i - imgn[0]))
                    sys.stdout.flush()
                frame = kfile.getframe(i).data
                if vimg is None:
                    vimg = np.empty((len(imgn),frame.shape[0], frame.shape[1]), dtype=frame.dtype)
                vimg[ii] = frame
                d0 += frame
                ii += 1
        else:
            # TODO: use parallel process to load files and import frames ? (not possible using hdf5 ?)
            sys.stdout.write('Reading HDF5 frames: ')
            # frames are grouped in different subentries
            i0 = 0
            entry0 = 1
            h5entry = self.params["h5data"]%entry0
            print("\nReading h5 data entry: %s"%(h5entry))
            h5d = np.array(self.h5[h5entry].value)
            ii = 0
            for i in imgn:
                if (i - imgn[0]) % 20 == 0:
                    sys.stdout.write('%d ' % (i - imgn[0]))
                    sys.stdout.flush()
                # Load all frames
                if i >= (i0 + len(h5d)):
                    # Read next data pack
                    i0 += len(h5d)
                    entry0 += 1
                    h5entry = self.params["h5data"] % entry0
                    print("\nReading h5 data entry: %s" % (h5entry))
                    h5d = np.array(self.h5[h5entry].value)
                frame = h5d[i - i0]
                # Values of 2**32-1 and -2 are invalid (module gaps or invalidated pixels)
                # Need to test on all frames, as 'invalid' pixels may differ from frame to frame...
                tmp = (frame > 2**32-3)
                if tmp.sum() > 0:
                    if self.raw_mask is None:
                        self.raw_mask = tmp.astype(np.int8)
                    else:
                        self.raw_mask[tmp > 0] = 1
                # Set to zero masked pixels ? No, this will be done in center_crop_data()
                #if self.raw_mask is not None:
                #    frame *= (self.raw_mask == 0)
                if vimg is None:
                    vimg = np.empty((len(imgn),frame.shape[0], frame.shape[1]), dtype=frame.dtype)
                vimg[ii] = frame
                d0 += frame
                ii += 1
        print("\n")
        dt = timeit.default_timer()-t0
        print('Time to read all frames: %4.1fs [%5.2f Mpixel/s]'%(dt, d0.size * len(vimg) / 1e6 / dt))
        if self.raw_mask is not None:
            if self.raw_mask.sum() > 0:
                print("\nMASKING %d pixels from detector flags" % (self.raw_mask.sum()))

        self.raw_data = vimg
        self.load_data_post_process()


class PtychoRunnerID13(PtychoRunner):
    """
    Class to process a series of scans with a series of algorithms, given from the command-line
    """
    def __init__(self, argv, params, ptycho_runner_scan_class):
        super(PtychoRunnerID13, self).__init__(argv, params, ptycho_runner_scan_class)
        self.help_text += helptext_beamline

    def parse_arg_beamline(self,k , v):
        """
        Parse argument in a beamline-specific way. This function only parses single arguments.
        If an argument is recognized and interpreted, the corresponding value is added to self.params

        This method should be superseded in a beamline/instrument-specific child class
        Returns:
            True if the argument is interpreted, false otherwise
        """
        if k == 'specfile' or k == 'h5file' or k == 'kmapfile' or k == 'ptychomotors':
            self.params[k] = v
            return True
        elif k == 'nrj' or k == 'detectordistance' or k == 'pixelsize':
            self.params[k] = float(v)
            return True
        return False

    def check_params_beamline(self):
        """
        Check if self.params includes a minimal set of valid parameters, specific to a beamiline
        Returns: Nothing. Will raise an exception if necessary
        """
        if self.params['specfile'] is None:
            raise PtychoRunnerException('Missing argument: no specfile given')
        if self.params['h5file'] is None:
            raise PtychoRunnerException('Missing argument: no h5file given')
        if self.params['detectordistance'] is None:
            raise PtychoRunnerException('Missing argument: detector distance')
        if self.params['ptychomotors'] is None:
            raise PtychoRunnerException('Missing argument: ptychomotors')
        if self.params['kmapfile'] is not None and self.params['nrj'] is None:
            raise PtychoRunnerException('Missing argument: for KMAP data, nrj= (in keV) is required')

