# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

__all__ = ['Bragg2DPtychoData', 'Bragg2DPtycho', 'OperatorBragg2DPtycho']

import numpy as np
from scipy.interpolate import RegularGridInterpolator
from ...operator import Operator, has_attr_not_none


# TODO: Merge as much as possible with the 3D Bragg case (almost identical)

def rotation_matrix(axis, angle):
    """
    Creates a numpy rotation matrix. The convention is the NeXus one so that a positive rotation of +pi/2:
    - with axis='x', transforms +y into z
    - with axis='y', transforms +z into x
    - with axis='z', transforms +x into y
    :param axis: the rotation axis, either 'x', 'y' or 'z'
    :param angle: the rotation angle in radians
    :return: the rotation matrix
    """
    c, s = np.cos(angle), np.sin(angle)
    if axis.lower() == 'x':
        return np.matrix([[1, 0, 0], [0, c, s], [0, -s, c]], dtype=np.float32)
    elif axis.lower() == 'y':
        return np.matrix([[c, 0, s], [0, 1, 0], [-s, 0, c]], dtype=np.float32)
    else:
        return np.matrix([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=np.float32)


def rotate(m, x, y, z):
    """
    Perform a rotation given a rotation matrix and x, y, z coordinates (which can be arrays)
    :param m:
    :param x:
    :param y:
    :param z:
    :return: a tuple of x, y, z coordinated after rotation
    """
    return m[0, 0] * x + m[0, 1] * y + m[0, 2] * z, \
           m[1, 0] * x + m[1, 1] * y + m[1, 2] * z, \
           m[2, 0] * x + m[2, 1] * y + m[2, 2] * z


class Bragg2DPtychoData(object):
    """Class for two-dimensional ptychographic data: observed diffraction and probe positions.
    This may include only part of the data from a larger dataset.
    """

    def __init__(self, iobs=None, positions=None, mask=None, wavelength=None, detector=None, scattering_vector=None):
        """
        Init function. The orthonormal coordinate system to be used for the sample position follows the
        NeXus/CXI/McStas convention:
        - z parallel to the incident beam, downstream
        - y: perpendicular to the direct beam and up
        - x: perpendicular to the direct beam and towards the left when seen from the source.

        :param iobs: 3d array with (nb_frame, ny,nx) shape observed intensity (assumed to follow Poisson statistics).
                     Data is assumed to be centered on the detector, and will be fft-shifted to be centered in (0,0).
        :param positions: (z, y, x) tuple or 2d array with ptycho probe positions in meters. The coordinates
                          must be in the laboratory reference frame given above.
        :param mask: 2D mask (>0 means masked pixel) for the observed data. Can be None.
        :param wavelength: wavelength of the experiment, in meters.
        :param detector= {rotation_axes:(('x', 0), ('y', pi/4)), 'pixel_size':55e-6, 'distance':1}:
               parameters for the detector as a dictionary. The rotation axes giving the detector orientation
               will be applied in order, i.e. to find the detector direction, a vector with x=y=0, z=1 is rotated
               by the axes in the given order.
        :param scattering_vector=(hz, hy, hx): the coordinates of the central scattering vector for all frames, taking
               into account any tilt for multi-angle Bragg projection ptychography. The scattering vector coordinates
               units should be inverse meters (without 2pi multiplication), i.e. the norm should be 2*sin(theta)/lambda,
               where theta is the Bragg angle. Note that only the difference between the scattering vector and the
               average (intensity-weighted) scattering vector is used for multi-orientation back-projection.
               [Default: None, all frames are taken with the same orientation]
        """
        # TODO: alternatively, instead of supplying one detector geometry and scattering vectors which can vary from
        # frame to frame, rather allow some geometry parameters to be vectors to be more general ? This
        # would require including detector crop coordinates and rotation axes...

        if scattering_vector is not None:
            hz, hy, hx = scattering_vector
            hx0 = hx.mean()
            hy0 = hy.mean()
            hz0 = hz.mean()
            if iobs is not None:
                iobs_sum = iobs.sum(axis=(1, 2))
                s = iobs_sum.sum()
                if s > 1000:  # Avoid the special case with null data for simulation
                    hx0 = (hx * iobs_sum).sum() / s
                    hy0 = (hy * iobs_sum).sum() / s
                    hz0 = (hz * iobs_sum).sum() / s
            self.s0 = (hx0, hy0, hz0)  # Average scattering vector
            self.ds = (hx - hx0, hy - hy0, hz - hz0)  # Difference with average scattering vector
        else:
            self.s0 = (0, 0, 0)
            tmp = np.zeros(len(iobs), dtype=np.float32)
            self.ds = tmp.copy(), tmp.copy(), tmp.copy()

        if iobs is not None:
            self.iobs = np.fft.fftshift(iobs, axes=(-2, -1)).astype(np.float32)
            # Total nb of photons is used for regularization
            self.iobs_sum = self.iobs.sum()
            if mask is not None:
                mask3D = np.repeat(mask[np.newaxis, :, :], iobs.shape[0],
                                   axis=0)  # tile mask to be a 3D array with (nb_frame, ny, nx) shape
                self.mask = np.fft.fftshift(mask3D.astype(np.int8), axes=(-2, -1))
                self.iobs[self.mask > 0] = -100
            else:
                self.mask = None
        else:
            self.iobs = None

        self.wavelength = wavelength
        self.posz, self.posy, self.posx = positions
        self.posx -= self.posx.mean()
        self.posy -= self.posy.mean()
        self.posz -= self.posz.mean()
        self.detector = detector
        self.m = np.matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
        for ax, angle in detector['rotation_axes']:
            self.m *= rotation_matrix(ax, angle)
        self.im = np.linalg.inv(self.m)

        # Calculate ds coordinate in detector reference frame
        self.ds1 = rotate(self.im, self.ds[0], self.ds[1], self.ds[2])


class Bragg2DPtycho(object):
    """ Class for 2D Bragg ptychography data: object, probe, and observed diffraction.
    This may include only part of the data from a larger dataset
    """

    def __init__(self, probe=None, data=None, support=None, background=None):
        """
        Constructor.
        :param probe: the starting estimate of the probe, as a pynx wavefront object - can be 3D if modes are used.
        :param data: the Bragg2DPtychoData object with all observed frames, ptycho positions
        :param support: the support of the object (1 inside, 0 outside) the object will be constrained to
        """
        self._probe2d = probe  # The original 2D probe as a Wavefront object
        self._probe = None  # This will hold the probe as projected onto the 3D object
        self._obj = None
        self.support = support
        self.data = data
        self._background = background

        # Matrix transformation from array indices of the array obtained by inverse Fourier Transform
        # to xyz in the laboratory frame. Different from self.data.m, which is a pure rotation matrix
        self.m = None
        # Inverse of self.m
        self.im = None

        # Voxel size in object and probe space
        self.px1 = None
        self.py1 = None
        self.pz1 = None

        # Stored variables
        # self.scan_area_obj = None
        # self.scan_area_probe = None
        # self.scan_area_points = None
        self.llk_poisson = 0
        self.llk_gaussian = 0
        self.llk_euclidian = 0
        self.nb_photons_calc = 0
        self.nb_obs = self.data.iobs.size
        if self.data.mask is not None:
            self.nb_obs *= (self.data.mask == 0).sum() / float(self.data.mask.size)

        # The timestamp counters record when the data was last altered, either in the host or the GPU memory.
        self._timestamp_counter = 1
        self._cpu_timestamp_counter = 1
        self._cl_timestamp_counter = 0
        self._cu_timestamp_counter = 0
        self.prepare()

    def from_pu(self):
        """
        Get all relevant arrays from processing unit, if necessary
        :return: Nothing
        """
        if self._cpu_timestamp_counter < self._timestamp_counter:
            if self._timestamp_counter == self._cl_timestamp_counter:
                if has_attr_not_none(self, '_cl_obj'):
                    self._obj = self._cl_obj.get()
                if has_attr_not_none(self, '_cl_probe'):
                    self._probe = self._cl_probe.get()
                if self._background is not None:
                    if has_attr_not_none(self, '_cl_background'):
                        self._background = self._cl_background.get()
            if self._timestamp_counter == self._cu_timestamp_counter:
                if has_attr_not_none(self, '_cu_obj'):
                    self._obj = self._cu_obj.get()
                if has_attr_not_none(self, '_cu_probe'):
                    self._probe = self._cu_probe.get()
                if self._background is not None:
                    if has_attr_not_none(self, '_cu_background'):
                        self._background = self._cu_background.get()
            self._cpu_timestamp_counter = self._timestamp_counter

    def get_probe(self):
        """
        Get the probe data array. This will automatically get the latest data, either from GPU or from the host
        memory, depending where the last changes were made.

        :param shift: if True, the data array will be fft-shifted so that the center of the data is in the center
                      of the array, rather than in the corner (the default).
        :return: the 3D numpy data array (nb object modes, nyo, nxo)
        """
        self.from_pu()
        return self._probe

    def get_obj(self):
        """
        Get the object data array. This will automatically get the latest data, either from GPU or from the host
        memory, depending where the last changes were made.

        :param shift: if True, the data array will be fft-shifted so that the center of the data is in the center
                      of the array, rather than in the corner (the default).
        :return: the 3D numpy data array (nb object modes, nyo, nxo)
        """
        self.from_pu()
        return self._obj

    def set_obj(self, obj):
        """
        Set the object data array. This should either be a 3D array of the correct shape, or a 4D array where
        the first dimension are the object modes.

        :param obj: the object (complex64 numpy array)
        :return: nothing
        """
        if obj.ndim == 3:
            nz, ny, nx = obj.shape
            self._obj = obj.reshape((1, nz, ny, nx)).astype(np.complex64)
        else:
            self._obj = obj.astype(np.complex64)
        self._timestamp_counter += 1
        self._cpu_timestamp_counter = self._timestamp_counter

        # Test if object and support have the same dimensions
        if self.support is not None:
            assert self.support.shape == self._obj.shape[-3:]

    def set_support(self, sup, shrink_object_around_support=True):
        """
        Set the support data array. This should be a 3D array of the correct shape.

        :param sup: the support array. 0 outside support, 1 inside
        :param shrink_object_around_support: if True, will shrink the object and support array around the support
                                             volume. Note that the xyz coordinate may be shifted as a result, if
                                             the support was not centered.
        :return: nothing
        """
        if sup is not None:
            self.support = sup.astype(np.int8)
            self.support_sum = self.support.sum()
            self._timestamp_counter += 1
            self._cpu_timestamp_counter = self._timestamp_counter
            if shrink_object_around_support:
                self.shrink_object_around_support()

    def shrink_object_around_support(self):
        """
        Shrink the object around the tight support, to minimise the 3D object & support volume.
        :return:
        """
        nzo, nyo, nxo = self._obj.shape[-3:]
        z0, z1 = np.nonzero(self.support.sum(axis=(1, 2)))[0].take([0, -1])
        y0, y1 = np.nonzero(self.support.sum(axis=(0, 2)))[0].take([0, -1])
        x0, x1 = np.nonzero(self.support.sum(axis=(0, 1)))[0].take([0, -1])
        self.set_support(self.support[z0:z1 + 1, y0:y1 + 1, x0:x1 + 1], shrink_object_around_support=False)
        if self._obj is not None:
            self.set_obj(self._obj[:, z0:z1 + 1, y0:y1 + 1, x0:x1 + 1])

    def set_background(self, background):
        """
        Set the incoherent background data array.

        :param background: the background (float32 numpy array)
        :return: nothing
        """
        self._background = background.astype(np.float32)
        self._timestamp_counter += 1
        self._cpu_timestamp_counter = self._timestamp_counter

    def get_background(self, shift=False):
        """
        Get the background data array. This will automatically get the latest data, either from GPU or from the host
        memory, depending where the last changes were made.

        :param shift: if True, the data array will be fft-shifted so that the center of the data is in the center
                      of the array, rather than in the corner (the default).
        :return: the 2D numpy data array
        """
        self.from_pu()
        return self._background

    def prepare(self):
        """
        Calculate projection parameters
        :return:
        """
        self.calc_orthonormalisation_matrix()
        self.init_probe()
        self.init_obj()
        self.set_support(self.support)

    def calc_orthonormalisation_matrix(self):
        """
        Calculate the orthonormalisation matrix to convert probe/object array coordinates (pixel coordinates in the
        detector reference frame) to/from orthonormal ones in the laboratory reference frame.
        This also initialises the voxel sizes in object and probe space
        :return:
        """
        npos, ny, nx = self.data.iobs.shape
        # Pixel size. Assumes square frames from detector.
        d = self.data.detector['distance']
        lambdaz = self.data.wavelength * d
        p = self.data.detector['pixel_size']
        py, px = lambdaz / (p * ny), lambdaz / (p * nx)
        # TODO: Find method to estimate optimal pixel size along z. Would depend on probe shape and multiple angles..
        pz = max(px, py)
        self.m = self.data.m.copy()
        self.m[:, 0] *= -px  # - sign because image pixels origin is 'top, left' seen from sample
        self.m[:, 1] *= -py  # - sign because image pixels origin is 'top, left' seen from sample
        self.m[:, 2] *= pz
        self.im = np.linalg.inv(self.m)
        self.px1 = px
        self.py1 = py
        self.pz1 = pz

    def init_obj(self):
        """
        Initialize the object array
        :return: nothing. The object is created as an empty array
        """
        nzo, nyo, nxo = self.calc_obj_shape()
        print("Initialised object with %dx%dx%d voxels" % (nzo, nyo, nxo))
        self.set_obj(np.empty((1, nzo, nyo, nxo), dtype=np.complex64))

    def calc_probe_shape(self):
        """
        Calculate the probe shape, given the 2D probe and detector characteristics
        :return: the 3D probe shape (nz, ny, nx)
        """
        ny, nx = self.data.iobs.shape[-2:]
        # The number of points along z is calculated using the intersection of the back-projected 2D wavefront
        # from the detector and the propagated 2D wavefront.
        pixel_size_probe = self._probe2d.pixel_size
        # Extent of the probe along x and y in the laboratory reference frame
        nyp, nxp = self._probe2d.get().shape[-2:]
        dyp, dxp = nyp // 2 * pixel_size_probe, nxp // 2 * pixel_size_probe
        # How far must we extend the back-propagated 2D wavefront along the sample-detector axis to go beyond
        # the probe path ? Use the corners of the projected wavefront to test
        nx2 = nx / 2
        ny2 = ny / 2
        x, y, z = self.xyz_from_obj(np.array([-nx2, -nx2, nx2, nx2]), np.array([-ny2, ny2, -ny2, ny2]), 0)
        d20 = self.data.m[0, 2]
        if abs(d20) < 1e-10:
            d20 = 1e-10
        d21 = self.data.m[1, 2]
        if abs(d21) < 1e-10:
            d21 = 1e-10
        # print("d20", d20, "d21", d21)

        if d20 > 0:
            # corner with lowest x, intersection with upper x probe border: x + alpha*d20 > dxp
            i = x.argmin()
            xc, yc, zc = x[i], y[i], z[i]
            alpha = (dxp - xc) / d20
            # print("alpha_max_x: ", alpha, i)
            izmax_x = self.xyz_to_obj(xc + alpha * self.data.m[0, 2], yc + alpha * self.data.m[1, 2],
                                      zc + alpha * self.data.m[2, 2])[2]

            # corner with highest x, intersection with lower x probe border: x + alpha*d20 < -dxp
            i = x.argmax()
            xc, yc, zc = x[i], y[i], z[i]
            alpha = (-dxp - xc) / d20
            # print("alpha_min_x: ", alpha, i)
            izmin_x = self.xyz_to_obj(xc + alpha * self.data.m[0, 2], yc + alpha * self.data.m[1, 2],
                                      zc + alpha * self.data.m[2, 2])[2]
        else:
            # corner with highest x, intersection with lower x probe border: x + alpha*d20 > -dxp
            i = x.argmax()
            xc, yc, zc = x[i], y[i], z[i]
            alpha = (-dxp - xc) / d20
            # print("alpha_max_x: ", alpha, i)
            izmax_x = self.xyz_to_obj(xc + alpha * self.data.m[0, 2], yc + alpha * self.data.m[1, 2],
                                      zc + alpha * self.data.m[2, 2])[2]

            # corner with lowest x, intersection with higher x probe border: x + alpha*d20 < dxp
            i = x.argmin()
            xc, yc, zc = x[i], y[i], z[i]
            alpha = (dxp - xc) / d20
            # print("alpha_min_x: ", alpha, i)
            izmin_x = self.xyz_to_obj(xc + alpha * self.data.m[0, 2], yc + alpha * self.data.m[1, 2],
                                      zc + alpha * self.data.m[2, 2])[2]

        # Same along y
        if d21 > 0:
            # corner with lowest y, intersection with upper y probe border: y + alpha*d21 > dyp
            i = y.argmin()
            xc, yc, zc = x[i], y[i], z[i]
            alpha = (dyp - yc) / d21
            # print("alpha_max_y: ", alpha, i)
            izmax_y = self.xyz_to_obj(xc + alpha * self.data.m[0, 2], yc + alpha * self.data.m[1, 2],
                                      zc + alpha * self.data.m[2, 2])[2]

            # corner with highest y, intersection with lower y probe border: y + alpha*d21 < -dyp
            i = y.argmax()
            xc, yc, zc = x[i], y[i], z[i]
            alpha = (-dyp - yc) / d21
            # print("alpha_max_y: ", alpha, i)
            izmin_y = self.xyz_to_obj(xc + alpha * self.data.m[0, 2], yc + alpha * self.data.m[1, 2],
                                      zc + alpha * self.data.m[2, 2])[2]
        else:
            # corner with highest y, intersection with lower y probe border: y + alpha*d21 > -dyp
            i = y.argmax()
            xc, yc, zc = x[i], y[i], z[i]
            alpha = (-dyp - yc) / d21
            # print("alpha_max_y: ", alpha, i)
            izmax_y = self.xyz_to_obj(xc + alpha * self.data.m[0, 2], yc + alpha * self.data.m[1, 2],
                                      zc + alpha * self.data.m[2, 2])[2]

            # corner with lowest y, intersection with higher y probe border: y + alpha*d21 < dyp
            i = y.argmin()
            xc, yc, zc = x[i], y[i], z[i]
            alpha = (dyp - yc) / d21
            # print("alpha_min_y: ", alpha, i)
            izmin_y = self.xyz_to_obj(xc + alpha * self.data.m[0, 2], yc + alpha * self.data.m[1, 2],
                                      zc + alpha * self.data.m[2, 2])[2]
        # print("izmin_x: ", izmin_x)
        # print("izmax_x: ", izmax_x)
        # print("izmin_y: ", izmin_y)
        # print("izmax_y: ", izmax_y)
        # Final interval is smallest between x and y
        izmax = min(izmax_x, izmax_y)
        izmin = max(izmin_x, izmin_y)
        nz = izmax - izmin
        print("Calculated probe shape: ", nz, ny, nx)
        return nz, ny, nx

    def calc_obj_shape(self, margin=8, multiple=2):
        """
        Calculate the 3D object shape, given the detector, probe and scan characteristics.
        This must be called after the 3D probe has been initialized. Note that the final object shape will be shrunk
        around the support once it is given.
        :param margin: margin to extend the object area, in case the positions will change (optimization)
        :param multiple: the shape must be a multiple of that number. >=2
        :return: the 3D object shape (nzo, nyo, nxo)
        """
        probe_shape = self._probe.shape[1:]
        ix, iy, iz = self.xyz_to_obj(self.data.posx, self.data.posy, self.data.posz)

        nz = int(2 * (abs(np.ceil(iz)) + 1).max() + probe_shape[0])
        ny = int(2 * (abs(np.ceil(iy)) + 1).max() + probe_shape[1])
        nx = int(2 * (abs(np.ceil(ix)) + 1).max() + probe_shape[2])

        if margin is not None:
            nz += margin
            ny += margin
            nx += margin

        if multiple is not None:
            dz = nz % multiple
            if dz:
                nz += (multiple - dz)
            dy = ny % multiple
            if dy:
                ny += (multiple - dy)
            dx = nx % multiple
            if dx:
                nx += (multiple - dx)

        print("Calculated object shape: ", nz, ny, nx)
        return nz, ny, nx

    def xyz_to_obj(self, x, y, z):
        """
        Convert x,y,z coordinates from the laboratory reference frame to indices in the object array.
        :param x, y, z: laboratory frame coordinates in meters
        :return: (ix, iy, iz) coordinates in the array in the back-projected detector frame
        """
        ix = self.im[0, 0] * x + self.im[0, 1] * y + self.im[0, 2] * z
        iy = self.im[1, 0] * x + self.im[1, 1] * y + self.im[1, 2] * z
        iz = self.im[2, 0] * x + self.im[2, 1] * y + self.im[2, 2] * z
        return ix, iy, iz

    def xyz_from_obj(self, ix, iy, iz):
        """
        Convert x,y,z coordinates to the laboratory reference frame from indices in the 3D object array.
        :param ix, iy, iz: coordinates in the 3D object array.
        :return: (x, y, z) laboratory frame coordinates in meters
        """
        x = self.m[0, 0] * ix + self.m[0, 1] * iy + self.m[0, 2] * iz
        y = self.m[1, 0] * ix + self.m[1, 1] * iy + self.m[1, 2] * iz
        z = self.m[2, 0] * ix + self.m[2, 1] * iy + self.m[2, 2] * iz
        return x, y, z

    def get_xyz(self, rotation=None, domain='object'):
        """
        Get x,y,z orthonormal coordinates corresponding to the object grid.
        :param domain='probe': the domain over which the xyz coordinates should be returned. It should either
                               be 'object' (the default) or the probe, the only difference being that the object
                               is extended to cover all the volume scanned by the shifted probe. The probe
                               has the same size as the observed 3D data.
        :param rotation=('z',np.deg2rad(-20)): optionally, the coordinates can be obtained after a rotation of the
                                               object. This is useful if the object or support is to be defined as a
                                               parallelepiped, before being rotated to be in diffraction condition.
                                               The rotation can be given as a tuple of a rotation axis name (x, y or z)
                                               and a counter-clockwise rotation angle in radians.
        :return: a tuple of (x,y,z) coordinates, each a 3D array
        """
        if domain == "probe":
            # TODO: once probe has been calculated, use its shape rather than recalculate it ?
            nz, ny, nx = self.calc_probe_shape()
        elif domain == 'object' or domain == 'obj':
            nz, ny, nx = self._obj.shape[1:]
        else:
            raise Exception("BraggPtycho.get_xyz(): unknown domain '', should be 'object' or 'probe'" % domain)

        iz, iy, ix = np.meshgrid(np.arange(nz), np.arange(ny), np.arange(nx), indexing='ij')
        x, y, z = self.xyz_from_obj(ix, iy, iz)

        if rotation is not None:
            # TODO: allow multiple axis
            ax, ang = rotation
            c, s = np.cos(ang), np.sin(ang)
            if ax == 'x':
                y, z = c * y - s * z, c * z + s * y
            elif ax == 'y':
                z, x = c * z - s * x, c * x + s * z
            elif ax == 'z':
                x, y = c * x - s * y, c * y + s * x
            else:
                raise Exception("BraggPtycho.get_xyz_obj(): unknown rotation axis '%s'" % ax)

        # Assume the probe is centered on the object grid
        x -= x.mean()
        y -= y.mean()
        z -= z.mean()

        return x, y, z

    def voxel_size_object(self):
        """
        Get the object voxel size
        :return: the voxel size in meters as (pz, py, px)
        """
        tmp = self.xyz_from_obj(1, 0, 0)
        px = np.sqrt(tmp[0] ** 2 + tmp[1] ** 2 + tmp[2] ** 2)
        tmp = self.xyz_from_obj(0, 1, 0)
        py = np.sqrt(tmp[0] ** 2 + tmp[1] ** 2 + tmp[2] ** 2)
        tmp = self.xyz_from_obj(0, 0, 1)
        pz = np.sqrt(tmp[0] ** 2 + tmp[1] ** 2 + tmp[2] ** 2)
        return pz, py, px

    def init_probe(self):
        """
        Calculate the probe over the object volume, given the 2D probe, the data and the object coordinates,
        assuming that the probe is invariant along z.
        :return: Nothing. Creates self._probe
        """
        # TODO: move this to an operator and kernel for faster conversion ?
        print("Calculating probe on object grid")
        x, y, z = self.get_xyz(domain='probe')
        nz, ny, nx = x.shape
        z0, z1 = z.min(), z.max()

        # Create 3D probe array in the laboratory frame
        dz = z1 - z0
        pr2d = self._probe2d.get(shift=True)
        # TODO: take into account all probe modes
        if pr2d.ndim == 3:
            pr2d = pr2d[0]
        nyp, nxp = pr2d.shape
        if np.isclose(self.data.wavelength, self._probe2d.wavelength) is False:
            raise Exception('BraggPtycho: probe and data wavelength are different !')
        pixel_size_probe = self._probe2d.pixel_size
        pr = np.empty((nz, nyp, nxp), dtype=np.complex64)
        pr[:] = pr2d

        # Original probe coordinates
        zp, yp, xp = np.arange(nz), np.arange(nyp), np.arange(nxp)
        zp = (zp - zp.mean()) * (dz / nz)
        yp = (yp - yp.mean()) * pixel_size_probe
        xp = (xp - xp.mean()) * pixel_size_probe

        # Interpolate probe to object grid
        rgi = RegularGridInterpolator((zp, yp, xp), pr, method='linear', bounds_error=False, fill_value=0)
        self._probe = rgi(np.concatenate((z.reshape(1, z.size), y.reshape(1, y.size),
                                          x.reshape(1, x.size))).transpose()).reshape((1, nz, ny, nx)).astype(
            np.complex64)
        # To check view of object:
        # pcolormesh(z[:,:,100],y[:,:,100],abs(p._probe3d[:,:,100]))
        # np.savez_compressed('probe3d.npz', probe3dobj=self._probe, probe3d=pr, probe2d=pr2d, x=x, y=y, z=z)


class OperatorBragg2DPtycho(Operator):
    """
    Base class for an operator on Ptycho2D objects.
    """

    def timestamp_increment(self, p):
        # By default CPU operators increment the CPU counter. Unless they don't affect the pty object, like
        # display operators.
        p._timestamp_counter += 1
        p._cpu_timestamp_counter = p._timestamp_counter
