# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr


import time
import numpy as np
from scipy.spatial import ConvexHull

try:
    # We must import hdf5plugin before h5py, even if it is not used here.
    import hdf5plugin
except:
    pass
import h5py
from ..operator import Operator, OperatorException
from ..version import __version__
from ..utils import phase


class PtychoData:
    """Class for two-dimensional ptychographic data: observed diffraction and probe positions. 
    This may include only part of the data from a larger dataset.
    """

    def __init__(self, iobs=None, positions=None, detector_distance=None, mask=None,
                 pixel_size_detector=None, wavelength=None, near_field=False):
        """
        
        :param iobs: 3d array with (nb_frame, ny,nx) shape observed intensity (assumed to follow Poisson statistics).
                     The frames will be stored fft-shifted so that the center of diffraction lies in the (0,0) corner
                     of each image. The supplied frames should have the diffraction center in the middle of the frames.
                     Intensities must be >=0. Negative values will be used to mark masked pixels.
        :param positions: (z, y, x) tuple or 2-column array with ptycho probe positions in meters. 
                          For 2D data, z is ignored and can be None or missing, e.g. with (y,x)
                          The orthonormal coordinate system follows the CXI/NeXus/McStas convention, 
                          with z along the incident beam propagation direction and y vertical towards ceiling.
        :param detector_distance: detector distance in meters
        :param mask: 2D mask (>0 means masked pixel) for the observed data. Can be None. Will be fft-shifted like iobs.
                     Masked pixels are stored as negative values in the iobs array.
        :param pixel_size_detector: in meters, assuming square pixels
        :param wavelength: wavelength of the experiment, in meters.
        :param near_field: True if using near field ptycho
        """
        if iobs is not None:
            self.iobs = np.fft.fftshift(iobs, axes=(-2, -1)).astype(np.float32)
            # This should not be necessary
            self.iobs[self.iobs < 0] = 0
            if mask is not None:
                self.mask = np.fft.fftshift(mask.astype(np.int8))
                self.iobs[mask > 0] = -100
        else:
            self.iobs = None
        self.detector_distance = detector_distance
        self.pixel_size_detector = pixel_size_detector
        self.wavelength = wavelength
        if positions is not None:
            if len(positions) == 2:
                self.posy, self.posx = positions
            else:
                self.posz, self.posy, self.posx = positions
        self.near_field = near_field

    def pixel_size_object(self):
        """
        Get the y and x pixel size in object space after a FFT.
        :return: a tuple (pixel_size_y, pixel_size_x) in meters
        """
        if self.near_field:
            return self.pixel_size_detector, self.pixel_size_detector
        else:
            ny, nx = self.iobs.shape[-2:]
            pixel_size_x = self.wavelength * self.detector_distance / (nx * self.pixel_size_detector)
            pixel_size_y = self.wavelength * self.detector_distance / (ny * self.pixel_size_detector)
            return pixel_size_y, pixel_size_x


class Ptycho:
    """
    Class for two-dimensional ptychographic data: object, probe, and observed diffraction.
    This may include only part of the data from a larger dataset.
    """

    def __init__(self, probe=None, obj=None, background=None, data=None, nb_frame_total=None):
        """

        :param probe: the starting estimate of the probe, as a complex 2D numpy array - can be 3D if modes are used.
                      the probe should be centered in the center of the array.
        :param obj: the starting estimate of the object, as a complex 2D numpy array - can be 3D if modes are used. 
        :param background: 2D array with the incoherent background.
        :param data: the PtychoData object with all observed frames, ptycho positions
        :param nb_frame_total: total number of frames (used for normalization)
        """
        self.data = data
        self._probe = probe
        self._obj = obj
        self._background = background

        if self._probe is not None:
            self._probe = self._probe.astype(np.complex64)
            if self._probe.ndim == 2:
                ny, nx = self._probe.shape
                self._probe = self._probe.reshape((1, ny, nx))

        if self._obj is not None:
            self._obj = self._obj.astype(np.complex64)
            if self._obj.ndim == 2:
                ny, nx = self._obj.shape
                self._obj = self._obj.reshape((1, ny, nx))

        if self._background is not None:
            self._background = self._background.astype(np.float32)
        elif data.iobs is not None:
            self._background = np.zeros(data.iobs.shape[-2:], dtype=np.float32)

        self.nb_frame_total = nb_frame_total
        if self.nb_frame_total is None:
            self.nb_frame_total = len(data.iobs)

        # Placeholder for storage of propagated wavefronts. Only used with CPU
        self._psi = None

        # Stored variables
        self.pixel_size_object = np.float32(data.pixel_size_object()[0])
        self.scan_area_obj = None
        self.scan_area_probe = None
        self.scan_area_points = None
        self.llk_poisson = 0
        self.llk_gaussian = 0
        self.llk_euclidian = 0
        self.nb_photons_calc = 0
        self.nb_obs = (self.data.iobs >= 0).sum()

        # The timestamp counter record when the data was last altered, either in the host or the GPU memory.
        self._timestamp_counter = 1
        self._cpu_timestamp_counter = 1
        self._cl_timestamp_counter = 0
        self._cu_timestamp_counter = 0

        if self.data is not None:
            self.calc_scan_area()

    # def init_obj_probe_mask(self):
    def calc_scan_area(self):
        """
        Compute the scan area for the object and probe, using scipy ConvexHull. The scan area for the object is
        augmented by twice the average distance between scan positions for a more realistic estimation.
        scan_area_points is also computed, corresponding to the outline of the scanned area.

        :return: Nothing. self.scan_area_probe and self.scan_area_obj are updated, as 2D arrays with the same shape 
                 as the object and probe, with False outside the scan area and True inside.
        """
        # TODO: expand scan area by the average distance between neighbours
        py, px = self.data.pixel_size_object()
        y, x = self.data.posy - self.data.posy.mean(), self.data.posx - self.data.posx.mean()
        points = np.array([(x / px, y / py) for x, y in zip(x, y)])
        c = ConvexHull(points)
        vx = points[:, 0]
        vy = points[:, 1]
        # Scan center & average distance between points
        xc = vx.mean()
        yc = vy.mean()
        # Estimated average distance between points with an hexagonal model
        try:
            w = 4 / 3 / np.sqrt(3) * np.sqrt(c.volume / x.size)
        except:
            # c.volume only supported in scipy >=0.17 (2016/02)
            w = 0
        # print("calc_scan_area: scan area = %8g pixels^2, center @(%6.1f, %6.1f), <d>=%6.2f)"%(c.volume, xc, yc, w))
        # Object
        ny, nx = self._obj.shape[-2:]
        xx = np.fft.fftshift(np.fft.fftfreq(nx, d=1. / nx))
        yy = np.fft.fftshift(np.fft.fftfreq(ny, d=1. / ny))[:, np.newaxis]
        self.scan_area_obj = np.ones((ny, nx), dtype=np.float32)
        vd = []
        for i in range(len(c.vertices)):
            # c.vertices are in counterclockwise order
            xx0, yy0 = vx[c.vertices[i - 1]], vy[c.vertices[i - 1]]
            xx1, yy1 = vx[c.vertices[i]], vy[c.vertices[i]]
            if w > 0:
                # Increase distance-to-center for a more realistic scan area
                d0 = np.sqrt((xx0 - xc) ** 2 + (yy0 - yc) ** 2)
                xx0 = xc + (xx0 - xc) * (d0 + 2 * w) / d0
                yy0 = yc + (yy0 - yc) * (d0 + 2 * w) / d0
                d1 = np.sqrt((xx1 - xc) ** 2 + (yy1 - yc) ** 2)
                xx1 = xc + (xx1 - xc) * (d1 + 2 * w) / d1
                yy1 = yc + (yy1 - yc) * (d1 + 2 * w) / d1

            dx, dy = xx1 - xx0, yy1 - yy0
            vd.append(np.sqrt(dx ** 2 + dy ** 2))
            self.scan_area_obj *= ((xx - xx0) * dy - (yy - yy0) * dx) <= 0
        self.scan_area_obj = self.scan_area_obj > 0.5
        # Probe
        ny, nx = self._probe.shape[-2:]
        xx = np.fft.fftshift(np.fft.fftfreq(nx, d=1. / nx))
        yy = np.fft.fftshift(np.fft.fftfreq(ny, d=1. / ny))[:, np.newaxis]
        self.scan_area_probe = np.ones((ny, nx), dtype=np.float32)
        vd = []
        for i in range(len(c.vertices)):
            # c.vertices are in counterclockwise order
            xx0, yy0 = vx[c.vertices[i - 1]], vy[c.vertices[i - 1]]
            dx, dy = vx[c.vertices[i]] - xx0, vy[c.vertices[i]] - yy0
            vd.append(np.sqrt(dx ** 2 + dy ** 2))
            self.scan_area_probe *= ((xx - xx0) * dy - (yy - yy0) * dx) <= 0
        self.scan_area_probe = self.scan_area_probe > 0.5
        x, y = [points[i, 1] for i in c.vertices], [points[i, 0] for i in c.vertices]
        x.append(points[c.vertices[0], 1])
        y.append(points[c.vertices[0], 0])
        self.scan_area_points = np.array(x), np.array(y)

    def get_obj(self):
        """
        Get the object data array. This will automatically get the latest data, either from GPU or from the host
        memory, depending where the last changes were made.

        :param shift: if True, the data array will be fft-shifted so that the center of the data is in the center
                      of the array, rather than in the corner (the default).
        :return: the 3D numpy data array (nb object modes, nyo, nxo)
        """
        if self._cpu_timestamp_counter < self._timestamp_counter:
            if self._timestamp_counter == self._cl_timestamp_counter:
                self._obj = self._cl_obj.get()
                self._probe = self._cl_probe.get()
                if self._background is not None:
                    self._background = self._cl_background.get()
            if self._timestamp_counter == self._cu_timestamp_counter:
                self._obj = self._cu_obj.get()
                self._probe = self._cu_probe.get()
                if self._background is not None:
                    self._background = self._cu_background.get()
            self._cpu_timestamp_counter = self._timestamp_counter

        return self._obj

    def set_obj(self, obj):
        """
        Set the object data array.

        :param obj: the object (complex64 numpy array)
        :return: nothing
        """
        self._obj = obj.astype(np.complex64)
        self._timestamp_counter += 1
        self._cpu_timestamp_counter = self._timestamp_counter

    def get_probe(self):
        """
        Get the probe data array. This will automatically get the latest data, either from GPU or from the host
        memory, depending where the last changes were made.

        :param shift: if True, the data array will be fft-shifted so that the center of the data is in the center
                      of the array, rather than in the corner (the default).
        :return: the 3D probe numpy data array
        """
        if self._cpu_timestamp_counter < self._timestamp_counter:
            if self._timestamp_counter == self._cl_timestamp_counter:
                self._obj = self._cl_obj.get()
                self._probe = self._cl_probe.get()
                if self._background is not None:
                    self._background = self._cl_background.get()
            if self._timestamp_counter == self._cu_timestamp_counter:
                self._obj = self._cu_obj.get()
                self._probe = self._cu_probe.get()
                if self._background is not None:
                    self._background = self._cu_background.get()
            self._cpu_timestamp_counter = self._timestamp_counter

        return self._probe

    def set_background(self, background):
        """
        Set the incoherent background data array.

        :param background: the background (float32 numpy array)
        :return: nothing
        """
        self._background = background.astype(np.float32)
        self._timestamp_counter += 1
        self._cpu_timestamp_counter = self._timestamp_counter

    def get_background(self, shift=False):
        """
        Get the background data array. This will automatically get the latest data, either from GPU or from the host
        memory, depending where the last changes were made.

        :param shift: if True, the data array will be fft-shifted so that the center of the data is in the center
                      of the array, rather than in the corner (the default).
        :return: the 2D numpy data array
        """
        if self._background is not None:
            if self._cpu_timestamp_counter < self._timestamp_counter:
                if self._timestamp_counter == self._cl_timestamp_counter:
                    self._obj = self._cl_obj.get()
                    self._probe = self._cl_probe.get()
                    self._background = self._cl_background.get()
                if self._timestamp_counter == self._cu_timestamp_counter:
                    self._obj = self._cu_obj.get()
                    self._probe = self._cu_probe.get()
                    self._background = self._cu_background.get()
                self._cpu_timestamp_counter = self._timestamp_counter
            if shift:
                return np.fft.fftshift(self._background)
        return self._background

    def set_probe(self, probe):
        """
        Set the probe data array.

        :param probe: the probe (complex64 numpy array)
        :return: nothing
        """
        self._probe = probe.astype(np.complex64)
        self._timestamp_counter += 1
        self._cpu_timestamp_counter = self._timestamp_counter

    def __rmul__(self, x):
        """
        Multiply object (by a scalar).

        This is a placeholder for a function which will be replaced when importing either CUDA or OpenCL operators.
        If called before being replaced, will raise an error

        :param x: the scalar by which the wavefront will be multiplied
        :return:
        """
        if np.isscalar(x):
            raise OperatorException(
                "ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s. Did you import operators ?" % (str(x), str(self)))
        else:
            raise OperatorException("ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s." % (str(x), str(self)))

    def __mul__(self, x):
        """
        Multiply object (by a scalar).

        This is a placeholder for a function which will be replaced when importing either CUDA or OpenCL operators.
        If called before being replaced, will raise an error

        :param x: the scalar by which the wavefront will be multiplied
        :return:
        """
        if np.isscalar(x):
            raise OperatorException(
                "ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s. Did you import operators ?" % (str(self), str(x)))
        else:
            raise OperatorException("ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s." % (str(self), str(x)))

    def save_obj_probe_cxi(self, filename, sample_name=None, experiment_id=None, instrument=None, note=None,
                           process=None, append=False, shift_phase_zero=False):
        """
        Save the result of the optimisation (object, probe, scan areas) to an HDF5 CXI-like file.
        :param filename: the file name to save the data to
        :param sample_name: optional, the sample name
        :param experiment_id: the string identifying the experiment, e.g.: 'HC1234: Siemens star calibration tests'
        :param instrument: the string identifying the instrument, e.g.: 'ESRF id10'
        :param note: a string with a text note giving some additional information about the data, a publication...
        :param process: a dictionary of strings which will be saved in '/entry_N/data_1/process_1'. A dictionary entry
                        can also be a 'note' as keyword and a dictionary as value - all key/values will then be saved
                        as separate notes. Example: process={'program': 'PyNX', 'note':{'llk':1.056, 'nb_photons': 1e8}}
        :param append: by default (append=False), any existing file will be overwritten, and the result will be saved
                       as 'entry_1'. If append==True and the file exists, a new entry_N will be saved instead.
                       This can be used to export the different steps of an optimisation.
        :return: Nothing. a CXI file is created
        """
        if append:
            f = h5py.File(filename, "a")
            if "cxi_version" not in f:
                f.create_dataset("cxi_version", data=150)
            i = 1
            while True:
                if 'entry_%d' % i not in f:
                    break
                i += 1
            entry = f.create_group("/entry_%d" % i)
            entry_path = "/entry_%d" % i
        else:
            f = h5py.File(filename, "w")
            f.create_dataset("cxi_version", data=140)
            entry = f.create_group("/entry_1")
            entry_path = "/entry_1"

        entry.create_dataset("program_name", data="PyNX %s" % (__version__))
        entry.create_dataset("start_time", data=time.strftime("%Y-%m-%dT%H:%M:%S%z", time.localtime(time.time())))
        if experiment_id is not None:
            entry.create_dataset("experiment_id", data=experiment_id)

        if note is not None:
            note_1 = entry.create_group("note_1")
            note_1.create_dataset("data", data=note)
            note_1.create_dataset("type", data="text/plain")

        if sample_name is not None:
            sample_1 = entry.create_group("sample_1")
            sample_1.create_dataset("name", data=sample_name)

        # Get the object and center the phase around 0
        obj = phase.shift_phase_zero(self.get_obj(), percent=2, origin=0, mask=self.scan_area_obj)

        probe = self.get_probe()

        # Store object in result_1
        result_1 = entry.create_group("result_1")
        entry["object"] = h5py.SoftLink(entry_path + '/result_1')  # Unorthodox, departs from specification ?
        result_1.create_dataset("data", data=obj, chunks=True, shuffle=True, compression="gzip")
        result_1.create_dataset("data_type", data="electron density")
        result_1.create_dataset("data_space", data="real")
        ny, nx = obj.shape[-2:]
        py, px = self.data.pixel_size_object()
        result_1.create_dataset("image_size", data=[px * nx, py * ny])
        # Store object pixel size (not in CXI specification)
        result_1.create_dataset("x_pixel_size", data=px)
        result_1.create_dataset("y_pixel_size", data=py)

        if self.scan_area_obj is not None:
            # Using 'mask' from CXI specification: 0x00001000 == CXI_PIXEL_HAS_SIGNAL 'pixel signal above background'
            s = (self.scan_area_obj > 0).astype(np.int) * 0x00001000
            result_1.create_dataset("mask", data=s, chunks=True, shuffle=True, compression="gzip")

        # Store probe in result_2
        result_2 = entry.create_group("result_2")
        entry["probe"] = h5py.SoftLink(entry_path + '/result_2')  # Unorthodox, departs from specification ?
        result_2.create_dataset("data", data=probe, chunks=True, shuffle=True, compression="gzip")
        result_2.create_dataset("data_space", data="real")
        ny, nx = probe.shape[-2:]
        result_2.create_dataset("image_size", data=[px * nx, py * ny])
        # Store probe pixel size (not in CXI specification)
        result_2.create_dataset("x_pixel_size", data=px)
        result_2.create_dataset("y_pixel_size", data=py)

        if self.scan_area_probe is not None:
            # Using 'mask' from CXI specification: 0x00001000 == CXI_PIXEL_HAS_SIGNAL 'pixel signal above background'
            s = (self.scan_area_probe > 0).astype(np.int) * 0x00001000
            result_2.create_dataset("mask", data=s, chunks=True, shuffle=True, compression="gzip")

        if self.get_background() is not None:
            result_3 = entry.create_group("result_3")
            entry["background"] = h5py.SoftLink(entry_path + '/result_2')  # Unorthodox, departs from specification ?
            result_3.create_dataset("data", data=self.get_background(shift=True), chunks=True, shuffle=True,
                                    compression="gzip")
            result_3.create_dataset("data_space", data="diffraction")
            result_3.create_dataset("data_type", data="intensity")

        instrument_1 = entry.create_group("instrument_1")
        if instrument is not None:
            instrument_1.create_dataset("name", data=instrument)

        nrj = 12.3984 / self.data.wavelength
        source_1 = instrument_1.create_group("source_1")
        source_1.create_dataset("energy", data=nrj * 1.60218e-16)  # in J

        detector_1 = instrument_1.create_group("detector_1")
        detector_1.create_dataset("distance", data=self.data.detector_distance)

        detector_1.create_dataset("x_pixel_size", data=self.data.pixel_size_detector)
        detector_1.create_dataset("y_pixel_size", data=self.data.pixel_size_detector)

        # Add shortcut to the main data
        data_1 = entry.create_group("data_1")
        data_1["data"] = h5py.SoftLink(entry_path + '/result_1/data')

        if process is not None:
            process_1 = entry.create_group("process_1")
            for k, v in process.items():
                if isinstance(v, str):
                    process_1.create_dataset(k, data=v)
                elif isinstance(v, dict) and k == 'note':
                    # Save this as notes:
                    for kk, vv in v.items():
                        i = 1
                        while True:
                            note_s ='note_%d' % i
                            if note_s not in process_1:
                                break
                            i += 1
                        note = process_1.create_group(note_s)
                        note.create_dataset("data", data=str(vv))
                        note.create_dataset("description", data=kk)
                        note.create_dataset("type", data="text/plain")

        f.flush()  # Is that necessary in addition to close ??
        f.close()


class OperatorPtycho(Operator):
    """
    Base class for an operator on Ptycho2D objects.
    """

    def timestamp_increment(self, p):
        # By default CPU operators should increment the CPU counter. Unless they don't affect the pty object, like
        # all display operators.
        p._timestamp_counter += 1
        p._cpu_timestamp_counter = p._timestamp_counter


def algo_string(algo_base, p, update_object, update_probe, update_background=False):
    """
    Get a short string for the algorithm being run, e.g. 'DM/o/3p' for difference map with 1 object and 3 probe modes.

    :param algo_base: 'AP' or 'ML' or 'DM'
    :param p: the ptycho object
    :param update_object: True if updating the object
    :param update_probe: True if updating the probe
    :return: a short string for the algorithm
    """
    s = algo_base

    if update_object:
        s += "/"
        if len(p._obj) > 1:
            s += "%d" % (len(p._obj))
        s += "o"

    if update_probe:
        s += "/"
        if len(p._probe) > 1:
            s += "%d" % (len(p._probe))
        s += "p"

    if update_background:
        s += "/b"

    return s
