# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr


import os
import sys
import time
import numpy as np
from scipy.spatial import ConvexHull

try:
    # We must import hdf5plugin before h5py, even if it is not used here.
    import hdf5plugin
except:
    pass
import h5py
from ..operator import Operator, OperatorException, has_attr_not_none
from ..version import __version__
from ..utils import phase
from ..utils.history import History


class PtychoData:
    """Class for two-dimensional ptychographic data: observed diffraction and probe positions. 
    This may include only part of the data from a larger dataset.
    """

    def __init__(self, iobs=None, positions=None, detector_distance=None, mask=None,
                 pixel_size_detector=None, wavelength=None, near_field=False, padding=0):
        """
        
        :param iobs: 3d array with (nb_frame, ny,nx) shape observed intensity (assumed to follow Poisson statistics).
                     The frames will be stored fft-shifted so that the center of diffraction lies in the (0,0) corner
                     of each image. The supplied frames should have the diffraction center in the middle of the frames.
                     Intensities must be >=0. Negative values will be used to mark masked pixels.
        :param positions: (x, y, z) tuple or 2-column array with ptycho probe positions in meters.
                          For 2D data, z is ignored and can be None or missing, e.g. with (x, y)
                          The orthonormal coordinate system follows the CXI/NeXus/McStas convention, 
                          with z along the incident beam propagation direction and y vertical towards ceiling.
        :param detector_distance: detector distance in meters
        :param mask: 2D mask (>0 means masked pixel) for the observed data. Can be None. Will be fft-shifted like iobs.
                     Masked pixels are stored as negative values in the iobs array.
        :param pixel_size_detector: in meters, assuming square pixels
        :param wavelength: wavelength of the experiment, in meters.
        :param near_field: True if using near field ptycho
        :param padding: an integer value indicating the number of zero-padded pixels to be used
                        on each side of the observed frames. This can be used for near field ptychography.
                        The input iobs should already padded, the corresponding pixels will be added to the mask.
        """
        self.iobs = None  # full observed intensity array
        self.iobs_sum = None
        self.scale = None  # Vector of scale factors, one for each frame. Used for floating intensities.
        self.padding = padding

        if self.padding > 0 and iobs is not None:
            if mask is not None:
                mask[:padding] = 1
                mask[:, :padding] = 1
                mask[-padding:] = 1
                mask[:, -padding:] = 1
            else:
                mask = np.ones(iobs.shape[-2:], dtype=np.int8)
                mask[padding:-padding, padding:-padding] = 0

        if iobs is not None:
            self.iobs = np.fft.fftshift(iobs, axes=(-2, -1)).astype(np.float32)
            # This should not be necessary
            self.iobs[self.iobs < 0] = 0
            self.iobs_sum = self.iobs.sum()
            if mask is not None:
                self.mask = np.fft.fftshift(mask.astype(np.int8))
                self.iobs[:, self.mask > 0] = -100
            self.scale = np.ones(len(self.iobs), dtype=np.float32)
        self.detector_distance = detector_distance
        self.pixel_size_detector = pixel_size_detector
        self.wavelength = wavelength
        if positions is not None:
            if len(positions) == 2:
                self.posx, self.posy = positions
            else:
                self.posx, self.posy, self.posz = positions
        self.near_field = near_field

    def pixel_size_object(self):
        """
        Get the x and y pixel size in object space after a FFT.
        :return: a tuple (pixel_size_x, pixel_size_y) in meters
        """
        if self.near_field:
            return self.pixel_size_detector, self.pixel_size_detector
        else:
            ny, nx = self.iobs.shape[-2:]
            pixel_size_x = self.wavelength * self.detector_distance / (nx * self.pixel_size_detector)
            pixel_size_y = self.wavelength * self.detector_distance / (ny * self.pixel_size_detector)
            return pixel_size_x, pixel_size_y


class Ptycho:
    """
    Class for two-dimensional ptychographic data: object, probe, and observed diffraction.
    This may include only part of the data from a larger dataset.
    """

    def __init__(self, probe=None, obj=None, background=None, data=None, nb_frame_total=None):
        """

        :param probe: the starting estimate of the probe, as a complex 2D numpy array - can be 3D if modes are used.
                      the probe should be centered in the center of the array.
        :param obj: the starting estimate of the object, as a complex 2D numpy array - can be 3D if modes are used. 
        :param background: 2D array with the incoherent background.
        :param data: the PtychoData object with all observed frames, ptycho positions
        :param nb_frame_total: total number of frames (used for normalization)
        """
        self.data = data
        self._probe = probe
        self._obj = obj
        self._background = background
        self._obj_zero_phase_mask = None

        if self._probe is not None:
            self._probe = self._probe.astype(np.complex64)
            if self._probe.ndim == 2:
                ny, nx = self._probe.shape
                self._probe = self._probe.reshape((1, ny, nx))

        if self._obj is not None:
            self._obj = self._obj.astype(np.complex64)
            if self._obj.ndim == 2:
                ny, nx = self._obj.shape
                self._obj = self._obj.reshape((1, ny, nx))

        if self._background is not None:
            self._background = self._background.astype(np.float32)
        elif data is not None:
            if data.iobs is not None:
                self._background = np.zeros(data.iobs.shape[-2:], dtype=np.float32)

        self.nb_frame_total = nb_frame_total
        if self.nb_frame_total is None and data is not None:
            self.nb_frame_total = len(data.iobs)

        # Placeholder for storage of propagated wavefronts. Only used with CPU
        self._psi = None

        # Stored variables
        if data is not None:
            self.pixel_size_object = np.float32(data.pixel_size_object()[0])
        self.scan_area_obj = None
        self.scan_area_probe = None
        self.scan_area_points = None
        self.llk_poisson = 0
        self.llk_gaussian = 0
        self.llk_euclidian = 0
        self.nb_photons_calc = 0
        if data is not None:
            self.nb_obs = (self.data.iobs >= 0).sum()
        else:
            self.nb_obs = 0

        # The timestamp counter record when the data was last altered, either in the host or the GPU memory.
        self._timestamp_counter = 1
        self._cpu_timestamp_counter = 1
        self._cl_timestamp_counter = 0
        self._cu_timestamp_counter = 0

        if self.data is not None:
            self.calc_scan_area()

        # Regularisation scale factors
        self.reg_fac_scale_obj = 0
        self.reg_fac_scale_probe = 0
        self.calc_regularisation_scale()

        # Record the number of cycles (ML, AP, DM, etc...), for history purposes
        self.cycle = 0
        # History record
        self.history = History()

    # def init_obj_probe_mask(self):
    def calc_scan_area(self):
        """
        Compute the scan area for the object and probe, using scipy ConvexHull. The scan area for the object is
        augmented by twice the average distance between scan positions for a more realistic estimation.
        scan_area_points is also computed, corresponding to the outline of the scanned area.

        :return: Nothing. self.scan_area_probe and self.scan_area_obj are updated, as 2D arrays with the same shape 
                 as the object and probe, with False outside the scan area and True inside.
        """
        # TODO: expand scan area by the average distance between neighbours
        px, py = self.data.pixel_size_object()
        y, x = self.data.posy - self.data.posy.mean(), self.data.posx - self.data.posx.mean()
        points = np.array([(x / px, y / py) for x, y in zip(x, y)])
        c = ConvexHull(points)
        vx = points[:, 0]
        vy = points[:, 1]
        # Scan center & average distance between points
        xc = vx.mean()
        yc = vy.mean()
        # Estimated average distance between points with an hexagonal model
        try:
            w = 4 / 3 / np.sqrt(3) * np.sqrt(c.volume / x.size)
        except:
            # c.volume only supported in scipy >=0.17 (2016/02)
            w = 0
        # print("calc_scan_area: scan area = %8g pixels^2, center @(%6.1f, %6.1f), <d>=%6.2f)"%(c.volume, xc, yc, w))
        # Object
        ny, nx = self._obj.shape[-2:]
        xx = np.fft.fftshift(np.fft.fftfreq(nx, d=1. / nx))
        yy = np.fft.fftshift(np.fft.fftfreq(ny, d=1. / ny))[:, np.newaxis]
        self.scan_area_obj = np.ones((ny, nx), dtype=np.float32)
        vd = []
        for i in range(len(c.vertices)):
            # c.vertices are in counterclockwise order
            xx0, yy0 = vx[c.vertices[i - 1]], vy[c.vertices[i - 1]]
            xx1, yy1 = vx[c.vertices[i]], vy[c.vertices[i]]
            if w > 0:
                # Increase distance-to-center for a more realistic scan area
                d0 = np.sqrt((xx0 - xc) ** 2 + (yy0 - yc) ** 2)
                xx0 = xc + (xx0 - xc) * (d0 + 2 * w) / d0
                yy0 = yc + (yy0 - yc) * (d0 + 2 * w) / d0
                d1 = np.sqrt((xx1 - xc) ** 2 + (yy1 - yc) ** 2)
                xx1 = xc + (xx1 - xc) * (d1 + 2 * w) / d1
                yy1 = yc + (yy1 - yc) * (d1 + 2 * w) / d1

            dx, dy = xx1 - xx0, yy1 - yy0
            vd.append(np.sqrt(dx ** 2 + dy ** 2))
            self.scan_area_obj *= ((xx - xx0) * dy - (yy - yy0) * dx) <= 0
        self.scan_area_obj = self.scan_area_obj > 0.5
        # Probe
        ny, nx = self._probe.shape[-2:]
        xx = np.fft.fftshift(np.fft.fftfreq(nx, d=1. / nx))
        yy = np.fft.fftshift(np.fft.fftfreq(ny, d=1. / ny))[:, np.newaxis]
        self.scan_area_probe = np.ones((ny, nx), dtype=np.float32)
        vd = []
        for i in range(len(c.vertices)):
            # c.vertices are in counterclockwise order
            xx0, yy0 = vx[c.vertices[i - 1]], vy[c.vertices[i - 1]]
            dx, dy = vx[c.vertices[i]] - xx0, vy[c.vertices[i]] - yy0
            vd.append(np.sqrt(dx ** 2 + dy ** 2))
            self.scan_area_probe *= ((xx - xx0) * dy - (yy - yy0) * dx) <= 0
        self.scan_area_probe = self.scan_area_probe > 0.5
        x, y = [points[i, 0] for i in c.vertices], [points[i, 1] for i in c.vertices]
        x.append(points[c.vertices[0], 0])
        y.append(points[c.vertices[0], 1])
        self.scan_area_points = np.array(x), np.array(y)

    def calc_regularisation_scale(self):
        """
        Calculate the scale factor for object and probe regularisation.
        Calculated according to Thibault & Guizar-Sicairos 2012
        :return: nothing
        """
        if self.data is not None and self._obj is not None and self._probe is not None:
            probe_size = self._probe[0].size
            obj_size = self._obj[0].size
            data_size = self.data.iobs.size
            nb_photons = self.nb_obs
            # TODO: take into account the object area actually scanned
            self.reg_fac_scale_obj = data_size * nb_photons / (8 * obj_size ** 2)
            self.reg_fac_scale_probe = data_size * nb_photons / (8 * probe_size ** 2)
            if False:
                print("Regularisation scale factors: object %8e probe %8e" % (self.reg_fac_scale_obj,
                                                                              self.reg_fac_scale_probe))
        else:
            self.reg_fac_scale_obj = 0
            self.reg_fac_scale_probe = 0

    def from_pu(self):
        """
        Get all relevant arrays from processing unit, if necessary
        :return: Nothing
        """
        if self._cpu_timestamp_counter < self._timestamp_counter:
            if self._timestamp_counter == self._cl_timestamp_counter:
                if has_attr_not_none(self, '_cl_obj'):
                    self._obj = self._cl_obj.get()
                if has_attr_not_none(self, '_cl_probe'):
                    self._probe = self._cl_probe.get()
                if has_attr_not_none(self, '_cl_scale'):
                    self.data.scale = self._cl_scale.get()
                if self._background is not None:
                    if has_attr_not_none(self, '_cl_background'):
                        self._background = self._cl_background.get()
            if self._timestamp_counter == self._cu_timestamp_counter:
                if has_attr_not_none(self, '_cu_obj'):
                    self._obj = self._cu_obj.get()
                if has_attr_not_none(self, '_cu_probe'):
                    self._probe = self._cu_probe.get()
                if has_attr_not_none(self, '_cu_scale'):
                    self.data.scale = self._cu_scale.get()
                if self._background is not None:
                    if has_attr_not_none(self, '_cu_background'):
                        self._background = self._cu_background.get()
            self._cpu_timestamp_counter = self._timestamp_counter

    def get_obj(self):
        """
        Get the object data array. This will automatically get the latest data, either from GPU or from the host
        memory, depending where the last changes were made.

        :param shift: if True, the data array will be fft-shifted so that the center of the data is in the center
                      of the array, rather than in the corner (the default).
        :return: the 3D numpy data array (nb object modes, nyo, nxo)
        """
        self.from_pu()
        return self._obj

    def set_obj(self, obj):
        """
        Set the object data array.

        :param obj: the object (complex64 numpy array)
        :return: nothing
        """
        self._obj = obj.astype(np.complex64)
        self._timestamp_counter += 1
        self._cpu_timestamp_counter = self._timestamp_counter
        self.calc_regularisation_scale()

    def set_obj_zero_phase_mask(self, mask):
        """
        Set an object mask, which has the same 2D shape as the object, where values of 1 indicate that the area
        corresponds to vacuum (or air), and 0 corresponds to some material. Values between 0 and 1 can be given to
        smooth the transition.
        This mask will be used to restrain the corresponding area to a null phase, dampening the imaginary part
        at every object update.
        :param mask: a floating-point array with the same 2D shape as the object, where values of 1 indicate
        that the area corresponds to vacuum (or air), and 0 corresponds to the sample.
        :return: nothing
        """
        self._obj_zero_phase_mask = mask

    def get_probe(self):
        """
        Get the probe data array. This will automatically get the latest data, either from GPU or from the host
        memory, depending where the last changes were made.

        :return: the 3D probe numpy data array
        """
        self.from_pu()

        return self._probe

    def set_background(self, background):
        """
        Set the incoherent background data array.

        :param background: the background (float32 numpy array)
        :return: nothing
        """
        self._background = background.astype(np.float32)
        self._timestamp_counter += 1
        self._cpu_timestamp_counter = self._timestamp_counter

    def get_background(self):
        """
        Get the background data array. This will automatically get the latest data, either from GPU or from the host
        memory, depending where the last changes were made.

        :return: the 2D numpy data array
        """
        self.from_pu()
        return self._background

    def set_probe(self, probe):
        """
        Set the probe data array.

        :param probe: the probe (complex64 numpy array)
        :return: nothing
        """
        self._probe = probe.astype(np.complex64)
        self._timestamp_counter += 1
        self._cpu_timestamp_counter = self._timestamp_counter

    def __rmul__(self, x):
        """
        Multiply object (by a scalar).

        This is a placeholder for a function which will be replaced when importing either CUDA or OpenCL operators.
        If called before being replaced, will raise an error

        :param x: the scalar by which the wavefront will be multiplied
        :return:
        """
        if np.isscalar(x):
            raise OperatorException(
                "ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s. Did you import operators ?" % (str(x), str(self)))
        else:
            raise OperatorException("ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s." % (str(x), str(self)))

    def __mul__(self, x):
        """
        Multiply object (by a scalar).

        This is a placeholder for a function which will be replaced when importing either CUDA or OpenCL operators.
        If called before being replaced, will raise an error

        :param x: the scalar by which the wavefront will be multiplied
        :return:
        """
        if np.isscalar(x):
            raise OperatorException(
                "ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s. Did you import operators ?" % (str(self), str(x)))
        else:
            raise OperatorException("ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s." % (str(self), str(x)))

    def load_obj_probe_cxi(self, filename, entry=None, verbose=True):
        """
        Load object and probe from a CXI file, result of a previous optimisation. If no data is already present in
        the current object, then the pixel size and energy/wavelength are also loaded, and a dummy (one frame) data
        object is created.

        :param filename: the CXI filename from which to load the data
        :param entry: the entry to be read. By default, the last in the file is loaded. Can be 'entry_1', 'entry_10'...
                      An integer n can also be supplied, in which case 'entry_%d' % n will be read
        :return:
        """
        f = h5py.File(filename, 'r')
        if entry is None:
            i = 1
            while True:
                if 'entry_%d' % i not in f:
                    break
                i += 1
            entry = f["entry_%d" % (i - 1)]
        elif isinstance(entry, int):
            entry = f["entry_%d" % entry]
        else:
            entry = f[entry]

        self.entry = entry  # Debug

        self.set_obj(entry['object/data'][()])
        if verbose:
            print("CXI: Loaded object with shape: ", self.get_obj().shape)

        self.set_probe(entry['probe/data'][()])
        if verbose:
            print("CXI: Loaded probe with shape: ", self.get_obj().shape)
        pixel_size_obj = (np.float32(entry['probe/x_pixel_size'][()])
                          + np.float32(entry['probe/y_pixel_size'][()])) / 2
        if verbose:
            print("CXI: object pixel size (m): ", pixel_size_obj)

        if 'mask' in entry['result_1']:
            self.scan_area_obj = entry['result_1/mask'][()] > 0
            if verbose:
                print("CXI: Loaded scan_area_obj")

        if 'mask' in entry['result_2']:
            self.scan_area_probe = entry['result_2/mask'][()] > 0
            if verbose:
                print("CXI: Loaded scan_area_probe")

        if 'background' in entry:
            self.set_background(entry['background/data'][()])
            if verbose:
                print("CXI: Loaded background")

        if self.data is None:
            ny, nx = self.get_probe().shape[-2:]
            d = np.zeros((1, ny, nx), dtype=np.float32)
            nrj = np.float32(entry['instrument_1/source_1/energy'][()])
            wavelength = 12.3984 / (nrj / 1.60218e-16) * 1e-10
            detector_distance = np.float32(entry['instrument_1/detector_1/distance'][()])
            x_pixel_size = np.float32(entry['instrument_1/detector_1/x_pixel_size'][()])
            y_pixel_size = np.float32(entry['instrument_1/detector_1/y_pixel_size'][()])
            pxy = (x_pixel_size + y_pixel_size) / 2

            # Check consistency of energy values (bug in versions priors to git 2018-10-12)
            px_obj = wavelength * detector_distance / (nx * pxy)
            if pixel_size_obj / px_obj > 1e9:
                print("Correcting for incorrect energy stored in file (energy 1e10 too large)")
                nrj *= 1e-10
                wavelength *= 1e10
            if verbose:
                print("CXI: wavelength (m): ", wavelength)
                print("CXI: detector pixel size (m): ", pxy)
                print("CXI: detector distance (m): ", detector_distance)
                print("CXI: created a dummy iobs with only one frame")

            self.data = PtychoData(iobs=d, positions=([0], [0], [0]), detector_distance=detector_distance,
                                   pixel_size_detector=pxy, wavelength=wavelength)
            self.pixel_size_object = np.float32(self.data.pixel_size_object()[0])
            self.nb_frame_total = 1
        self.calc_regularisation_scale()

    def save_obj_probe_cxi(self, filename, sample_name=None, experiment_id=None, instrument=None, note=None,
                           process=None, append=False, shift_phase_zero=False, params=None):
        """
        Save the result of the optimisation (object, probe, scan areas) to an HDF5 CXI-like file.
        
        :param filename: the file name to save the data to
        :param sample_name: optional, the sample name
        :param experiment_id: the string identifying the experiment, e.g.: 'HC1234: Siemens star calibration tests'
        :param instrument: the string identifying the instrument, e.g.: 'ESRF id10'
        :param note: a string with a text note giving some additional information about the data, a publication...
        :param process: a dictionary of strings which will be saved in '/entry_N/data_1/process_1'. A dictionary entry
                        can also be a 'note' as keyword and a dictionary as value - all key/values will then be saved
                        as separate notes. Example: process={'program': 'PyNX', 'note':{'llk':1.056, 'nb_photons': 1e8}}
        :param append: by default (append=False), any existing file will be overwritten, and the result will be saved
                       as 'entry_1'. If append==True and the file exists, a new entry_N will be saved instead.
                       This can be used to export the different steps of an optimisation.
        :param shift_phase_zero: if True, remove the linear phase ramp from the object
        :param params: a dictionary of parameters to be saved into process_1/configuration NXcollection
        :return: Nothing. a CXI file is created
        """
        if append:
            f = h5py.File(filename, "a")
            if "cxi_version" not in f:
                f.create_dataset("cxi_version", data=150)
            i = 1
            while True:
                if 'entry_%d' % i not in f:
                    break
                i += 1
            entry = f.create_group("/entry_%d" % i)
            entry_path = "/entry_%d" % i
            f.attrs['default'] = "/entry_%d" % i
            if "/entry_last" in entry:
                del entry["/entry_last"]
            entry["/entry_last"] = h5py.SoftLink("/entry_%d" % i)
        else:
            f = h5py.File(filename, "w")
            f.create_dataset("cxi_version", data=150)
            entry = f.create_group("/entry_1")
            entry_path = "/entry_1"
            f.attrs['default'] = 'entry_1'
            entry["/entry_last"] = h5py.SoftLink("/entry_1")
        f.attrs['creator'] = 'PyNX'
        # f.attrs['NeXus_version'] = '2018.5'  # Should only be used when the NeXus API has written the file
        f.attrs['HDF5_Version'] = h5py.version.hdf5_version
        f.attrs['h5py_version'] = h5py.version.version
        entry.attrs['NX_class'] = 'NXentry'
        entry.attrs['default'] = 'data_1'

        entry.create_dataset("program_name", data="PyNX %s" % (__version__))
        entry.create_dataset("start_time", data=time.strftime("%Y-%m-%dT%H:%M:%S%z", time.localtime(time.time())))
        if experiment_id is not None:
            entry.create_dataset("experiment_id", data=experiment_id)

        if note is not None:
            note_1 = entry.create_group("note_1")
            note_1.attrs['NX_class'] = 'NXnote'
            note_1.create_dataset("data", data=note)
            note_1.create_dataset("type", data="text/plain")

        if sample_name is not None:
            sample_1 = entry.create_group("sample_1")
            sample_1.attrs['NX_class'] = 'NXsample'
            sample_1.create_dataset("name", data=sample_name)

        if shift_phase_zero:
            # Get the object and center the phase around 0
            obj = phase.shift_phase_zero(self.get_obj(), percent=2, origin=0, mask=self.scan_area_obj)
        else:
            obj = self.get_obj()

        probe = self.get_probe()

        # Store object in result_1
        result_1 = entry.create_group("result_1")
        result_1.attrs['NX_class'] = 'NXdata'
        result_1.attrs['signal'] = 'data'
        result_1.attrs['note'] = 'Reconstructed object'
        result_1.attrs['interpretation'] = 'image'
        entry["object"] = h5py.SoftLink(entry_path + '/result_1')  # Unorthodox, departs from specification ?
        result_1.create_dataset("data", data=obj, chunks=True, shuffle=True, compression="gzip")
        result_1.create_dataset("data_type", data="electron density")
        result_1.create_dataset("data_space", data="real")
        ny, nx = obj.shape[-2:]
        px, py = self.data.pixel_size_object()
        result_1.create_dataset("image_size", data=[px * nx, py * ny])
        # Store object pixel size (not in CXI specification)
        result_1.create_dataset("x_pixel_size", data=px)
        result_1.create_dataset("y_pixel_size", data=py)

        if self.scan_area_obj is not None:
            # Using 'mask' from CXI specification: 0x00001000 == CXI_PIXEL_HAS_SIGNAL 'pixel signal above background'
            s = (self.scan_area_obj > 0).astype(np.int) * 0x00001000
            result_1.create_dataset("mask", data=s, chunks=True, shuffle=True, compression="gzip")

        # Store probe in result_2
        result_2 = entry.create_group("result_2")
        result_2.attrs['NX_class'] = 'NXdata'
        result_2.attrs['signal'] = 'data'
        result_2.attrs['note'] = 'Reconstructed probe'
        result_2.attrs['interpretation'] = 'image'
        entry["probe"] = h5py.SoftLink(entry_path + '/result_2')  # Unorthodox, departs from specification ?
        result_2.create_dataset("data", data=probe, chunks=True, shuffle=True, compression="gzip")
        result_2.create_dataset("data_space", data="real")
        ny, nx = probe.shape[-2:]
        result_2.create_dataset("image_size", data=[px * nx, py * ny])
        # Store probe pixel size (not in CXI specification)
        result_2.create_dataset("x_pixel_size", data=px)
        result_2.create_dataset("y_pixel_size", data=py)

        if self.scan_area_probe is not None:
            # Using 'mask' from CXI specification: 0x00001000 == CXI_PIXEL_HAS_SIGNAL 'pixel signal above background'
            s = (self.scan_area_probe > 0).astype(np.int) * 0x00001000
            result_2.create_dataset("mask", data=s, chunks=True, shuffle=True, compression="gzip")

        if self.get_background() is not None:
            result_3 = entry.create_group("result_3")
            result_3.attrs['NX_class'] = 'NXdata'
            result_3.attrs['signal'] = 'data'
            result_3.attrs['note'] = 'Incoherent background'
            result_3.attrs['interpretation'] = 'image'
            entry["background"] = h5py.SoftLink(entry_path + '/result_3')  # Unorthodox, departs from specification ?
            result_3.create_dataset("data", data=np.fft.fftshift(self.get_background()), chunks=True, shuffle=True,
                                    compression="gzip")
            result_3.create_dataset("data_space", data="diffraction")
            result_3.create_dataset("data_type", data="intensity")

        if self.data is not None:
            if self.data.scale is not None:
                if np.allclose(self.data.scale, 1) is False:
                    result_4 = entry.create_group("result_4")
                    result_4.attrs['NX_class'] = 'NXdata'
                    result_4.attrs['signal'] = 'data'
                    result_4.attrs['note'] = 'FLoating intensities for each frame'
                    result_4.attrs['interpretation'] = 'spectrum'
                    entry["floating_intensity"] = h5py.SoftLink(entry_path + '/result_4')
                    result_4.create_dataset("data", data=self.data.scale, compression="gzip")
                    result_4.create_dataset("data_space", data="diffraction")
                    result_4.create_dataset("data_type", data="scale")

        instrument_1 = entry.create_group("instrument_1")
        instrument_1.attrs['NX_class'] = 'NXinstrument'
        if instrument is not None:
            instrument_1.create_dataset("name", data=instrument)

        nrj = 12.3984 / (self.data.wavelength * 1e10)
        source_1 = instrument_1.create_group("source_1")
        source_1.attrs['NX_class'] = 'NXsource'
        source_1.attrs['note'] = 'Incident photon energy (instead of source energy), for CXI compatibility'
        source_1.create_dataset("energy", data=nrj * 1.60218e-16)  # in J
        source_1["energy"].attrs['units'] = 'J'

        beam_1 = instrument_1.create_group("beam_1")
        beam_1.attrs['NX_class'] = 'NXbeam'
        beam_1.create_dataset("incident_energy", data=nrj * 1.60218e-16)
        beam_1["incident_energy"].attrs['units'] = 'J'
        beam_1.create_dataset("incident_wavelength", data=self.data.wavelength)
        beam_1["incident_wavelength"].attrs['units'] = 'm'

        detector_1 = instrument_1.create_group("detector_1")
        detector_1.attrs['NX_class'] = 'NXdetector'
        detector_1.create_dataset("distance", data=self.data.detector_distance)
        detector_1["distance"].attrs['units'] = 'm'

        detector_1.create_dataset("x_pixel_size", data=self.data.pixel_size_detector)
        detector_1["x_pixel_size"].attrs['units'] = 'm'
        detector_1.create_dataset("y_pixel_size", data=self.data.pixel_size_detector)
        detector_1["y_pixel_size"].attrs['units'] = 'm'

        # Add shortcut to the main data
        data_1 = entry.create_group("data_1")
        data_1.attrs['NX_class'] = 'NXdata'
        data_1.attrs['signal'] = 'data'
        data_1.attrs['interpretation'] = 'image'
        data_1["data"] = h5py.SoftLink(entry_path + '/result_1/data')

        command = ""
        for arg in sys.argv:
            command += arg + " "
        process_1 = entry.create_group("process_1")
        process_1.attrs['NX_class'] = 'NXprocess'
        process_1.create_dataset("program", data='PyNX')  # NeXus spec
        process_1.create_dataset("version", data="%s" % __version__)  # NeXus spec
        process_1.create_dataset("command", data=command)  # CXI spec

        if process is not None:
            for k, v in process.items():
                if isinstance(v, str) and k not in process_1:
                    process_1.create_dataset(k, data=v)
                elif isinstance(v, dict) and k == 'note':
                    # Save this as notes:
                    for kk, vv in v.items():
                        i = 1
                        while True:
                            note_s = 'note_%d' % i
                            if note_s not in process_1:
                                break
                            i += 1
                        note = process_1.create_group(note_s)
                        note.create_dataset("data", data=str(vv))
                        note.create_dataset("description", data=kk)
                        note.create_dataset("type", data="text/plain")
        # Configuration of process: custom ESRF data policy
        # see https://gitlab.esrf.fr/sole/data_policy/blob/master/ESRF_NeXusImplementation.rst
        if params is not None or self._obj_zero_phase_mask is not None:
            config = process_1.create_group("configuration")
            config.attrs['NX_class'] = 'NXcollection'
            if params is not None:
                for k, v in params.items():
                    if v is not None:
                        if type(v) is dict:
                            # This can happen if complex configuration is passed on
                            if len(v):
                                kd = config.create_group(k)
                                kd.attrs['NX_class'] = 'NXcollection'
                                for kk, vv in v.items():
                                    kd.create_dataset(kk, data=vv)
                        else:
                            config.create_dataset(k, data=v)

            if self._obj_zero_phase_mask is not None:
                config.create_dataset("obj_zero_phase_mask", data=self._obj_zero_phase_mask, chunks=True, shuffle=True,
                                      compression="gzip")
                config["obj_zero_phase_mask"].attrs['note'] = 'Weighted mask of region restrained to real values'

        # Configuration & results of process: custom ESRF data policy
        # see https://gitlab.esrf.fr/sole/data_policy/blob/master/ESRF_NeXusImplementation.rst
        results = process_1.create_group("results")
        results.attrs['NX_class'] = 'NXcollection'
        results.create_dataset('llk_poisson', data=self.llk_poisson / self.nb_obs)
        results.create_dataset('llk_gaussian', data=self.llk_poisson / self.nb_obs)
        results.create_dataset('llk_euclidian', data=self.llk_poisson / self.nb_obs)
        results.create_dataset('nb_photons_calc', data=self.nb_photons_calc)
        results.create_dataset('cycle_history', data=self.history.as_numpy_record_array())
        for k in self.history.keys():
            results.create_dataset('cycle_history_%s' % k, data=self.history[k].as_numpy_record_array())

        f.close()

    def reset_history(self):
        """
        Reset history, and set current cycle to zero
        :return: nothing
        """
        self.history = History()
        self.cycle = 0

    def update_history(self, mode='llk', update_obj=False, update_probe=False, verbose=False, **kwargs):
        """
        Update the history record.
        :param mode: either 'llk' (will record new log-likelihood and number of photons)
                     or 'algorithm' (will only update the algorithm) - for the latter case, algorithm
                     should be given as a keyword argument
        :param verbose: if True, print some info about current process (only if mode=='llk')
        :param kwargs: other parameters to be recorded, e.g. probe_inertia=xx, dt=xx, algorithm='DM'
        :return: nothing
        """
        if mode == 'llk':
            algo = ''
            dt = 0
            if 'algorithm' in kwargs:
                algo = kwargs['algorithm']
            if 'dt' in kwargs:
                dt = kwargs['dt']
            if verbose:
                s = algo_string(algo, self, update_obj, update_probe, False, False)
                print("%-10s #%3d LLK= %8.2f(p) %8.2f(g) %8.2f(e), nb photons=%e, dt/cycle=%5.3fs"
                      % (s, self.cycle, self.llk_poisson / self.nb_obs, self.llk_gaussian / self.nb_obs,
                         self.llk_euclidian / self.nb_obs, self.nb_photons_calc, dt))

            self.history.insert(self.cycle, llk_poisson=self.llk_poisson / self.nb_obs,
                                llk_gaussian=self.llk_gaussian / self.nb_obs,
                                llk_euclidian=self.llk_euclidian / self.nb_obs, nb_photons_calc=self.nb_photons_calc,
                                nb_obj=len(self._obj), nb_probe=len(self._probe), **kwargs)
        elif 'algo' in mode:
            if 'algorithm' in kwargs:
                self.history.insert(self.cycle, algorithm=kwargs['algorithm'])


def save_ptycho_data_cxi(file_name, iobs, pixel_size, wavelength, detector_distance, x, y, z=None, monitor=None,
                         mask=None, instrument="", overwrite=False, scan=None, params=None, verbose=False, **kwargs):
    """
    Save the Ptychography scan data using the CXI format (see http://cxidb.org)

    :param file_name: the file name (including relative or full path) to save the data to
    :param iobs: the observed intensity, with shape (nb_frame, ny, nx)
    :param pixel_size: the detector pixel size in meters
    :param wavelength: the experiment wavelength
    :param x: the x scan positions
    :param y: the y scan positions
    :param z: the z scan positions (default=None)
    :param monitor: the monitor
    :param mask: the mask for the observed frames
    :param instrument: a string with the name of the instrument (e.g. 'ESRF id16A')
    :param overwrite: if True, will overwrite an existing file
    :param params: a dictionary of parameters which will be saved as a NXcollection
    :param verbose: if True, print something.
    :return:
    """
    path = os.path.split(file_name)[0]
    if len(path):
        os.makedirs(path, exist_ok=True)
    if os.path.isfile(file_name) and overwrite is False:
        print("CXI file already exists, not overwriting: ", file_name)
        os.system('ls -la %s' % file_name)
        return
    else:
        print('Creating CXI file: %s' % file_name)


    f = h5py.File(file_name, "w")
    f.attrs['file_name'] = file_name
    f.attrs['file_time'] = time.strftime("%Y-%m-%dT%H:%M:%S%z", time.localtime(time.time()))
    if instrument is not None:
        f.attrs['instrument'] = instrument
    f.attrs['creator'] = 'PyNX'
    # f.attrs['NeXus_version'] = '2018.5'  # Should only be used when the NeXus API has written the file
    f.attrs['HDF5_Version'] = h5py.version.hdf5_version
    f.attrs['h5py_version'] = h5py.version.version
    f.attrs['default'] = 'entry_1'
    f.create_dataset("cxi_version", data=140)

    entry_1 = f.create_group("entry_1")
    entry_1.create_dataset("program_name", data="PyNX %s" % (__version__))
    entry_1.create_dataset("start_time", data=time.strftime("%Y-%m-%dT%H:%M:%S%z", time.localtime(time.time())))
    entry_1.attrs['NX_class'] = 'NXentry'
    # entry_1.create_dataset('title', data='1-D scan of I00 v. mr')
    entry_1.attrs['default'] = 'data_1'

    sample_1 = entry_1.create_group("sample_1")
    sample_1.attrs['NX_class'] = 'NXsample'

    geometry_1 = sample_1.create_group("geometry_1")
    sample_1.attrs['NX_class'] = 'NXgeometry'  # Deprecated NeXus class, move to NXtransformations
    xyz = np.zeros((3, x.size), dtype=np.float32)
    xyz[0] = x
    xyz[1] = y
    geometry_1.create_dataset("translation", data=xyz)

    data_1 = entry_1.create_group("data_1")
    data_1.attrs['NX_class'] = 'NXdata'
    data_1.attrs['signal'] = 'data'
    data_1.attrs['interpretation'] = 'image'
    data_1["translation"] = h5py.SoftLink('/entry_1/sample_1/geometry_1/translation')
    if monitor is not None:
        data_1.create_dataset("monitor", monitor)

    instrument_1 = entry_1.create_group("instrument_1")
    instrument_1.attrs['NX_class'] = 'NXinstrument'
    if instrument is not None:
        instrument_1.create_dataset("name", data=instrument)

    source_1 = instrument_1.create_group("source_1")
    source_1.attrs['NX_class'] = 'NXsource'
    nrj = 12384e-10 / wavelength * 1.60218e-19
    source_1.create_dataset("energy", data=nrj)  # in J
    source_1["energy"].attrs['note'] = 'Incident photon energy (instead of source energy), for CXI compatibility'

    detector_1 = instrument_1.create_group("detector_1")
    detector_1.attrs['NX_class'] = 'NX_detector'

    nz, ny, nx = iobs.shape
    detector_1.create_dataset("data", data=iobs, chunks=(1, ny, nx), shuffle=True,
                              compression="gzip")
    detector_1.create_dataset("distance", data=detector_distance)
    detector_1["distance"].attrs['units'] = 'm'
    detector_1.create_dataset("x_pixel_size", data=pixel_size)
    detector_1["x_pixel_size"].attrs['units'] = 'm'
    detector_1.create_dataset("y_pixel_size", data=pixel_size)
    detector_1["y_pixel_size"].attrs['units'] = 'm'
    if mask is not None:
        if mask.sum() != 0:
            detector_1.create_dataset("mask", data=mask, chunks=True, shuffle=True, compression="gzip")
            detector_1["mask"].attrs['note'] = "Mask of invalid pixels, applying to each frame"
    # Basis vector - this is the default CXI convention, so could be skipped
    # This corresponds to a 'top, left' origin convention
    basis_vectors = np.zeros((2, 3), dtype=np.float32)
    basis_vectors[0, 1] = -pixel_size
    basis_vectors[1, 0] = -pixel_size
    detector_1.create_dataset("basis_vectors", data=basis_vectors)

    detector_1["translation"] = h5py.SoftLink('/entry_1/sample_1/geometry_1/translation')
    data_1["data"] = h5py.SoftLink('/entry_1/instrument_1/detector_1/data')

    # Remember how import was done
    command = ""
    for arg in sys.argv:
        command += arg + " "
    process_1 = data_1.create_group("process_1")
    process_1.attrs['NX_class'] = 'NXprocess'
    process_1.create_dataset("program", data='PyNX')  # NeXus spec
    process_1.create_dataset("version", data="%s" % __version__)  # NeXus spec
    process_1.create_dataset("command", data=command)  # CXI spec
    config = process_1.create_group("configuration")
    config.attrs['NX_class'] = 'NXcollection'
    if params is not None:
        for k, v in params.items():
            if k == 'scan' and scan is not None:
                continue
            if v is not None:
                if type(v) is dict:
                    # This can happen if complex configuration is passed on
                    if len(v):
                        kd = config.create_group(k)
                        kd.attrs['NX_class'] = 'NXcollection'
                        for kk, vv in v.items():
                            kd.create_dataset(kk, data=vv)
                else:
                    config.create_dataset(k, data=v)
    if scan is not None:
        config.create_dataset('scan', data=scan)

    f.close()


class OperatorPtycho(Operator):
    """
    Base class for an operator on Ptycho2D objects.
    """

    def timestamp_increment(self, p):
        # By default CPU operators should increment the CPU counter. Unless they don't affect the pty object, like
        # all display operators.
        p._timestamp_counter += 1
        p._cpu_timestamp_counter = p._timestamp_counter


def algo_string(algo_base, p, update_object, update_probe, update_background=False, update_pos=False):
    """
    Get a short string for the algorithm being run, e.g. 'DM/o/3p' for difference map with 1 object and 3 probe modes.

    :param algo_base: 'AP' or 'ML' or 'DM'
    :param p: the ptycho object
    :param update_object: True if updating the object
    :param update_probe: True if updating the probe
    :return: a short string for the algorithm
    """
    s = algo_base

    if update_object:
        s += "/"
        if len(p._obj) > 1:
            s += "%d" % (len(p._obj))
        s += "o"

    if update_probe:
        s += "/"
        if len(p._probe) > 1:
            s += "%d" % (len(p._probe))
        s += "p"

    if update_background:
        s += "/b"

    return s
