# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

__all__ = ['CDI', 'OperatorCDI', 'save_cdi_data_cxi']

import time
import sys
import warnings
import numpy as np
from ..operator import Operator, OperatorException, has_attr_not_none
from ..version import __version__

try:
    # We must import hdf5plugin before h5py, even if it is not used here.
    import hdf5plugin
except:
    pass
import h5py


class CDI:
    """ Reconstruction class for two or three-dimensional coherent diffraction imaging data
    """

    def __init__(self, iobs, support=None, obj=None, mask=None,
                 pixel_size_detector=None, wavelength=None, detector_distance=None):
        """
        Constructor. All arrays are assumed to be centered in (0,0) to avoid fftshift

        Args:
            iobs: 2D/3D observed diffraction data (intensity). 
                  Assumed to be corrected and following Poisson statistics, will be converted to float32.
                  Dimensions should be divisible by 4 and have a prime factor decomposition up to 7
            support: initial support in real space (1 = inside support, 0 = outside)
            obj: initial object. If None, will be initialized to a complex object
                 with uniform random 0<modulus<1 and 0<phase<2pi
                 The data is assumed to be centered in (0,0) to avoid fft shifts.
            mask: mask for the diffraction data (0: valid pixel, >0: masked)
            pixel_size_detector: detector pixel size (meters)
            wavelength: experiment wavelength (meters)
            detector_distance: detector distance (meters)
        """
        self.iobs = iobs.astype(np.float32)

        if support is not None:
            self._support = support.copy().astype(np.int8)
        else:
            # Support will be later updated
            self._support = np.ones_like(iobs, dtype=np.int8)

        if obj is None:
            a, p = np.random.uniform(0, 1, iobs.shape), np.random.uniform(0, 2 * np.pi, iobs.shape)
            if self._support is not None:
                self._obj = (a * np.exp(1j * p) * self._support).astype(np.complex64)
            else:
                self._obj = (a * np.exp(1j * p)).astype(np.complex64)
        else:
            self._obj = obj.copy().astype(np.complex64)

        self._is_in_object_space = False

        if mask is not None:
            # Avoid needing to store/transfer a mask in GPU. Use negative iobs to flag masked pixels
            self.iobs[mask > 0] = -100

        # reciprocal space kernel for the point spread function. This is used for partial coherence correction.
        self._k_psf = None

        self.pixel_size_detector = pixel_size_detector
        self.wavelength = wavelength
        self.detector_distance = detector_distance
        self.pixel_size_object = None
        self.lambdaz = None

        # Experimental parameters
        if self.wavelength is not None and self.detector_distance is not None:
            self.lambdaz = self.wavelength * self.detector_distance
        if self.lambdaz is not None and self.pixel_size_detector is not None and iobs is not None:
            self.pixel_size_object = self.lambdaz / (self.pixel_size_detector * self.iobs.shape[-1])

        # Variables for log-likelihood statistics
        self.llk_poisson = 0
        self.llk_gaussian = 0
        self.llk_euclidian = 0
        self.nb_photons_calc = 0

        if mask is not None:
            self.nb_observed_points = (mask == 0).sum()
        else:
            self.nb_observed_points = self.iobs.size

        self.llk_support = None  # Current value of the support log-likelihood (support regularization)
        self.llk_support_reg_fac = None  # Regularization factor for the support log-likelihood
        if support is not None:
            self.nb_point_support = support.sum()
        else:
            self.nb_point_support = 0

        # Max amplitude reported during support update (after smoothing)
        self._obj_max = 0

        # The timestamp counter record when the cdi or support data was last altered, either in the host or the
        # GPU memory.
        self._timestamp_counter = 1
        self._cl_timestamp_counter = 0
        self._cu_timestamp_counter = 0

        # Record the number of cycles (RAAR, HIO, ER, CF, etc...), which can be used to make some parameters
        # evolve, e.g. for support update
        self.cycle = 0

    def get_x_y_z(self):
        """
        Get 1D arrays of x and y (z if 3d) coordinates, taking into account the pixel size. The arrays are centered
        at (0,0) - i.e. with the origin in the corner for FFT puroposes. x is an horizontal vector and y vertical, 
        and (if 3d) z along third dimension.

        :return: a tuple (x, y) or (x, y, z) of 1D numpy arrays
        """
        if self.iobs.ndim == 2:
            ny, nx = self.iobs.shape
            x, y = np.arange(-nx // 2, nx // 2, dtype=np.float32), \
                   np.arange(-ny // 2, ny // 2, dtype=np.float32)[:, np.newaxis]
            return np.fft.fftshift(x) * self.pixel_size_object, np.fft.fftshift(y) * self.pixel_size_object
        else:
            nz, ny, nx = self.iobs.shape
            x, y, z = np.arange(-nx // 2, nx // 2, dtype=np.float32), \
                      np.arange(-ny // 2, ny // 2, dtype=np.float32)[:, np.newaxis], \
                      np.arange(-nz // 2, nz // 2, dtype=np.float32)[:, np.newaxis, np.newaxis]
            return np.fft.fftshift(x) * self.pixel_size_object, np.fft.fftshift(y) * self.pixel_size_object, \
                   np.fft.fftshift(z) * self.pixel_size_object

    def get_x_y(self):
        return self.get_x_y_z()

    def copy(self, copy_obj=True):
        """
        Creates a copy (without any reference passing) of this object, unless copy_obj is False.

        :param copy_obj: if False, the new object will be a shallow copy, with d copied as a reference.
        :return: a copy of the object.
        """
        return CDI(iobs=self.iobs, support=self._support, obj=self._obj, mask=self.iobs < 0,
                   pixel_size_detector=self.pixel_size_detector, wavelength=self.wavelength,
                   detector_distance=self.detector_distance)

    def in_object_space(self):
        """

        :return: True if the current obj array is in object space, False otherwise.
        """
        return self._is_in_object_space

    def _from_gpu(self):
        """
        Internal function to get relevant arrays from GPU memory
        :return: Nothing
        """
        if self._timestamp_counter < self._cl_timestamp_counter:
            self._obj = self._cl_obj.get()
            self._support = self._cl_support.get()
            # if has_attr_not_none(self, '_cl_k_psf'):
            #     self._k_psf = self._cl_k_psf.get()
            self._timestamp_counter = self._cl_timestamp_counter
        if self._timestamp_counter < self._cu_timestamp_counter:
            self._obj = self._cu_obj.get()
            self._support = self._cu_support.get()
            # if has_attr_not_none(self, '_cu_k_psf'):
            #     self._k_psf = self._cu_k_psf.get()
            self._timestamp_counter = self._cu_timestamp_counter

    def get_obj(self, shift=False):
        """
        Get the object data array. This will automatically get the latest data, either from GPU or from the host
        memory, depending where the last changes were made.

        :param shift: if True, the data array will be fft-shifted so that the center of the data is in the center
                      of the array, rather than in the corner (the default).
        :return: the 2D or 3D CDI numpy data array
        """
        self._from_gpu()
        if shift:
            return np.fft.fftshift(self._obj)
        else:
            return self._obj

    def get_support(self, shift=False):
        """
        Get the support array. This will automatically get the latest data, either from GPU or from the host
        memory, depending where the last changes were made.

        :param shift: if True, the data array will be fft-shifted so that the center of the data is in the center
                      of the array, rather than in the corner (the default).
        :return: the 2D or 3D CDI numpy data array
        """
        self._from_gpu()

        if shift:
            return np.fft.fftshift(self._support)
        else:
            return self._support

    def get_iobs(self, shift=False):
        """
        Get the observed intensity data array.

        :param shift: if True, the data array will be fft-shifted so that the center of the data is in the center
                      of the array, rather than in the corner (the default).
        :return: the 2D or 3D CDI numpy data array
        """
        if shift:
            return np.fft.fftshift(self.iobs)
        else:
            return self.iobs

    def set_obj(self, obj, shift=False):
        """
        Set the object data array.

        :param obj: the 2D or 3D CDI numpy data array (complex64 numpy array)
        :param shift: if True, the data array will be fft-shifted so that the center of the stored data is
                      in the corner of the array. [default: the array is already shifted]
        :return: nothing
        """
        if shift:
            self._obj = np.fft.fftshift(obj).astype(np.complex64)
        else:
            self._obj = obj.astype(np.complex64)
        if self._timestamp_counter <= self._cl_timestamp_counter:
            self._timestamp_counter = self._cl_timestamp_counter + 1
        if self._timestamp_counter <= self._cu_timestamp_counter:
            self._timestamp_counter = self._cu_timestamp_counter + 1

    def set_support(self, support, shift=False):
        """
        Set the support data array.

        :param obj: the 2D or 3D CDI numpy data array (complex64 numpy array)
        :param shift: if True, the data array will be fft-shifted so that the center of the stored data is
                      in the corner of the array. [default: the array is already shifted]
        :return: nothing
        """
        if shift:
            self._support = np.fft.fftshift(support).astype(np.int8)
        else:
            self._support = support.astype(np.int8)
        if self._timestamp_counter <= self._cl_timestamp_counter:
            self._timestamp_counter = self._cl_timestamp_counter + 1
        if self._timestamp_counter <= self._cu_timestamp_counter:
            self._timestamp_counter = self._cu_timestamp_counter + 1
        self.nb_point_support = self._support.sum()

    def set_iobs(self, iobs, shift=False):
        """
        Set the observed intensity data array.

        :param iobs: the 2D or 3D CDI numpy data array (float32 numpy array)
        :param shift: if True, the data array will be fft-shifted so that the center of the stored data is
                      in the corner of the array. [default: the array is already shifted]
        :return: nothing
        """
        if shift:
            self.iobs = np.fft.fftshift(iobs).astype(np.float32)
        else:
            self.iobs = iobs.astype(np.float32)
        if self._timestamp_counter <= self._cl_timestamp_counter:
            self._timestamp_counter = self._cl_timestamp_counter + 1
        if self._timestamp_counter <= self._cu_timestamp_counter:
            self._timestamp_counter = self._cu_timestamp_counter + 1

    def set_mask(self, mask, shift=False):
        """
        Set the mask data array. Note that since the mask is stored by setting observed intensitied of masked
        pixels to negative values, it is not possible to un-mask pixels.

        :param obj: the 2D or 3D CDI mask array
        :param shift: if True, the data array will be fft-shifted so that the center of the stored data is
                      in the corner of the array. [default: the array is already shifted]
        :return: nothing
        """
        if mask is None:
            return
        if shift:
            mask = np.fft.fftshift(mask).astype(np.int8)
        iobs = self.get_iobs()
        iobs[mask > 0] = -100
        self.set_iobs(iobs)

    def save_data_cxi(self, filename, sample_name=None, experiment_id=None, instrument=None):
        """
        Save the diffraction data (observed intensity, mask) to an HDF% CXI file.
        :param filename: the file name to save the data to
        :param sample_name: optional, the sample name
        :param experiment_id: the string identifying the experiment, e.g.: 'HC1234: Siemens star calibration tests'
        :param instrument: the string identifying the instrument, e.g.: 'ESRF id10'
        :return: Nothing. a CXI file is created
        """
        save_cdi_data_cxi(filename, iobs=self.iobs, wavelength=self.wavelength,
                          detector_distance=self.detector_distance,
                          pixel_size_detector=self.pixel_size_detector, mask=self.iobs < 0, sample_name=sample_name,
                          experiment_id=experiment_id, instrument=instrument, iobs_is_fft_shifted=True)

    def save_obj_cxi(self, filename, sample_name=None, experiment_id=None, instrument=None, note=None, crop=False):
        """
        Save the result of the optimisation (object, support) to an HDF5 CXI file.
        :param filename: the file name to save the data to
        :param sample_name: optional, the sample name
        :param experiment_id: the string identifying the experiment, e.g.: 'HC1234: Siemens star calibration tests'
        :param instrument: the string identifying the instrument, e.g.: 'ESRF id10'
        :param note: a string with a text note giving some additional information about the data, a publication...
        :return: Nothing. a CXI file is created
        """
        f = h5py.File(filename, "w")
        f.create_dataset("cxi_version", data=140)
        entry_1 = f.create_group("entry_1")
        entry_1.create_dataset("program_name", data="PyNX %s" % (__version__))
        entry_1.create_dataset("start_time", data=time.strftime("%Y-%m-%dT%H:%M:%S%z", time.localtime(time.time())))
        if experiment_id is not None:
            entry_1.create_dataset("experiment_id", data=experiment_id)

        if note is not None:
            note_1 = entry_1.create_group("note_1")
            note_1.create_dataset("data", data=note)
            note_1.create_dataset("type", data="text/plain")

        if sample_name is not None:
            sample_1 = entry_1.create_group("sample_1")
            sample_1.create_dataset("name", data=sample_name)

        obj = self.get_obj(shift=True)
        sup = self.get_support(shift=True)
        if crop:
            # crop around support, with a margin of 10 pixels

            if self.iobs.ndim == 3:
                l0 = np.nonzero(sup.sum(axis=(1, 2)))[0].take([0, -1]) + np.array([-10, 10])
                if l0[0] < 0: l0[0] = 0
                if l0[1] >= sup.shape[0]: l0[1] = -1

                l1 = np.nonzero(sup.sum(axis=(0, 2)))[0].take([0, -1]) + np.array([-10, 10])
                if l1[0] < 0: l1[0] = 0
                if l1[1] >= sup.shape[1]: l1[1] = -1

                l2 = np.nonzero(sup.sum(axis=(0, 1)))[0].take([0, -1]) + np.array([-10, 10])
                if l2[0] < 0: l2[0] = 0
                if l2[1] >= sup.shape[2]: l2[1] = -1
                obj = obj[l0[0]:l0[1], l1[0]:l1[1], l2[0]:l2[1]]
                sup = sup[l0[0]:l0[1], l1[0]:l1[1], l2[0]:l2[1]]
            else:
                l0 = np.nonzero(sup.sum(axis=1))[0].take([0, -1]) + np.array([-10, 10])
                if l0[0] < 0: l0[0] = 0
                if l0[1] >= sup.shape[0]: l0[1] = -1

                l1 = np.nonzero(sup.sum(axis=0))[0].take([0, -1]) + np.array([-10, 10])
                if l1[0] < 0: l1[0] = 0
                if l1[1] >= sup.shape[1]: l1[1] = -1

                obj = obj[l0[0]:l0[1], l1[0]:l1[1]]
                sup = sup[l0[0]:l0[1], l1[0]:l1[1]]

        image_1 = entry_1.create_group("image_1")
        image_1.create_dataset("data", data=obj, chunks=True, shuffle=True, compression="gzip")
        image_1.create_dataset("mask", data=sup, chunks=True, shuffle=True, compression="gzip")
        image_1["support"] = h5py.SoftLink('/entry_1/image_1/mask')
        image_1.create_dataset("data_type", data="electron density")
        image_1.create_dataset("data_space", data="real")
        if self.pixel_size_object is not None:
            s = self.pixel_size_object * np.array(obj.shape)
            image_1.create_dataset("image_size", data=s)

        instrument_1 = image_1.create_group("instrument_1")
        if instrument is not None:
            instrument_1.create_dataset("name", data=instrument)

        if self.wavelength is not None:
            nrj = 12.3984 / self.wavelength
            source_1 = instrument_1.create_group("source_1")
            source_1.create_dataset("energy", data=nrj * 1.60218e-16)  # in J

        detector_1 = instrument_1.create_group("detector_1")
        if self.detector_distance is not None:
            detector_1.create_dataset("distance", data=self.detector_distance)

        if self.pixel_size_detector is not None:
            detector_1.create_dataset("x_pixel_size", data=self.pixel_size_detector)
            detector_1.create_dataset("y_pixel_size", data=self.pixel_size_detector)

        if self._k_psf is not None:
            # Likely this is due to partial coherence and not the detector, but it is still analysed as a PSF
            detector_1.create_dataset("point_spread_function", data=self._k_psf)

        # Add shortcut to the main data
        data_1 = entry_1.create_group("data_1")
        data_1["data"] = h5py.SoftLink('/entry_1/image_1/data')

        f.flush()  # Is that necessary in addition to close ??
        f.close()

    def get_llkn(self):
        """
        Get the poisson normalized log-likelihood, which should converge to 1 for a statistically ideal fit
        to a Poisson-noise  dataset.

        :return: the normalized log-likelihood, for Poisson noise.
        """
        warnings.warn("cdi.get_llkn is deprecated. Use cdi.get_llk instead", DeprecationWarning)
        return self.get_llk(noise='poisson', normalized=True)

    def get_llk(self, noise=None, normalized=True):
        """
        Get the normalized log-likelihoods, which should converge to 1 for a statistically ideal fit.

        :param noise: either 'gaussian', 'poisson' or 'euclidian', will return the corresponding log-likelihood.
        :param normalized: if True, will return normalized values so that the llk from a statistically ideal model
                           should converge to 1
        :return: the log-likelihood, or if noise=None, a tuple of the three (poisson, gaussian, euclidian)
                 log-likelihoods.
        """
        n = 1
        if normalized:
            n = 1 / max((self.nb_observed_points - self.nb_point_support), 1e-10)

        if noise is None:
            return self.llk_poisson * n, self.llk_gaussian * n, self.llk_euclidian * n
        elif 'poiss' in str(noise).lower():
            return self.llk_poisson * n
        elif 'gauss' in str(noise).lower():
            return self.llk_gaussian * n
        elif 'eucl' in str(noise).lower():
            return self.llk_euclidian * n

    def __rmul__(self, x):
        """
        Multiply object (by a scalar).

        This is a placeholder for a function which will be replaced when importing either CUDA or OpenCL operators.
        If called before being replaced, will raise an error

        :param x: the scalar by which the wavefront will be multiplied
        :return:
        """
        if np.isscalar(x):
            raise OperatorException(
                "ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s. Did you import operators ?" % (str(x), str(self)))
        else:
            raise OperatorException("ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s." % (str(x), str(self)))

    def __mul__(self, x):
        """
        Multiply object (by a scalar).

        This is a placeholder for a function which will be replaced when importing either CUDA or OpenCL operators.
        If called before being replaced, will raise an error

        :param x: the scalar by which the wavefront will be multiplied
        :return:
        """
        if np.isscalar(x):
            raise OperatorException(
                "ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s. Did you import operators ?" % (str(self), str(x)))
        else:
            raise OperatorException("ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s." % (str(self), str(x)))

    def __str__(self):
        return "CDI"


def save_cdi_data_cxi(filename, iobs, wavelength=None, detector_distance=None, pixel_size_detector=None, mask=None,
                      sample_name=None, experiment_id=None, instrument=None, note=None, iobs_is_fft_shifted=False):
    """
    Save the diffraction data (observed intensity, mask) to an HDF5 CXI file.
    :param filename: the file name to save the data to
    :param iobs: the observed intensity
    :param wavelength: the wavelength of the experiment (in meters)
    :param detector_distance: the detector distance (in meters)
    :param pixel_size_detector: the pixel size of the detector (in meters)
    :param mask: the mask indicating valide (=0) and bad pixels (>0)
    :param sample_name: optional, the sample name
    :param experiment_id: the string identifying the experiment, e.g.: 'HC1234: Siemens star calibration tests'
    :param instrument: the string identifying the instrument, e.g.: 'ESRF id10'
    :param iobs_is_fft_shifted: if true, input iobs (and mask if any) have their origin in (0,0[,0]) and will be shifted
    back to centered-versions before being saved.
    :return: Nothing. a CXI file is created
    """
    f = h5py.File(filename, "w")
    f.create_dataset("cxi_version", data=140)
    entry_1 = f.create_group("entry_1")
    entry_1.create_dataset("program_name", data="PyNX %s" % (__version__))
    entry_1.create_dataset("start_time", data=time.strftime("%Y-%m-%dT%H:%M:%S%z", time.localtime(time.time())))
    if experiment_id is not None:
        entry_1.create_dataset("experiment_id", data=experiment_id)

    if note is not None:
        note_1 = entry_1.create_group("note_1")
        note_1.create_dataset("data", data=note)
        note_1.create_dataset("type", data="text/plain")

    if sample_name is not None:
        sample_1 = entry_1.create_group("sample_1")
        sample_1.create_dataset("name", data=sample_name)

    instrument_1 = entry_1.create_group("instrument_1")
    if instrument is not None:
        instrument_1.create_dataset("name", data=instrument)

    if wavelength is not None:
        nrj = 12.3984 / (wavelength * 1e10)
        source_1 = instrument_1.create_group("source_1")
        source_1.create_dataset("energy", data=nrj * 1.60218e-16)  # in J

    detector_1 = instrument_1.create_group("detector_1")
    if detector_distance is not None:
        detector_1.create_dataset("distance", data=detector_distance)
    if pixel_size_detector is not None:
        detector_1.create_dataset("x_pixel_size", data=pixel_size_detector)
        detector_1.create_dataset("y_pixel_size", data=pixel_size_detector)
    if iobs_is_fft_shifted:
        detector_1.create_dataset("data", data=np.fft.fftshift(iobs), chunks=True, shuffle=True,
                                  compression="gzip")
    else:
        detector_1.create_dataset("data", data=iobs, chunks=True, shuffle=True,
                                  compression="gzip")

    if mask is not None:
        if mask.sum() != 0:
            if iobs_is_fft_shifted:
                detector_1.create_dataset("mask", data=np.fft.fftshift(mask), chunks=True, shuffle=True,
                                          compression="gzip")
            else:
                detector_1.create_dataset("mask", data=mask, chunks=True, shuffle=True, compression="gzip")
    if False:
        # Basis vector - this is the default CXI convention, so could be skipped
        # This corresponds to a 'top, left' origin convention
        basis_vectors = np.zeros((2, 3), dtype=np.float32)
        basis_vectors[0, 1] = -pixel_size_detector
        basis_vectors[1, 0] = -pixel_size_detector
        detector_1.create_dataset("basis_vectors", data=basis_vectors)

    data_1 = entry_1.create_group("data_1")
    data_1["data"] = h5py.SoftLink('/entry_1/instrument_1/detector_1/data')

    # Remember how import was done
    command = ""
    for arg in sys.argv:
        command += arg + " "
    process_1 = data_1.create_group("process_1")
    process_1.create_dataset("command", data=command)
    f.flush()  # Is that necessary in addition to close ??
    f.close()


class OperatorCDI(Operator):
    """
    Base class for an operator on CDI objects, not requiring a processing unit.
    """

    def timestamp_increment(self, cdi):
        cdi._timestamp_counter += 1
