# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

import types
import timeit
import gc
import numpy as np
import pycuda.driver as cu_drv
import pycuda.gpuarray as cua
from pycuda.elementwise import ElementwiseKernel as CU_ElK
from pycuda.reduction import ReductionKernel as CU_RedK
from pycuda.compiler import SourceModule
import skcuda.fft as cu_fft
import pycuda.tools as cu_tools

from ..processing_unit import default_processing_unit as main_default_processing_unit
from ..processing_unit.cu_processing_unit import CUProcessingUnit
from ..processing_unit.kernel_source import get_kernel_source as getks
from ..operator import has_attr_not_none, OperatorSum, OperatorPower, OperatorException
from .ptycho import Ptycho, OperatorPtycho, algo_string
from . import cpu_operator as cpuop
from .shape import get_view_coord

my_float4 = cu_tools.get_or_register_dtype("my_float4",
                                           np.dtype([('a', '<f4'), ('b', '<f4'), ('c', '<f4'), ('d', '<f4')]))


################################################################################################
# Patch Ptycho class so that we can use 5*w to scale it.
# OK, so this might be ugly. There will definitely be issues if several types of operators
# are imported (e.g. OpenCL and CUDA)
# Solution (?): in a different sub-module, implement dynamical type-checking to decide which
# Scale() operator to call.

def patch_method(cls):
    def __rmul__(self, x):
        # Multiply object by a scalar.
        if np.isscalar(x) is False:
            raise OperatorException("ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s" % (str(x), str(self)))
        return Scale(x) * self

    def __mul__(self, x):
        # Multiply object by a scalar.
        if np.isscalar(x) is False:
            raise OperatorException("ERROR: attempted Op1 * Op2, with Op1=%s, Op2=%s" % (str(x), str(self)))
        return self * Scale(x)

    cls.__rmul__ = __rmul__
    cls.__mul__ = __mul__


patch_method(Ptycho)


################################################################################################


class CUProcessingUnitPtycho(CUProcessingUnit):
    """
    Processing unit in CUDA space, for 2D Ptycho operations.

    Handles initializing the context and kernels.
    """

    def __init__(self):
        super(CUProcessingUnitPtycho, self).__init__()
        # Size of the stack size used on the GPU - can be any integer, optimal values between 10 to 30
        # Should be chosen smaller for large frame sizes.
        self.cu_stack_size = np.int32(16)
        self.cu_mem_pool = None  # Memory pool

    def set_stack_size(self, s):
        """
        Change the number of frames which are stacked to perform all operations in //. If it
        is larger than the total number of frames, operators like AP, DM, ML will loop over
        all the stacks.
        :param s: an integer number (default=16)
        :return: nothing
        """
        self.cu_stack_size = np.int32(s)

    def cu_init_kernels(self):
        """
        Initialize cuda kernels
        :return: nothing
        """
        # TODO: delay initialization, on-demand for each type of operator ?

        # Elementwise kernels
        self.cu_scale = CU_ElK(name='cu_scale',
                               operation="d[i] = complexf(d[i].real() * scale, d[i].imag() * scale )",
                               preamble=getks('cuda/complex.cu'),
                               options=self.cu_options, arguments="pycuda::complex<float> *d, const float scale")

        self.cu_sum = CU_ElK(name='cu_sum',
                             operation="dest[i] += src[i]",
                             preamble=getks('cuda/complex.cu'),
                             options=self.cu_options,
                             arguments="pycuda::complex<float> *src, pycuda::complex<float> *dest")

        self.cu_scale_complex = CU_ElK(name='cu_scale_complex',
                                       operation="d[i] = complexf(d[i].real() * s.real() - d[i].imag() * s.imag(), d[i].real() * s.imag() + d[i].imag() * s.real())",
                                       preamble=getks('cuda/complex.cu'),
                                       options=self.cu_options,
                                       arguments="pycuda::complex<float> *d, const pycuda::complex<float> s")

        self.cu_quad_phase = CU_ElK(name='cu_quad_phase',
                                    operation="QuadPhase(i, d, f, scale, nx, ny)",
                                    preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/quad_phase_elw.cu'),
                                    options=self.cu_options,
                                    arguments="pycuda::complex<float> *d, const float f, const float scale, const int nx, const int ny")

        # Linear combination with 2 complex arrays and 2 float coefficients
        self.cu_linear_comb_fcfc = CU_ElK(name='cu_linear_comb_fcfc',
                                          operation="dest[i] = complexf(a * dest[i].real() + b * src[i].real(), a * dest[i].imag() + b * src[i].imag())",
                                          options=self.cu_options,
                                          preamble=getks('cuda/complex.cu'),
                                          arguments="const float a, pycuda::complex<float> *dest, const float b, pycuda::complex<float> *src")

        # Linear combination with 2 float arrays and 2 float coefficients
        self.cu_linear_comb_4f = CU_ElK(name='cu_linear_comb_4f',
                                        operation="dest[i] = a * dest[i] + b * src[i]",
                                        options=self.cu_options,
                                        preamble=getks('cuda/complex.cu'),
                                        arguments="const float a, float *dest, const float b, float *src")

        self.cu_projection_amplitude = CU_ElK(name='cu_projection_amplitude',
                                              operation="ProjectionAmplitude(i, iobs, dcalc, background, nbmode, nxy, nxystack, npsi, scale_in, scale_out)",
                                              preamble=getks('cuda/complex.cu') + getks(
                                                  'ptycho/cuda/projection_amplitude_elw.cu'),
                                              options=self.cu_options,
                                              arguments="float *iobs, pycuda::complex<float> *dcalc, float *background, const int nbmode, const int nxy, const int nxystack, const int npsi, const float scale_in, const float scale_out")

        self.cu_projection_amplitude_update_background = CU_ElK(name='cu_projection_amplitude_update_background',
                                                                operation="ProjectionAmplitudeUpdateBackground(i, iobs, dcalc, background, vd, vd2, vz2, vdz2, nbmode, nxy, nxystack, npsi, first_pass, scale_in, scale_out)",
                                                                preamble=getks('cuda/complex.cu') + getks(
                                                                    'ptycho/cuda/projection_amplitude_elw.cu'),
                                                                options=self.cu_options,
                                                                arguments="float *iobs, pycuda::complex<float> *dcalc, float *background, float *vd, float *vd2, float *vz2, float *vdz2, const int nbmode, const int nxy, const int nxystack, const int npsi, const char first_pass, const float scale_in, const float scale_out")

        self.cu_background_update = CU_ElK(name='cu_background_update',
                                           operation="const float eta = fmaxf(0.8f, vdz2[i]/vd2[i]);"
                                                     "background[i] = fmaxf(0.0f, background[i] + (vd[i] - vz2[i] / eta) / nframes);",
                                           options=self.cu_options, preamble=getks('cuda/complex.cu'),
                                           arguments="float* background, float* vd, float* vd2, float* vz2, float* vdz2, const int nframes")

        self.cu_calc2obs = CU_ElK(name='cu_calc2obs',
                                  operation="Calc2Obs(i, iobs, dcalc, nbmode, nxystack)",
                                  preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/calc2obs_elw.cu'),
                                  options=self.cu_options,
                                  arguments="float *iobs, pycuda::complex<float> *dcalc, const int nbmode, const int nxystack")

        self.cu_object_probe_mult = CU_ElK(name='cu_object_probe_mult',
                                           operation="ObjectProbeMultQuadPhase(i, psi, obj, probe, cx, cy, pixel_size, f, npsi, stack_size, nx, ny, nxo, nyo, nbobj, nbprobe)",
                                           preamble=getks('cuda/complex.cu') + getks(
                                               'ptycho/cuda/obj_probe_mult_elw.cu'),
                                           options=self.cu_options,
                                           arguments="pycuda::complex<float>* psi, pycuda::complex<float> *obj, pycuda::complex<float>* probe, int* cx, int* cy, const float pixel_size, const float f, const int npsi, const int stack_size, const int nx, const int ny, const int nxo, const int nyo, const int nbobj, const int nbprobe")

        self.cu_2object_probe_psi_dm1 = CU_ElK(name='cu_2object_probe_psi_dm1',
                                               operation="ObjectProbePsiDM1(i, psi, obj, probe, cx, cy, pixel_size, f, npsi, stack_size, nx, ny, nxo, nyo, nbobj, nbprobe)",
                                               preamble=getks('cuda/complex.cu') + getks(
                                                   'ptycho/cuda/obj_probe_dm_elw.cu'),
                                               options=self.cu_options,
                                               arguments="pycuda::complex<float>* psi, pycuda::complex<float> *obj, pycuda::complex<float>* probe, int* cx, int* cy, const float pixel_size, const float f, const int npsi, const int stack_size, const int nx, const int ny, const int nxo, const int nyo, const int nbobj, const int nbprobe")

        self.cu_2object_probe_psi_dm2 = CU_ElK(name='cu_2object_probe_psi_dm2',
                                               operation="ObjectProbePsiDM2(i, psi, psi_fourier, obj, probe, cx, cy, pixel_size, f, npsi, stack_size, nx, ny, nxo, nyo, nbobj, nbprobe)",
                                               preamble=getks('cuda/complex.cu') + getks(
                                                   'ptycho/cuda/obj_probe_dm_elw.cu'),
                                               options=self.cu_options,
                                               arguments="pycuda::complex<float>* psi, pycuda::complex<float>* psi_fourier, pycuda::complex<float> *obj, pycuda::complex<float>* probe, int* cx, int* cy, const float pixel_size, const float f, const int npsi, const int stack_size, const int nx, const int ny, const int nxo, const int nyo, const int nbobj, const int nbprobe")

        self.cu_psi_to_obj = CU_ElK(name='psi_to_obj',
                                    operation="UpdateObjQuadPhase(i, psi, objnew, probe, objnorm, cx, cy, px, f, stack_size, nx, ny, nxo, nyo, nbobj, nbprobe)",
                                    preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/psi_to_obj_probe_elw.cu'),
                                    options=self.cu_options,
                                    arguments="pycuda::complex<float>* psi, pycuda::complex<float> *objnew, pycuda::complex<float>* probe, float* objnorm, const int cx, const int cy, const float px, const float f, const int stack_size, const int nx, const int ny, const int nxo, const int nyo, const int nbobj, const int nbprobe")

        self.cu_psi_to_objN = CU_ElK(name='psi_to_objN',
                                     operation="UpdateObjQuadPhaseN(i, psi, objnewN, probe, objnormN, cx, cy, px, f, stack_size, nx, ny, nxo, nyo, nbobj, nbprobe, npsi, padding)",
                                     preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/psi_to_obj_probe_elw.cu'),
                                     options=self.cu_options,
                                     arguments="pycuda::complex<float>* psi, pycuda::complex<float> *objnewN, pycuda::complex<float>* probe, float* objnormN, int* cx, int* cy, const float px, const float f, const int stack_size, const int nx, const int ny, const int nxo, const int nyo, const int nbobj, const int nbprobe, const int npsi, const int padding")

        self.cu_psi_to_obj_atomic = CU_ElK(name='psi_to_obj_atomic',
                                           operation="UpdateObjQuadPhaseAtomic(i, psi, objnew, probe, objnorm, cx, cy,"
                                                     "px, f, stack_size, nx, ny, nxo, nyo, nbobj, nbprobe, npsi,"
                                                     "padding)",
                                           preamble=getks('cuda/complex.cu') +
                                                    getks('ptycho/cuda/psi_to_obj_probe_elw.cu'),
                                           options=self.cu_options,
                                           arguments="pycuda::complex<float>* psi, pycuda::complex<float> *objnew,"
                                                     "pycuda::complex<float>* probe, float* objnorm, int* cx, int* cy,"
                                                     "const float px, const float f, const int stack_size,"
                                                     "const int nx, const int ny, const int nxo, const int nyo,"
                                                     "const int nbobj, const int nbprobe, const int npsi,"
                                                     "const int padding")

        self.cu_sum_n = CU_ElK(name='sum_n',
                               operation="SumN(i, objnewN, objnormN, stack_size, nxyo, nbobj)",
                               preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/psi_to_obj_probe_elw.cu'),
                               options=self.cu_options,
                               arguments="pycuda::complex<float>* objnewN, float* objnorm, float* objnormN, const int stack_size, const int nxyo, const int nbobj")

        self.cu_sum_n_norm = CU_ElK(name='sum_n_norm',
                                    operation="SumNnorm(i, objnormN, stack_size, nxyo)",
                                    preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/psi_to_obj_probe_elw.cu'),
                                    options=self.cu_options,
                                    arguments="float* objnormN, const int stack_size, const int nxyo")

        self.cu_psi_to_probe = CU_ElK(name='psi_to_probe',
                                      operation="UpdateProbeQuadPhase(i, obj, probe_new, psi, probenorm, cx, cy, px, f, firstpass, npsi, stack_size, nx, ny, nxo, nyo, nbobj, nbprobe)",
                                      preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/psi_to_obj_probe_elw.cu'),
                                      options=self.cu_options,
                                      arguments="pycuda::complex<float>* psi, pycuda::complex<float> *obj, pycuda::complex<float>* probe_new, float* probenorm, int* cx, int* cy, const float px, const float f, const char firstpass, const int npsi, const int stack_size, const int nx, const int ny, const int nxo, const int nyo, const int nbobj, const int nbprobe")

        self.cu_obj_norm = CU_ElK(name='obj_norm',
                                  operation="ObjNorm(i, objnorm, obj_unnorm, obj, regmax, reg, nxyo, nbobj)",
                                  preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/psi_to_obj_probe_elw.cu'),
                                  options=self.cu_options,
                                  arguments="float* objnorm, pycuda::complex<float>* obj_unnorm,"
                                            "pycuda::complex<float>* obj, float* regmax, const float reg,"
                                            "const int nxyo, const int nbobj")

        self.cu_obj_norm_n = CU_ElK(name='obj_norm_n',
                                    operation="ObjNormN(i, obj_norm, obj_newN, obj, regmax, reg, nxyo, nbobj, stack_size)",
                                    preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/psi_to_obj_probe_elw.cu'),
                                    options=self.cu_options,
                                    arguments="float* obj_norm, pycuda::complex<float>* obj_newN, pycuda::complex<float>* obj, float* regmax, const float reg, const int nxyo, const int nbobj, const int stack_size")

        self.cu_obj_norm_zero_phase_mask_n = CU_ElK(name='obj_norm_zero_phase_n',
                                                    operation="ObjNormZeroPhaseMaskN(i, obj_norm, obj_newN, obj, zero_phase_mask, regmax, reg, nxyo, nbobj, stack_size)",
                                                    preamble=getks('cuda/complex.cu') + getks(
                                                        'ptycho/cuda/psi_to_obj_probe_elw.cu'),
                                                    options=self.cu_options,
                                                    arguments="float* obj_norm, pycuda::complex<float>* obj_newN, pycuda::complex<float>* obj, float* zero_phase_mask, float* regmax, const float reg, const int nxyo, const int nbobj, const int stack_size")

        self.cu_grad_poisson_fourier = CU_ElK(name='grad_poisson_fourier',
                                              operation="GradPoissonFourier(i, iobs, psi, background, nbmode, nx, ny, "
                                                        "nxy, nxystack, hann_filter, scale_in, scale_out)",
                                              preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/grad_elw.cu'),
                                              options=self.cu_options,
                                              arguments="float *iobs, pycuda::complex<float> *psi, float *background,"
                                                        "const int nbmode, const int nx, const int ny, const int nxy, "
                                                        "const int nxystack, const char hann_filter,"
                                                        "const float scale_in, const float scale_out")

        self.cu_psi_to_obj_grad = CU_ElK(name='psi_to_obj_grad',
                                         operation="GradObj(i, psi, obj_grad, probe, cx, cy, px, f, stack_size, nx, ny,"
                                                   "nxo, nyo, nbobj, nbprobe)",
                                         preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/grad_elw.cu'),
                                         options=self.cu_options,
                                         arguments="pycuda::complex<float>* psi, pycuda::complex<float> *obj_grad,"
                                                   "pycuda::complex<float>* probe, const int cx, const int cy,"
                                                   "const float px, const float f, const int stack_size, const int nx,"
                                                   "const int ny, const int nxo, const int nyo, const int nbobj,"
                                                   "const int nbprobe")

        self.cu_psi_to_obj_grad_atomic = CU_ElK(name='psi_to_obj_grad_atomic',
                                                operation="GradObjAtomic(i, psi, obj_grad, probe, cx, cy, px, f,"
                                                          "stack_size, nx, ny, nxo, nyo, nbobj, nbprobe, npsi)",
                                                preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/grad_elw.cu'),
                                                options=self.cu_options,
                                                arguments="pycuda::complex<float>* psi,"
                                                          "pycuda::complex<float> *obj_grad,"
                                                          "pycuda::complex<float>* probe, int* cx, int* cy,"
                                                          "const float px, const float f, const int stack_size,"
                                                          "const int nx, const int ny, const int nxo, const int nyo,"
                                                          "const int nbobj, const int nbprobe, const int npsi")

        self.cu_psi_to_probe_grad = CU_ElK(name='psi_to_probe_grad',
                                           operation="GradProbe(i, psi, probe_grad, obj, cx, cy, px, f, firstpass, npsi, stack_size, nx, ny, nxo, nyo, nbobj, nbprobe)",
                                           preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/grad_elw.cu'),
                                           options=self.cu_options,
                                           arguments="pycuda::complex<float>* psi, pycuda::complex<float>* probe_grad, pycuda::complex<float> *obj, int* cx, int* cy, const float px, const float f, const char firstpass, const int npsi, const int stack_size, const int nx, const int ny, const int nxo, const int nyo, const int nbobj, const int nbprobe")

        self.cu_reg_grad = CU_ElK(name='reg_grad',
                                  operation="GradReg(i, grad, v, alpha, nx, ny)",
                                  preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/grad_elw.cu'),
                                  options=self.cu_options,
                                  arguments="pycuda::complex<float>* grad, pycuda::complex<float>* v,"
                                            "const float alpha, const int nx, const int ny")

        self.cu_circular_shift = CU_ElK(name='cu_circular_shift',
                                        operation="circular_shift(i, source, dest, dx, dy, dz, nx, ny, nz)",
                                        preamble=getks('cuda/complex.cu') + getks('cuda/circular_shift.cu'),
                                        options=self.cu_options,
                                        arguments="pycuda::complex<float>* source, pycuda::complex<float>* dest,"
                                                  "const int dx, const int dy, const int dz,"
                                                  "const int nx, const int ny, const int nz")

        # Reduction kernels
        self.cu_max_red = CU_RedK(np.float32, neutral="-1e32", reduce_expr="max(a,b)",
                                  options=self.cu_options, arguments="float *in")

        self.cu_norm_complex_n = CU_RedK(np.float32, neutral="0", reduce_expr="a+b", name='norm_complex_n_red',
                                         map_expr="ComplexNormN(d[i], nn)",
                                         options=self.cu_options,
                                         preamble=getks('cuda/complex.cu'),
                                         arguments="pycuda::complex<float> *d, const int nn")

        # This will compute Poisson, Gaussian, Euclidian LLK as well as the sum of the calculated intensity
        self.cu_llk = CU_RedK(my_float4, neutral="my_float4(0)", reduce_expr="a+b", name='llk_red',
                              preamble=getks('cuda/complex.cu') + getks('cuda/float_n.cu') + getks(
                                  'ptycho/cuda/llk_red.cu'),
                              options=self.cu_options,
                              map_expr="LLKAll(i, iobs, psi, background, nbmode, nxy, nxystack, scale)",
                              arguments="float *iobs, pycuda::complex<float> *psi, float *background, const int nbmode,"
                                        "const int nxy, const int nxystack, const float scale")

        self.cu_cg_polak_ribiere_red = CU_RedK(np.complex64, neutral="complexf(0,0)", name='polak_ribiere_red',
                                               reduce_expr="a+b",
                                               map_expr="PolakRibiereComplex(grad[i], lastgrad[i])",
                                               preamble=getks('cuda/complex.cu') + getks(
                                                   'cuda/cg_polak_ribiere_red.cu'),
                                               options=self.cu_options,
                                               arguments="pycuda::complex<float> *grad, pycuda::complex<float> *lastgrad")

        self._cu_cg_poisson_gamma_red = CU_RedK(np.complex64, neutral="complexf(0,0)", name='cg_poisson_gamma_red',
                                                reduce_expr="a+b",
                                                map_expr="CG_Poisson_Gamma(i, obs, PO, PdO, dPO, dPdO, scale, "
                                                         "scale_fft, nxy, nxystack, nbmode)",
                                                preamble=getks('cuda/complex.cu') + getks('cuda/float_n.cu') + getks(
                                                    'ptycho/cuda/cg_gamma_red.cu'),
                                                options=self.cu_options,
                                                arguments="float *obs, pycuda::complex<float> *PO, "
                                                          "pycuda::complex<float> *PdO, pycuda::complex<float> *dPO,"
                                                          "pycuda::complex<float> *dPdO, const float scale,"
                                                          "const float scale_fft, const int nxy, const int nxystack,"
                                                          "const int nbmode")

        # 4th order LLK(gamma) approximation
        self._cu_cg_poisson_gamma4_red = CU_RedK(my_float4, neutral="my_float4(0)", name='cg_poisson_gamma4_red',
                                                 reduce_expr="a+b",
                                                 map_expr="CG_Poisson_Gamma4(i, obs, PO, PdO, dPO, dPdO, scale, nxy, nxystack, nbmode)",
                                                 preamble=getks('cuda/complex.cu') + getks('cuda/float_n.cu') + getks(
                                                     'ptycho/cuda/cg_gamma_red.cu'),
                                                 options=self.cu_options,
                                                 arguments="float *obs, pycuda::complex<float> *PO, pycuda::complex<float> *PdO, pycuda::complex<float> *dPO, pycuda::complex<float> *dPdO, const float scale, const int nxy, const int nxystack, const int nbmode")

        self._cu_cg_gamma_reg_red = CU_RedK(np.complex64, neutral="complexf(0,0)", name='cg_gamma_reg_red',
                                            reduce_expr="a+b",
                                            map_expr="GammaReg(i, v, dv, nx, ny)",
                                            preamble=getks('cuda/complex.cu') + getks('cuda/cg_gamma_reg_red.cu'),
                                            arguments="pycuda::complex<float> *v, pycuda::complex<float> *dv,"
                                                      "const int nx, const int ny")

        self.cu_scale_intensity = CU_RedK(np.complex64, neutral="complexf(0,0)", name='scale_intensity',
                                          reduce_expr="a+b", map_expr="scale_intensity(i, obs, calc, nxystack,"
                                                                      "nb_mode)",
                                          preamble=getks('cuda/complex.cu') + getks('ptycho/cuda/scale_red.cu'),
                                          arguments="float *obs, pycuda::complex<float> *calc, const int nxystack,"
                                                    "const int nb_mode")

        self.cu_center_mass_complex = CU_RedK(my_float4, neutral="my_float4(0)", name="cu_center_mass_complex",
                                              reduce_expr="a+b",
                                              preamble=getks('cuda/complex.cu') + getks('cuda/float_n.cu')
                                                       + getks('cuda/center_mass_red.cu'),
                                              options=self.cu_options,
                                              map_expr="center_mass_complex(i, d, nx, ny, nz, power)",
                                              arguments="pycuda::complex<float> *d, const int nx, const int ny,"
                                                        "const int nz, const int power")

        # custom kernels
        # Gaussian convolution kernels
        opt = "#define BLOCKSIZE 16\n#define HALFBLOCK 7\n"
        conv16_mod = SourceModule(opt + getks('cuda/complex.cu') + getks('cuda/convolution_complex.cu'),
                                  options=self.cu_options)
        self.gauss_convol_complex_16x = conv16_mod.get_function("gauss_convol_complex_x")
        self.gauss_convol_complex_16y = conv16_mod.get_function("gauss_convol_complex_y")
        self.gauss_convol_complex_16z = conv16_mod.get_function("gauss_convol_complex_z")

        opt = "#define BLOCKSIZE 32\n#define HALFBLOCK 15\n"
        conv32_mod = SourceModule(opt + getks('cuda/complex.cu') + getks('cuda/convolution_complex.cu'),
                                  options=self.cu_options)
        self.gauss_convol_complex_32x = conv32_mod.get_function("gauss_convol_complex_x")
        self.gauss_convol_complex_32y = conv32_mod.get_function("gauss_convol_complex_y")
        self.gauss_convol_complex_32z = conv32_mod.get_function("gauss_convol_complex_z")

        opt = "#define BLOCKSIZE 64\n#define HALFBLOCK 31\n"
        conv64_mod = SourceModule(opt + getks('cuda/complex.cu') + getks('cuda/convolution_complex.cu'),
                                  options=self.cu_options)
        self.gauss_convol_complex_64x = conv64_mod.get_function("gauss_convol_complex_x")
        self.gauss_convol_complex_64y = conv64_mod.get_function("gauss_convol_complex_y")
        self.gauss_convol_complex_64z = conv64_mod.get_function("gauss_convol_complex_z")

        # Init memory pool
        self.cu_mem_pool = cu_tools.DeviceMemoryPool()


"""
The default processing unit 
"""
default_processing_unit = CUProcessingUnitPtycho()


class CUObsDataStack:
    """
    Class to store a stack (e.g. 16 frames) of observed data in CUDA space
    """

    def __init__(self, cu_obs, cu_x, cu_y, i, npsi):
        """

        :param cu_obs: pycuda array of observed data, with N frames
        :param cu_x, cu_y: pycuda arrays of the positions (in pixels) of the different frames
        :param i: index of the first frame
        :param npsi: number of valid frames (others are filled with zeros)
        """
        self.cu_obs = cu_obs
        self.cu_x = cu_x
        self.cu_y = cu_y
        self.i = np.int32(i)
        self.npsi = np.int32(npsi)
        self.x = cu_x.get()
        self.y = cu_y.get()


class CUOperatorPtycho(OperatorPtycho):
    """
    Base class for a operators on Ptycho objects using CUDA
    """

    def __init__(self, processing_unit=None):
        super(CUOperatorPtycho, self).__init__()

        self.Operator = CUOperatorPtycho
        self.OperatorSum = CUOperatorPtychoSum
        self.OperatorPower = CUOperatorPtychoPower

        if processing_unit is None:
            self.processing_unit = default_processing_unit
        else:
            self.processing_unit = processing_unit
        if self.processing_unit.cu_ctx is None:
            # CUDA kernels have not been prepared yet, use a default initialization
            if main_default_processing_unit.cu_device is None:
                main_default_processing_unit.select_gpu(language='cuda')

            self.processing_unit.init_cuda(cu_device=main_default_processing_unit.cu_device,
                                           test_fft=False, verbose=False)

    def apply_ops_mul(self, pty):
        """
        Apply the series of operators stored in self.ops to a Ptycho2D object.
        The operators are applied one after the other.

        :param pty: the Ptycho2D to which the operators will be applied.
        :return: the Ptycho2D object, after application of all the operators in sequence
        """
        if isinstance(pty, Ptycho) is False:
            raise OperatorException(
                "ERROR: tried to apply operator:\n    %s \n  to:\n    %s\n  which is not a Ptycho object" % (
                    str(self), str(pty)))
        return super(CUOperatorPtycho, self).apply_ops_mul(pty)

    def prepare_data(self, p: Ptycho):
        """
        Make sure the data to be used is in the correct memory (host or GPU) for the operator.
        Virtual, must be derived.

        :param p: the Ptycho object the operator will be applied to.
        :return:
        """
        pu = self.processing_unit
        if has_attr_not_none(p, "_cu_obs_v") is False:
            # Assume observed intensity is immutable, so transfer only once
            self.init_cu_vobs(p)
        elif len(p._cu_obs_v[0].cu_obs) != self.processing_unit.cu_stack_size:
            # This should not happen, but if tests are being made on the speed vs the stack size, this can be useful.
            self.init_cu_vobs(p)

        if p._timestamp_counter > p._cu_timestamp_counter:
            # print("Moving object, probe to CUDA GPU")
            p._cu_obj = cua.to_gpu(p._obj, allocator=pu.cu_mem_pool.allocate)
            p._cu_probe = cua.to_gpu(p._probe, allocator=pu.cu_mem_pool.allocate)
            if p._obj_zero_phase_mask is not None:
                p._cu_obj_zero_phase_mask = cua.to_gpu(p._obj_zero_phase_mask, allocator=pu.cu_mem_pool.allocate)
            else:
                p._cu_obj_zero_phase_mask = None
            p._cu_timestamp_counter = p._timestamp_counter
            if p._background is None:
                p._cu_background = cua.zeros(p.data.iobs.shape[-2:], dtype=np.float32,
                                             allocator=pu.cu_mem_pool.allocate)
            else:
                p._cu_background = cua.to_gpu(p._background, allocator=pu.cu_mem_pool.allocate)

        need_init_psi = False
        if has_attr_not_none(p, "_cu_psi") is False:
            need_init_psi = True
        elif p._cu_psi.shape[0:3] != (len(p._obj), len(p._probe), self.processing_unit.cu_stack_size):
            need_init_psi = True
        if need_init_psi:
            ny, nx = p._probe.shape[-2:]
            p._cu_psi = cua.empty(shape=(len(p._obj), len(p._probe), self.processing_unit.cu_stack_size, ny, nx),
                                  dtype=np.complex64, allocator=pu.cu_mem_pool.allocate)

        if has_attr_not_none(p, "_cu_psi_v") is False or need_init_psi:
            # _cu_psi_v is used to hold the complete copy of Psi projections for all stacks, for algorithms
            # such as DM which need them.
            p._cu_psi_v = {}

        if has_attr_not_none(p, '_cu_view') is False:
            p._cu_view = {}

    def init_cu_vobs(self, p):
        """
        Initialize observed intensity and scan positions in GPU space

        :param p: the Ptycho object the operator will be applied to.
        :return:
        """
        # print("Moving observed data to PU GPU")
        pu = self.processing_unit
        p._cu_obs_v = []
        nb_frame, ny, nx = p.data.iobs.shape
        nyo, nxo = p._obj.shape[-2:]
        cu_stack_size = self.processing_unit.cu_stack_size
        px, py = p.data.pixel_size_object()
        for i in range(0, nb_frame, cu_stack_size):
            vcx = np.zeros((cu_stack_size), dtype=np.int32)
            vcy = np.zeros((cu_stack_size), dtype=np.int32)
            vobs = np.zeros((cu_stack_size, ny, nx), dtype=np.float32)
            if nb_frame < (i + cu_stack_size):
                print("Number of frames is not a multiple of %d, adding %d null frames" %
                      (cu_stack_size, i + cu_stack_size - nb_frame))
            for j in range(cu_stack_size):
                ij = i + j
                if ij < nb_frame:
                    dy, dx = p.data.posy[ij] / py, p.data.posx[ij] / px
                    cx, cy = get_view_coord((nyo, nxo), (ny, nx), dx, dy)
                    vcx[j] = np.int32(round(cx))
                    vcy[j] = np.int32(round(cy))
                    vobs[j] = p.data.iobs[ij]
                else:
                    vcx[j] = vcx[0]
                    vcy[j] = vcy[0]
                    vobs[j] = np.zeros_like(vobs[0], dtype=np.float32)
            cu_vcx = cua.to_gpu(vcx, allocator=pu.cu_mem_pool.allocate)
            cu_vcy = cua.to_gpu(vcy, allocator=pu.cu_mem_pool.allocate)
            cu_vobs = cua.to_gpu(vobs, allocator=pu.cu_mem_pool.allocate)
            p._cu_obs_v.append(CUObsDataStack(cu_vobs, cu_vcx, cu_vcy, i, np.int32(min(cu_stack_size, nb_frame - i))))
        # Initialize the size and index of current stack
        p._cu_stack_i = 0
        p._cu_stack_nb = len(p._cu_obs_v)

    def timestamp_increment(self, p):
        p._timestamp_counter += 1
        p._cu_timestamp_counter = p._timestamp_counter

    def view_register(self, obj):
        """
        Creates a new unique view key in an object. When finished with this view, it should be de-registered
        using view_purge. Note that it only reserves the key, but does not create the view.
        :return: an integer value, which corresponds to yet-unused key in the object's view.
        """
        i = 1
        while i in obj._cu_view:
            i += 1
        obj._cu_view[i] = None
        return i

    def view_copy(self, pty, i_source, i_dest):
        if i_source == 0:
            src = {'obj': pty._cu_obj, 'probe': pty._cu_probe, 'psi': pty._cu_psi, 'psi_v': pty._cu_psi_v}
        else:
            src = pty._cu_view[i_source]
        if i_dest is 0:
            pty._cu_obj = cua.empty_like(src['obj'])
            pty._cu_probe = cua.empty_like(src['probe'])
            pty._cu_psi = cua.empty_like(src['psi'])
            pty._cu_psi_v = {}
            dest = {'obj': pty._cu_obj, 'probe': pty._cu_probe, 'psi': pty._cu_psi, 'psi_v': pty._cu_psi_v}
        else:
            pty._cu_view[i_dest] = {'obj': cua.empty_like(src['obj']), 'probe': cua.empty_like(src['probe']),
                                    'psi': cua.empty_like(src['psi']), 'psi_v': {}}
            dest = pty._cu_view[i_dest]

        for i in range(len(src['psi_v'])):
            dest['psi_v'][i] = cua.empty_like(src['psi'])

        for s, d in zip([src['obj'], src['probe'], src['psi']] + [v for k, v in src['psi_v'].items()],
                        [dest['obj'], dest['probe'], dest['psi']] + [v for k, v in dest['psi_v'].items()]):
            cu_drv.memcpy_dtod(dest=d.gpudata, src=s.gpudata, size=d.nbytes)

    def view_swap(self, pty, i1, i2):
        if i1 != 0:
            if pty._cu_view[i1] is None:
                # Create dummy value, assume a copy will be made later
                pty._cu_view[i1] = {'obj': None, 'probe': None, 'psi': None, 'psi_v': None}
        if i2 != 0:
            if pty._cu_view[i2] is None:
                # Create dummy value, assume a copy will be made later
                pty._cu_view[i2] = {'obj': None, 'probe': None, 'psi': None, 'psi_v': None}
        if i1 == 0:
            pty._cu_obj, pty._cu_view[i2]['obj'] = pty._cu_view[i2]['obj'], pty._cu_obj
            pty._cu_probe, pty._cu_view[i2]['probe'] = pty._cu_view[i2]['probe'], pty._cu_probe
            pty._cu_psi, pty._cu_view[i2]['psi'] = pty._cu_view[i2]['psi'], pty._cu_psi
            pty._cu_psi_v, pty._cu_view[i2]['psi_v'] = pty._cu_view[i2]['psi_v'], pty._cu_psi_v
        elif i2 == 0:
            pty._cu_obj, pty._cu_view[i1]['obj'] = pty._cu_view[i1]['obj'], pty._cu_obj
            pty._cu_probe, pty._cu_view[i1]['probe'] = pty._cu_view[i1]['probe'], pty._cu_probe
            pty._cu_psi, pty._cu_view[i1]['psi'] = pty._cu_view[i1]['psi'], pty._cu_psi
            pty._cu_psi_v, pty._cu_view[i1]['psi_v'] = pty._cu_view[i1]['psi_v'], pty._cu_psi_v
        else:
            pty._cu_view[i1], pty._cu_view[i2] = pty._cu_view[i2], pty._cu_view[i1]
        self.timestamp_increment(pty)

    def view_sum(self, pty, i_source, i_dest):
        if i_source == 0:
            src = {'obj': pty._cu_obj, 'probe': pty._cu_probe, 'psi': pty._cu_psi, 'psi_v': pty._cu_psi_v}
        else:
            src = pty._cu_view[i_source]
        if i_dest == 0:
            dest = {'obj': pty._cu_obj, 'probe': pty._cu_probe, 'psi': pty._cu_psi, 'psi_v': pty._cu_psi_v}
        else:
            dest = pty._cu_view[i_dest]
        for s, d in zip([src['obj'], src['probe'], src['psi']] + [v for k, v in src['psi_v'].items()],
                        [dest['obj'], dest['probe'], dest['psi']] + [v for k, v in dest['psi_v'].items()]):
            self.processing_unit.cu_sum(s, d)
        self.timestamp_increment(pty)

    def view_purge(self, pty, i):
        if i is not None:
            del pty._cu_view[i]
        elif has_attr_not_none(pty, '_cu_view'):
            del pty._cu_view


# The only purpose of this class is to make sure it inherits from CUOperatorPtycho and has a processing unit
class CUOperatorPtychoSum(OperatorSum, CUOperatorPtycho):
    def __init__(self, op1, op2):
        if np.isscalar(op1):
            op1 = Scale(op1)
        if np.isscalar(op2):
            op2 = Scale(op2)
        if isinstance(op1, CUOperatorPtycho) is False or isinstance(op2, CUOperatorPtycho) is False:
            raise OperatorException(
                "ERROR: cannot add a CUOperatorPtycho with a non-CUOperatorPtycho: %s + %s" % (str(op1), str(op2)))
        # We can only have a sum of two CuOperatorPtycho, so they must have a processing_unit attribute.
        CUOperatorPtycho.__init__(self, op1.processing_unit)
        OperatorSum.__init__(self, op1, op2)

        # We need to cherry-pick some functions & attributes doubly inherited
        self.Operator = CUOperatorPtycho
        self.OperatorSum = CUOperatorPtychoSum
        self.OperatorPower = CUOperatorPtychoPower
        self.prepare_data = types.MethodType(CUOperatorPtycho.prepare_data, self)
        self.timestamp_increment = types.MethodType(CUOperatorPtycho.timestamp_increment, self)
        self.view_copy = types.MethodType(CUOperatorPtycho.view_copy, self)
        self.view_swap = types.MethodType(CUOperatorPtycho.view_swap, self)
        self.view_sum = types.MethodType(CUOperatorPtycho.view_sum, self)
        self.view_purge = types.MethodType(CUOperatorPtycho.view_purge, self)


# The only purpose of this class is to make sure it inherits from CUOperatorPtycho and has a processing unit
class CUOperatorPtychoPower(OperatorPower, CUOperatorPtycho):
    def __init__(self, op, n):
        CUOperatorPtycho.__init__(self, op.processing_unit)
        OperatorPower.__init__(self, op, n)

        # We need to cherry-pick some functions & attributes doubly inherited
        self.Operator = CUOperatorPtycho
        self.OperatorSum = CUOperatorPtychoSum
        self.OperatorPower = CUOperatorPtychoPower
        self.prepare_data = types.MethodType(CUOperatorPtycho.prepare_data, self)
        self.timestamp_increment = types.MethodType(CUOperatorPtycho.timestamp_increment, self)
        self.view_copy = types.MethodType(CUOperatorPtycho.view_copy, self)
        self.view_swap = types.MethodType(CUOperatorPtycho.view_swap, self)
        self.view_sum = types.MethodType(CUOperatorPtycho.view_sum, self)
        self.view_purge = types.MethodType(CUOperatorPtycho.view_purge, self)


class FreePU(CUOperatorPtycho):
    """
    Operator freeing CUDA memory. The gpyfft data reference in self.processing_unit is removed,
    as well as any CUDA pycuda.array.GPUArray attribute in the supplied object.
    """

    def __init__(self, verbose=False):
        """

        :param verbose: if True, will detail all the free'd memory and a summary
        """
        super(FreePU, self).__init__()
        self.verbose = verbose

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        self.processing_unit.finish()

        p.from_pu()
        if self.verbose:
            print("FreePU:")
        bytes = 0

        for o in dir(p):
            if isinstance(p.__getattribute__(o), cua.GPUArray):
                if self.verbose:
                    print("  Freeing: %40s %10.3fMbytes" % (o, p.__getattribute__(o).nbytes / 1e6))
                    bytes += p.__getattribute__(o).nbytes
                p.__getattribute__(o).gpudata.free()
                p.__setattr__(o, None)
        if has_attr_not_none(p, "_cu_psi_v"):
            for a in p._cu_psi_v.values():
                if self.verbose:
                    print("  Freeing: %40s %10.3fMbytes" % ("_cu_psi_v", a.nbytes / 1e6))
                    bytes += a.nbytes
                a.gpudata.free()
            p._cu_psi_v = {}
        for v in p._cu_obs_v:
            for o in dir(v):
                if isinstance(v.__getattribute__(o), cua.GPUArray):
                    v.__getattribute__(o).gpudata.free()
                    if self.verbose:
                        print("  Freeing: %40s %10.3fMbytes" % ("_cu_obs_v:" + o, v.__getattribute__(o).nbytes / 1e6))
                        bytes += v.__getattribute__(o).nbytes

        p._cu_obs_v = None
        self.processing_unit.cu_mem_pool.free_held()
        gc.collect()
        if self.verbose:
            print('FreePU total: %10.3fMbytes freed' % (bytes / 1e6))
        return p

    def timestamp_increment(self, p):
        p._cu_timestamp_counter = 0


class Scale(CUOperatorPtycho):
    """
    Multiply the ptycho object by a scalar (real or complex).
    """

    def __init__(self, x, obj=True, probe=True, psi=True):
        """

        :param x: the scaling factor
        :param obj: if True, scale the object
        :param probe: if True, scale the probe
        :param psi: if True, scale the all the psi arrays, _cu_psi as well as _cu_psi_v
        """
        super(Scale, self).__init__()
        self.x = x
        self.obj = obj
        self.probe = probe
        self.psi = psi

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        if self.x == 1:
            return p

        if np.isreal(self.x):
            scale_k = pu.cu_scale
            x = np.float32(self.x)
        else:
            scale_k = pu.cu_scale_complex
            x = np.complex64(self.x)

        if self.obj:
            scale_k(p._cu_obj, x)
        if self.probe:
            scale_k(p._cu_probe, x)
        if self.psi:
            scale_k(p._cu_psi, x)
            for i in range(len(p._cu_psi_v)):
                scale_k(p._cu_psi_v[i], x)
        return p


class ObjProbe2Psi(CUOperatorPtycho):
    """
    Computes Psi = Obj(r) * Probe(r-r_j) for a stack of N probe positions.
    """

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        # Multiply obj and probe with quadratic phase factor, taking into account all modes (if any)
        i = p._cu_stack_i
        ny = np.int32(p._probe.shape[-2])
        nx = np.int32(p._probe.shape[-1])
        nb_probe = np.int32(p._probe.shape[0])
        nb_obj = np.int32(p._obj.shape[0])
        nyo = np.int32(p._obj.shape[-2])
        nxo = np.int32(p._obj.shape[-1])
        if p.data.near_field:
            f = 0
        else:
            f = np.float32(np.pi / (p.data.wavelength * p.data.detector_distance))
        # print(i, f, p._cu_obs_v[i].npsi, self.processing_unit.cu_stack_size, nx, ny, nxo, nyo, nb_probe, nb_obj)
        # First argument is p._cu_psi[0] because the kernel will calculate the projection for all object and probe modes
        # and the full stack of frames.
        pu.cu_object_probe_mult(p._cu_psi[0, 0, 0], p._cu_obj, p._cu_probe, p._cu_obs_v[i].cu_x,
                                p._cu_obs_v[i].cu_y, p.pixel_size_object, f, p._cu_obs_v[i].npsi,
                                pu.cu_stack_size, nx, ny, nxo, nyo, nb_obj, nb_probe)
        return p


class FT(CUOperatorPtycho):
    """
    Forward Fourier-transform a Psi array, i.e. a stack of N Obj*Probe views
    """

    def __init__(self, scale=True):
        """

        :param scale: if True, the FFT will be normalized.
        """
        super(FT, self).__init__()
        self.scale = scale

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        pu.cu_fft_set_plan(p._cu_psi, batch=True)
        cu_fft.fft(p._cu_psi, p._cu_psi, self.processing_unit.cufft_plan, scale=False)
        if self.scale:
            pu.cu_scale(p._cu_psi, np.float32(1 / np.sqrt(p._cu_psi[0, 0, 0].size)))
        return p


class IFT(CUOperatorPtycho):
    """
    Backward Fourier-transform a Psi array, i.e. a stack of N Obj*Probe views
    """

    def __init__(self, scale=True):
        """

        :param scale: if True, the FFT will be normalized.
        """
        super(IFT, self).__init__()
        self.scale = scale

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        pu.cu_fft_set_plan(p._cu_psi, batch=True)
        cu_fft.ifft(p._cu_psi, p._cu_psi, self.processing_unit.cufft_plan, scale=False)
        if self.scale:
            pu.cu_scale(p._cu_psi, np.float32(1 / np.sqrt(p._cu_psi[0, 0, 0].size)))
        return p


class QuadraticPhase(CUOperatorPtycho):
    """
    Operator applying a quadratic phase factor
    """

    def __init__(self, factor, scale=1):
        """
        Application of a quadratic phase factor, and optionally a scale factor.

        The actual factor is:  :math:`scale * e^{i * factor * ((ix/nx)^2 + (iy/ny)^2)}`
        where ix and iy are the integer indices of the pixels

        :param factor: the factor for the phase calculation.
        :param scale: the data will be scaled by this factor. Useful to normalize before/after a Fourier transform,
                      without accessing twice the array data.
        """
        super(QuadraticPhase, self).__init__()
        self.scale = np.float32(scale)
        self.factor = np.float32(factor)

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        ny = np.int32(p._probe.shape[-2])
        nx = np.int32(p._probe.shape[-1])
        pu.cu_quad_phase(p._cu_psi, self.factor, self.scale, nx, ny)
        return p


class PropagateNearField(CUOperatorPtycho):
    """
    Near field propagator
    """

    def __init__(self, forward=True):
        """

        :param forward: if True, propagate forward, otherwise backward. The distance is taken from the ptycho data
                        this operator applies to.
        """
        super(PropagateNearField, self).__init__()
        self.forward = forward

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        f = np.float32(-np.pi * p.data.wavelength * p.data.detector_distance / p.data.pixel_size_detector ** 2)
        if self.forward is False:
            f = -f
        s = 1.0 / p._cu_psi[0, 0, 0].size  # Compensates for FFT scaling
        p = IFT(scale=False) * QuadraticPhase(factor=f, scale=s) * FT(scale=False) * p
        return p


class Calc2Obs(CUOperatorPtycho):
    """
    Copy the calculated intensities to the observed ones. Can be used for simulation. Applies to a stack of N views,
    assumes the current Psi is already in Fourier space.
    """

    def __init__(self):
        """

        """
        super(Calc2Obs, self).__init__()

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        nxy = np.int32(p._probe.shape[-2] * p._probe.shape[-1])
        nxystack = np.int32(nxy * self.processing_unit.cu_stack_size)
        nb_mode = np.int32(p._probe.shape[0] * p._obj.shape[0])
        i = p._cu_stack_i
        nb_psi = p._cu_obs_v[i].npsi
        pu.cu_calc2obs(p._cu_obs_v[i].cu_obs[:nb_psi], p._cu_psi, nb_mode, nxystack)
        return p


class ApplyAmplitude(CUOperatorPtycho):
    """
    Apply the magnitude from observed intensities, keep the phase. Applies to a stack of N views.
    """

    def __init__(self, calc_llk=False, update_background=False, scale_in=1, scale_out=1):
        """

        :param calc_llk: if True, the log-likelihood will be calculated for this stack.
        :param update_background: if True, update the background according to
                                  Marchesini et el,  Inverse Problems 29(11), 115009 (2013). This actually
                                  updates arrays from which the final background update can be evaluated.
                                  Note that the temporary arrays (_cu_vd, _cu_vd2, _cu_vz2, _cu_vdz2) must
                                  be created first in the Ptycho object beforehand (e.g. in the calling AP() operator),
                                  where the new background will be calculated once all stack of frames are processed.
        :param scale_in: a scale factor by which the input values should be multiplied, typically because of FFT
        :param scale_out: a scale factor by which the output values should be multiplied, typically because of FFT
        """
        super(ApplyAmplitude, self).__init__()
        self.calc_llk = calc_llk
        self.update_background = update_background
        self.scale_in = np.float32(scale_in)
        self.scale_out = np.float32(scale_out)

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        # TODO: use a single-pass reduction kernel to apply the amplitude and compute the LLK
        if self.calc_llk:
            p = LLK(scale=(self.scale_in != 1)) * p
        nxy = np.int32(p._probe.shape[-2] * p._probe.shape[-1])
        nxystack = np.int32(nxy * pu.cu_stack_size)
        nb_mode = np.int32(p._probe.shape[0] * p._obj.shape[0])
        i = p._cu_stack_i
        nb_psi = np.int32(p._cu_obs_v[i].npsi)
        if self.update_background:
            first_pass = np.int8(i == 0)
            pu.cu_projection_amplitude_update_background(p._cu_obs_v[i].cu_obs[0], p._cu_psi, p._cu_background,
                                                         p._cu_vd, p._cu_vd2, p._cu_vz2, p._cu_vdz2, nb_mode,
                                                         nxy, nxystack, nb_psi, first_pass,
                                                         self.scale_in, self.scale_out)
        else:
            pu.cu_projection_amplitude(p._cu_obs_v[i].cu_obs[0], p._cu_psi, p._cu_background,
                                       nb_mode, nxy, nxystack, nb_psi, self.scale_in, self.scale_out)
            return p


class PropagateApplyAmplitude(CUOperatorPtycho):
    """
    Propagate to the detector plane (either in far or near field, perform the magnitude projection, and propagate
    back to the object plane. This applies to a stack of frames.
    """

    def __init__(self, calc_llk=False, update_background=False):
        """

        :param calc_llk: if True, calculate llk while in the detector plane.
        :param update_background: if True, update the background according to
                                  Marchesini et el,  Inverse Problems 29(11), 115009 (2013).
        """
        super(PropagateApplyAmplitude, self).__init__()
        self.calc_llk = calc_llk
        self.update_background = update_background

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        if p.data.near_field:
            p = PropagateNearField(forward=False) * ApplyAmplitude(calc_llk=self.calc_llk,
                                                                   update_background=self.update_background) \
                * PropagateNearField(forward=True) * p
        else:
            s = 1 / np.sqrt(p._cu_psi[0, 0, 0].size)  # Compensates for FFT scaling
            p = IFT(scale=False) * ApplyAmplitude(calc_llk=self.calc_llk,
                                                  update_background=self.update_background,
                                                  scale_in=1, scale_out=s ** 2) * FT(scale=False) * p
        return p


class LLK(CUOperatorPtycho):
    """
    Log-likelihood reduction kernel. Can only be used when Psi is in diffraction space.
    This is a reduction operator - it will write llk as an argument in the Ptycho object, and return the object.
    If _cu_stack_i==0, the llk is re-initialized. Otherwise it is added to the current value.

    The LLK can be calculated directly from object and probe using: p = LoopStack(LLK() * FT() * ObjProbe2Psi()) * p
    """

    def __init__(self, scale=False):
        """

        :param scale: if True, will scale the calculated amplitude to calculate the log-likelihood. The amplitudes are
                      left unchanged.
        """
        super(LLK, self).__init__()
        self.scale = scale

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        i = p._cu_stack_i
        nb_mode = np.int32(p._probe.shape[0] * p._obj.shape[0])
        nb_psi = p._cu_obs_v[i].npsi
        nxy = np.int32(p._probe.shape[-2] * p._probe.shape[-1])
        nxystack = np.int32(self.processing_unit.cu_stack_size * nxy)
        s = np.float32(1)
        if self.scale:
            s = 1 / np.sqrt(p._cu_psi[0, 0, 0].size)  # Compensates for FFT scaling
        llk = self.processing_unit.cu_llk(p._cu_obs_v[i].cu_obs[:nb_psi], p._cu_psi, p._cu_background,
                                          nb_mode, nxy, nxystack, s).get()
        if p._cu_stack_i == 0:
            p.llk_poisson = llk['a']
            p.llk_gaussian = llk['b']
            p.llk_euclidian = llk['c']
            p.nb_photons_calc = llk['d']
        else:
            p.llk_poisson += llk['a']
            p.llk_gaussian += llk['b']
            p.llk_euclidian += llk['c']
            p.nb_photons_calc += llk['d']
        return p


class Psi2Obj(CUOperatorPtycho):
    """
    Computes updated Obj(r) contributions from Psi and Probe(r-r_j), for a stack of N probe positions.
    """

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        i = p._cu_stack_i
        # print("Psi2Obj(), i=%d"%(i))
        first_pass = np.int8(i == 0)
        ny = np.int32(p._probe.shape[-2])
        nx = np.int32(p._probe.shape[-1])
        nb_probe = np.int32(p._probe.shape[0])
        nb_obj = np.int32(p._obj.shape[0])
        nyo = np.int32(p._obj.shape[-2])
        nxo = np.int32(p._obj.shape[-1])
        npsi = np.int32(p._cu_obs_v[i].npsi)
        if p.data.near_field:
            f = 0
            padding = np.int32(p.data.padding)
        else:
            f = np.float32(-np.pi / (p.data.wavelength * p.data.detector_distance))
            padding = np.int32(0)
        # print(i, first_pass, f, p._cu_obs_v[i].npsi, self.processing_unit.cu_stack_size, nx, ny, nxo, nyo, nb_probe, nb_obj)

        if True:  # Use atomic operations for object update
            if i == 0:
                if has_attr_not_none(p, '_cu_obj_new') is False:
                    p._cu_obj_new = cua.zeros((nb_obj, nyo, nxo), dtype=np.complex64,
                                              allocator=pu.cu_mem_pool.allocate)
                elif p._cu_obj_new.size != nb_obj * nyo * nxo:
                    p._cu_obj_new = cua.zeros((nb_obj, nyo, nxo), dtype=np.complex64,
                                              allocator=pu.cu_mem_pool.allocate)
                else:
                    p._cu_obj_new.fill(np.complex64(0))

                if has_attr_not_none(p, '_cu_obj_norm') is False:
                    p._cu_obj_norm = cua.zeros((nyo, nxo), dtype=np.float32, allocator=pu.cu_mem_pool.allocate)
                elif p._cu_obj_norm.size != nyo * nxo:
                    p._cu_obj_norm = cua.zeros((nyo, nxo), dtype=np.float32, allocator=pu.cu_mem_pool.allocate)
                else:
                    p._cu_obj_norm.fill(np.float32(0))
            pu.cu_psi_to_obj_atomic(p._cu_psi[0, 0, 0], p._cu_obj_new, p._cu_probe, p._cu_obj_norm,
                                    p._cu_obs_v[i].cu_x, p._cu_obs_v[i].cu_y, p.pixel_size_object, f, pu.cu_stack_size,
                                    nx, ny, nxo, nyo, nb_obj, nb_probe, npsi, padding)
        else:
            if i == 0:
                if has_attr_not_none(p, '_cu_obj_newN') is False:
                    p._cu_obj_newN = cua.zeros((nb_obj, pu.cu_stack_size, nyo, nxo), dtype=np.complex64,
                                               allocator=pu.cu_mem_pool.allocate)
                elif p._cu_obj_newN.size != nb_obj * pu.cu_stack_size * nyo * nxo:
                    p._cu_obj_newN = cua.zeros((nb_obj, pu.cu_stack_size, nyo, nxo), dtype=np.complex64,
                                               allocator=pu.cu_mem_pool.allocate)
                else:
                    p._cu_obj_newN.fill(np.complex64(0))

                if has_attr_not_none(p, '_cu_obj_normN') is False:
                    p._cu_obj_normN = cua.zeros((pu.cu_stack_size, nyo, nxo), dtype=np.float32,
                                                allocator=pu.cu_mem_pool.allocate)
                elif p._cu_obj_normN.size != pu.cu_stack_size * nyo * nxo:
                    p._cu_obj_normN = cua.zeros((pu.cu_stack_size, nyo, nxo), dtype=np.float32,
                                                allocator=pu.cu_mem_pool.allocate)
                else:
                    p._cu_obj_normN.fill(np.float32(0))

            pu.cu_psi_to_objN(p._cu_psi[0, 0, 0], p._cu_obj_newN, p._cu_probe, p._cu_obj_normN,
                              p._cu_obs_v[i].cu_x, p._cu_obs_v[i].cu_y, p.pixel_size_object, f, pu.cu_stack_size,
                              nx, ny, nxo, nyo, nb_obj, nb_probe, npsi, padding)
        return p


class Psi2Probe(CUOperatorPtycho):
    """
    Computes updated Probe contributions from Psi and Obj, for a stack of N probe positions.
    """

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        i = p._cu_stack_i
        first_pass = np.int8(i == 0)
        ny = np.int32(p._probe.shape[-2])
        nx = np.int32(p._probe.shape[-1])
        nb_probe = np.int32(p._probe.shape[0])
        nb_obj = np.int32(p._obj.shape[0])
        nyo = np.int32(p._obj.shape[-2])
        nxo = np.int32(p._obj.shape[-1])
        if p.data.near_field:
            f = 0
        else:
            f = np.float32(-np.pi / (p.data.wavelength * p.data.detector_distance))
        # print(i, first_pass, f, p._cu_obs_v[i].npsi, self.processing_unit.cu_stack_size, nx, ny, nxo, nyo, nb_probe, nb_obj)

        if i == 0:
            if has_attr_not_none(p, '_cu_probe_new') is False:
                p._cu_probe_new = cua.empty((nb_probe, ny, nx), dtype=np.complex64, allocator=pu.cu_mem_pool.allocate)
            elif p._cu_probe_new.size != p._cu_probe.size:
                p._cu_probe_new = cua.empty((nb_probe, ny, nx), dtype=np.complex64, allocator=pu.cu_mem_pool.allocate)
            if has_attr_not_none(p, '_cu_probe_norm') is False:
                p._cu_probe_norm = cua.empty((ny, nx), dtype=np.float32, allocator=pu.cu_mem_pool.allocate)
            elif p._cu_probe_norm.size != ny * nx:
                p._cu_probe_norm = cua.empty((ny, nx), dtype=np.float32, allocator=pu.cu_mem_pool.allocate)

        # First argument is p._cu_psi[0] because the kernel will calculate the projection for all object and probe modes
        # and the full stack of frames.
        pu.cu_psi_to_probe(p._cu_psi[0, 0, 0], p._cu_obj, p._cu_probe_new, p._cu_probe_norm,
                           p._cu_obs_v[i].cu_x, p._cu_obs_v[i].cu_y,
                           p.pixel_size_object, f, first_pass,
                           p._cu_obs_v[i].npsi, self.processing_unit.cu_stack_size,
                           nx, ny, nxo, nyo, nb_obj, nb_probe)

        return p


class Psi2ObjMerge(CUOperatorPtycho):
    """
    Call this when all stack of probe positions have been processed, and the final update of the object can
    be calculated. Temporary arrays are cleaned up
    """

    def __init__(self, inertia=1e-2, smooth_sigma=0):
        """

        :param reg: object inertia
        :param smooth_sigma: if > 0, the previous object array (used for inertia) will be convolved
                             by a gaussian with this sigma.
        """
        super(Psi2ObjMerge, self).__init__()
        self.inertia = np.float32(inertia)
        self.smooth_sigma = np.float32(smooth_sigma)

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        nb_obj = np.int32(p._obj.shape[0])
        nxo = np.int32(p._obj.shape[-1])
        nyo = np.int32(p._obj.shape[-2])
        nxyo = np.int32(nxo * nyo)
        if self.smooth_sigma > 8:
            pu.gauss_convol_complex_64x(p._cu_obj, self.smooth_sigma, nxo, nyo, nb_obj, block=(64, 1, 1),
                                        grid=(1, int(nyo), int(nb_obj)))
            pu.gauss_convol_complex_64y(p._cu_obj, self.smooth_sigma, nxo, nyo, nb_obj, block=(1, 64, 1),
                                        grid=(int(nxo), 1, int(nb_obj)))
        elif self.smooth_sigma > 4:
            pu.gauss_convol_complex_32x(p._cu_obj, self.smooth_sigma, nxo, nyo, nb_obj, block=(32, 1, 1),
                                        grid=(1, int(nyo), int(nb_obj)))
            pu.gauss_convol_complex_32y(p._cu_obj, self.smooth_sigma, nxo, nyo, nb_obj, block=(1, 32, 1),
                                        grid=(int(nxo), 1, int(nb_obj)))
        elif self.smooth_sigma > 0.1:
            pu.gauss_convol_complex_16x(p._cu_obj, self.smooth_sigma, nxo, nyo, nb_obj, block=(16, 1, 1),
                                        grid=(1, int(nyo), int(nb_obj)))
            pu.gauss_convol_complex_16y(p._cu_obj, self.smooth_sigma, nxo, nyo, nb_obj, block=(1, 16, 1),
                                        grid=(int(nxo), 1, int(nb_obj)))

        if True:  # Use atomic operations for object update
            regmax = pu.cu_max_red(p._cu_obj_norm, allocator=pu.cu_mem_pool.allocate)
            if p._cu_obj_zero_phase_mask is None:
                pu.cu_obj_norm(p._cu_obj_norm, p._cu_obj_new, p._cu_obj, regmax, self.inertia, nxyo, nb_obj)
            else:
                pu.cu_obj_norm_zero_phase_mask_n(p._cu_obj_norm, p._cu_obj_new, p._cu_obj, p._cu_obj_zero_phase_mask,
                                                 regmax, self.inertia, nxyo, nb_obj,
                                                 pu.cu_stack_size)

        else:
            pu.cu_sum_n_norm(p._cu_obj_normN[0], pu.cu_stack_size, nxyo)

            regmax = pu.cu_max_red(p._cu_obj_normN[0], allocator=pu.cu_mem_pool.allocate)

            if p._cu_obj_zero_phase_mask is None:
                pu.cu_obj_norm_n(p._cu_obj_normN[0], p._cu_obj_newN, p._cu_obj, regmax, self.inertia, nxyo, nb_obj,
                                 pu.cu_stack_size)
            else:
                pu.cu_obj_norm_zero_phase_mask_n(p._cu_obj_normN[0], p._cu_obj_newN, p._cu_obj,
                                                 p._cu_obj_zero_phase_mask, regmax, self.inertia, nxyo, nb_obj,
                                                 pu.cu_stack_size)

        # Clean up ?
        # del p._cu_obj_normN, p._cu_obj_newN

        return p


class Psi2ProbeMerge(CUOperatorPtycho):
    """
    Call this when all stack of probe positions have been processed, and the final update of the probe can
    be calculated. Temporary arrays are cleaned up.
    """

    def __init__(self, inertia=1e-3, smooth_sigma=0):
        """
        :param inertia: a regularisation factor to set the object inertia.
        :param smooth_sigma: if > 0, the previous object array (used for inertia) will be convolved
                             by a gaussian with this sigma.
        """
        super(Psi2ProbeMerge, self).__init__()
        self.inertia = np.float32(inertia)
        self.smooth_sigma = np.float32(smooth_sigma)

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        nb_probe = np.int32(p._probe.shape[0])
        ny = np.int32(p._probe.shape[-2])
        nx = np.int32(p._probe.shape[-1])
        nxy = np.int32(nx * ny)

        if self.smooth_sigma > 8:
            pu.gauss_convol_complex_64x(p._cu_probe, self.smooth_sigma, nx, ny, nb_probe, block=(64, 1, 1),
                                        grid=(1, int(ny), int(nb_probe)))
            pu.gauss_convol_complex_64y(p._cu_probe, self.smooth_sigma, nx, ny, nb_probe, block=(1, 64, 1),
                                        grid=(int(nx), 1, int(nb_probe)))
        elif self.smooth_sigma > 4:
            pu.gauss_convol_complex_32x(p._cu_probe, self.smooth_sigma, nx, ny, nb_probe, block=(32, 1, 1),
                                        grid=(1, int(ny), int(nb_probe)))
            pu.gauss_convol_complex_32y(p._cu_probe, self.smooth_sigma, nx, ny, nb_probe, block=(1, 32, 1),
                                        grid=(int(nx), 1, int(nb_probe)))
        elif self.smooth_sigma > 0.1:
            pu.gauss_convol_complex_16x(p._cu_probe, self.smooth_sigma, nx, ny, nb_probe, block=(16, 1, 1),
                                        grid=(1, int(ny), int(nb_probe)))
            pu.gauss_convol_complex_16y(p._cu_probe, self.smooth_sigma, nx, ny, nb_probe, block=(1, 16, 1),
                                        grid=(int(nx), 1, int(nb_probe)))

        # Don't get() the max value, to avoid D2H memory transfer (about 80 us faster..)
        # reg = np.float32(float(cua.max(p._cu_probe_norm).get()) * self.reg)

        # Try not to use gpuarray.max(). It re-generates the kernel ? Tiny improvement
        # regmax = cua.max(p._cu_probe_norm)

        regmax = pu.cu_max_red(p._cu_probe_norm, allocator=pu.cu_mem_pool.allocate)

        pu.cu_obj_norm(p._cu_probe_norm, p._cu_probe_new, p._cu_probe, regmax, self.inertia, nxy, nb_probe)

        # Clean up ? No - there is a significant overhead
        # del p._cu_probe_norm, p._cu_probe_new
        return p


class AP(CUOperatorPtycho):
    """
    Perform a complete Alternating Projection cycle:
    - forward all object*probe views to Fourier space and apply the observed amplitude
    - back-project to object space and project onto (probe, object)
    - update background optionally
    """

    def __init__(self, update_object=True, update_probe=False, update_background=False, floating_intensity=False,
                 nb_cycle=1, calc_llk=False, show_obj_probe=False, fig_num=-1, obj_smooth_sigma=0, obj_inertia=0.01,
                 probe_smooth_sigma=0, probe_inertia=0.001):
        """

        :param update_object: update object ?
        :param update_probe: update probe ?
        :param update_background: update background ?
        :param floating_intensity: optimise floating intensity scale factor [TODO for CUDA operators]
        :param nb_cycle: number of cycles to perform. Equivalent to AP(...)**nb_cycle
        :param calc_llk: if True, calculate llk while in Fourier space. If a positive integer is given, llk will be
                         calculated every calc_llk cycle
        :param show_obj_probe: if a positive integer number N, the object & probe will be displayed every N cycle.
                               By default 0 (no plot)
        :param fig_num: the number of the figure to plot the object and probe, as for ShowObjProbe()
        :param obj_smooth_sigma: if > 0, the previous object array (used for inertia) will convoluted by a gaussian
                                 array of this standard deviation.
        :param obj_inertia: the updated object retains this relative amount of the previous object.
        :param probe_smooth_sigma: if > 0, the previous probe array (used for inertia) will convoluted by a gaussian
                                   array of this standard deviation.
        :param probe_inertia: the updated probe retains this relative amount of the previous probe.
        """
        super(AP, self).__init__()
        self.update_object = update_object
        self.update_probe = update_probe
        self.update_background = update_background
        self.floating_intensity = floating_intensity  # TODO
        self.nb_cycle = nb_cycle
        self.calc_llk = calc_llk
        self.show_obj_probe = show_obj_probe
        self.fig_num = fig_num
        self.obj_smooth_sigma = obj_smooth_sigma
        self.obj_inertia = obj_inertia
        self.probe_smooth_sigma = probe_smooth_sigma
        self.probe_inertia = probe_inertia

    def __pow__(self, n):
        """

        :param n: a strictly positive integer
        :return: a new DM operator with the number of cycles multiplied by n
        """
        assert isinstance(n, int) or isinstance(n, np.integer)
        return AP(update_object=self.update_object, update_probe=self.update_probe,
                  update_background=self.update_background, floating_intensity=self.floating_intensity,
                  nb_cycle=self.nb_cycle * n, calc_llk=self.calc_llk, show_obj_probe=self.show_obj_probe,
                  fig_num=self.fig_num, obj_smooth_sigma=self.obj_smooth_sigma, obj_inertia=self.obj_inertia,
                  probe_smooth_sigma=self.probe_smooth_sigma, probe_inertia=self.probe_inertia)

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        t0 = timeit.default_timer()
        ic_dt = 0
        if self.update_background:
            p._cu_vd = cua.empty_like(p._cu_cu_background)
            p._cu_vd2 = cua.empty_like(p._cu_background)
            p._cu_vz2 = cua.empty_like(p._cu_background)
            p._cu_vdz2 = cua.empty_like(p._cu_background)

        for ic in range(self.nb_cycle):
            calc_llk = False
            if self.calc_llk:
                if ic % self.calc_llk == 0 or ic == self.nb_cycle - 1:
                    calc_llk = True
            if self.update_background:
                pass  # TODO: update while in Fourier space
            ops = PropagateApplyAmplitude(calc_llk=calc_llk,
                                          update_background=self.update_background) * ObjProbe2Psi()
            if self.update_object:
                ops = Psi2Obj() * ops
            if self.update_probe:
                ops = Psi2Probe() * ops

            p = LoopStack(ops) * p

            if self.update_object:
                p = Psi2ObjMerge(smooth_sigma=self.obj_smooth_sigma, inertia=self.obj_inertia) * p
            if self.update_probe:
                p = Psi2ProbeMerge(smooth_sigma=self.probe_smooth_sigma, inertia=self.probe_inertia) * p
            if self.update_background:
                pu.cu_background_update(p._cu_background, p._cu_vd, p._cu_vd2, p._cu_vz2, p._cu_vdz2,
                                        np.int32(p.data.iobs.shape[0]))
                # print("Background sum: %10.2f %10.2f %10.2f %10.2f %10.2f" % (cla.sum(p._cu_background).get(),
                #                                                               cla.sum(p._cu_vd).get(),
                #                                                               cla.sum(p._cu_vd2).get(),
                #                                                               cla.sum(p._cu_vz2).get(),
                #                                                               cla.sum(p._cu_vdz2).get()))

            if calc_llk:
                # Average time/cycle over the last N cycles
                dt = (timeit.default_timer() - t0) / (ic - ic_dt + 1)
                ic_dt = ic + 1
                t0 = timeit.default_timer()

                p.update_history(mode='llk', update_obj=self.update_object, update_probe=self.update_probe,
                                 update_background=self.update_background, update_pos=False, dt=dt, algorithm='AP',
                                 verbose=True)
            else:
                p.history.insert(p.cycle, update_obj=self.update_object, update_probe=self.update_probe,
                                 update_background=self.update_background, update_pos=False, algorithm='AP',
                                 verbose=False)
            if self.show_obj_probe:
                if ic % self.show_obj_probe == 0 or ic == self.nb_cycle - 1:
                    s = algo_string('AP', p, self.update_object, self.update_probe, self.update_background)
                    tit = "%s #%3d, LLKn(p)=%8.3f" % (s, ic, p.llk_poisson / p.nb_obs)
                    p = cpuop.ShowObjProbe(fig_num=self.fig_num, title=tit) * p
            p.cycle += 1
        if self.update_background:
            del p._cu_vd, p._cu_vd2, p._cu_vz2, p._cu_vdz2

        return p


class DM1(CUOperatorPtycho):
    """
    Equivalent to operator: 2 * ObjProbe2Psi() - 1
    """

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        i = p._cu_stack_i
        ny = np.int32(p._probe.shape[-2])
        nx = np.int32(p._probe.shape[-1])
        nb_probe = np.int32(p._probe.shape[0])
        nb_obj = np.int32(p._obj.shape[0])
        nyo = np.int32(p._obj.shape[-2])
        nxo = np.int32(p._obj.shape[-1])
        if p.data.near_field:
            f = 0
        else:
            f = np.float32(np.pi / (p.data.wavelength * p.data.detector_distance))
        pu.cu_2object_probe_psi_dm1(p._cu_psi[0, 0, 0], p._cu_obj, p._cu_probe,
                                    p._cu_obs_v[i].cu_x, p._cu_obs_v[i].cu_y,
                                    p.pixel_size_object, f,
                                    p._cu_obs_v[i].npsi, pu.cu_stack_size,
                                    nx, ny, nxo, nyo, nb_obj, nb_probe)
        return p


class DM2(CUOperatorPtycho):
    """
    # Psi(n+1) = Psi(n) - P*O + Psi_fourier

    This operator assumes that Psi_fourier is the current Psi, and that Psi(n) is in p._cu_psi_v

    On output Psi(n+1) is the current Psi, and Psi_fourier has been swapped to p._cu_psi_v
    """

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        i = p._cu_stack_i
        ny = np.int32(p._probe.shape[-2])
        nx = np.int32(p._probe.shape[-1])
        nb_probe = np.int32(p._probe.shape[0])
        nb_obj = np.int32(p._obj.shape[0])
        nyo = np.int32(p._obj.shape[-2])
        nxo = np.int32(p._obj.shape[-1])
        if p.data.near_field:
            f = 0
        else:
            f = np.float32(np.pi / (p.data.wavelength * p.data.detector_distance))
        # Swap p._cu_psi_v_copy = Psi(n) with p._cu_psi = Psi_fourier
        p._cu_psi_copy, p._cu_psi = p._cu_psi, p._cu_psi_copy
        pu.cu_2object_probe_psi_dm2(p._cu_psi[0, 0, 0], p._cu_psi_copy, p._cu_obj, p._cu_probe,
                                    p._cu_obs_v[i].cu_x, p._cu_obs_v[i].cu_y,
                                    p.pixel_size_object, f,
                                    p._cu_obs_v[i].npsi, pu.cu_stack_size,
                                    nx, ny, nxo, nyo, nb_obj, nb_probe)
        return p


class DM(CUOperatorPtycho):
    """
    Operator to perform a complete Difference Map cycle, updating the Psi views for all stack of frames,
    as well as updating the object and/or probe.
    """

    def __init__(self, update_object=True, update_probe=True, nb_cycle=1, calc_llk=False, show_obj_probe=False,
                 fig_num=-1, obj_smooth_sigma=0, obj_inertia=0.01, probe_smooth_sigma=0, probe_inertia=0.001,
                 center_probe_n=0, center_probe_max_shift=5, loop_obj_probe=1):
        """

        :param update_object: update object ?
        :param update_probe: update probe ?
        :param nb_cycle: number of cycles to perform. Equivalent to DM(...)**nb_cycle
        :param calc_llk: if True, calculate llk while in Fourier space. If a positive integer is given, llk will be
                         calculated every calc_llk cycle
        :param show_obj_probe: if a positive integer number N, the object & probe will be displayed every N cycle.
                               By default 0 (no plot)
        :param fig_num: the number of the figure to plot the object and probe, as for ShowObjProbe()
        :param obj_smooth_sigma: if > 0, the previous object array (used for inertia) will convoluted by a gaussian
                                 array of this standard deviation.
        :param obj_inertia: the updated object retains this relative amount of the previous object.
        :param probe_smooth_sigma: if > 0, the previous probe array (used for inertia) will convoluted by a gaussian
                                   array of this standard deviation.
        :param probe_inertia: the updated probe retains this relative amount of the previous probe.
        :param center_probe_n: test the probe every N cycle for deviation from the center. If deviation is larger
                               than center_probe_max_shift, probe and object are shifted to correct. If 0 (the default),
                               the probe centering is never calculated.
        :param center_probe_max_shift: maximum deviation from the center (in pixels) to trigger a position correction
        :param loop_obj_probe: when both object and probe are updated, it can be more stable to loop the object
                               and probe update for a more stable optimisation, but slower.
        """
        super(DM, self).__init__()
        self.nb_cycle = nb_cycle
        self.update_object = update_object
        self.update_probe = update_probe
        self.calc_llk = calc_llk
        self.show_obj_probe = show_obj_probe
        self.fig_num = fig_num
        self.obj_smooth_sigma = obj_smooth_sigma
        self.obj_inertia = obj_inertia
        self.probe_smooth_sigma = probe_smooth_sigma
        self.probe_inertia = probe_inertia
        self.center_probe_n = center_probe_n
        self.center_probe_max_shift = center_probe_max_shift
        self.loop_obj_probe = loop_obj_probe

    def __pow__(self, n):
        """

        :param n: a strictly positive integer
        :return: a new DM operator with the number of cycles multiplied by n
        """
        assert isinstance(n, int) or isinstance(n, np.integer)
        return DM(update_object=self.update_object, update_probe=self.update_probe, nb_cycle=self.nb_cycle * n,
                  calc_llk=self.calc_llk, show_obj_probe=self.show_obj_probe, fig_num=self.fig_num,
                  obj_smooth_sigma=self.obj_smooth_sigma, obj_inertia=self.obj_inertia,
                  probe_smooth_sigma=self.probe_smooth_sigma, probe_inertia=self.probe_inertia,
                  center_probe_n=self.center_probe_n, center_probe_max_shift=self.center_probe_max_shift,
                  loop_obj_probe=self.loop_obj_probe)

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        # First loop to get a starting Psi (note that all Psi are multiplied by the quadratic phase factor)
        p = LoopStack(ObjProbe2Psi(), keep_psi=True) * p

        # We could use instead of DM1 and DM2 operators:
        # op_dm1 = 2 * ObjProbe2Psi() - 1
        # op_dm2 = 1 - ObjProbe2Psi() + FourierApplyAmplitude() * op_dm1
        # But this would use 3 copies of the whole Psi stack - too much memory ?
        # TODO: check if memory usage would be that bad, or if it's possible the psi storage only applies
        # TODO: to the current psi array

        t0 = timeit.default_timer()
        ic_dt = 0
        for ic in range(self.nb_cycle):
            calc_llk = False
            if self.calc_llk:
                if ic % self.calc_llk == 0 or ic == self.nb_cycle - 1:
                    calc_llk = True

            if True:
                ops = DM2() * PropagateApplyAmplitude() * DM1()
                p = LoopStack(ops, keep_psi=True, copy=True) * p
                if True:
                    # Loop the object and probe update if both are done at the same time. Slow, more stable ?
                    nb_loop_update_obj_probe = 1
                    if self.update_probe and self.update_object:
                        nb_loop_update_obj_probe = self.loop_obj_probe

                    for i in range(nb_loop_update_obj_probe):
                        if self.update_object:
                            p = Psi2ObjMerge(smooth_sigma=self.obj_smooth_sigma,
                                             inertia=self.obj_inertia) * LoopStack(Psi2Obj(), keep_psi=True) * p
                        if self.update_probe:
                            p = Psi2ProbeMerge(smooth_sigma=self.probe_smooth_sigma,
                                               inertia=self.probe_inertia) * LoopStack(Psi2Probe(), keep_psi=True) * p
                else:
                    # TODO: updating probe and object at the same time does not work as in AP. Why ?
                    # Probably due to a scaling issue, as Psi is not a direct back-propagation but the result of DM2
                    ops = 1
                    if self.update_object:
                        ops = Psi2Obj() * ops
                    if self.update_probe:
                        ops = Psi2Probe() * ops

                    p = LoopStack(ops, keep_psi=True) * p

                    if self.update_object:
                        p = Psi2ObjMerge(smooth_sigma=self.obj_smooth_sigma, inertia=self.obj_inertia) * p
                    if self.update_probe:
                        p = Psi2ProbeMerge(smooth_sigma=self.probe_smooth_sigma, inertia=self.probe_inertia) * p
            else:
                # Update obj and probe immediately after back-propagation, before DM2 ?
                # Does not seem to give very good results
                ops = PropagateApplyAmplitude() * DM1()
                if self.update_object:
                    ops = Psi2Obj() * ops
                if self.update_probe:
                    ops = Psi2Probe() * ops

                p = LoopStack(DM2() * ops, keep_psi=True, copy=True) * p
                if self.update_object:
                    p = Psi2ObjMerge(smooth_sigma=self.obj_smooth_sigma, inertia=self.obj_inertia) * p
                if self.update_probe:
                    p = Psi2ProbeMerge(smooth_sigma=self.probe_smooth_sigma, inertia=self.probe_inertia) * p

            if self.center_probe_n > 0 and p.data.near_field is False:
                if (ic % self.center_probe_n) == 0:
                    p = CenterObjProbe(max_shift=self.center_probe_max_shift) * p
            if calc_llk:
                # Keep a copy of current Psi
                cu_psi0 = p._cu_psi.copy()
                # We need to perform a loop for LLK as the DM2 loop is on (2*PO-I), not the current PO estimate
                if p.data.near_field:
                    p = LoopStack(LLK() * PropagateNearField() * ObjProbe2Psi()) * p
                else:
                    p = LoopStack(LLK(scale=False) * FT(scale=False) * ObjProbe2Psi()) * p

                # Average time/cycle over the last N cycles
                dt = (timeit.default_timer() - t0) / (ic - ic_dt + 1)
                ic_dt = ic + 1
                t0 = timeit.default_timer()

                p.update_history(mode='llk', update_obj=self.update_object, update_probe=self.update_probe,
                                 update_background=False, update_pos=False, dt=dt, algorithm='DM',
                                 verbose=True)
                # TODO: find a   better place to do this rescaling, only useful to avoid obj/probe divergence
                p = ScaleObjProbe() * p
                # Restore correct Psi
                p._cu_psi = cu_psi0
            else:
                p.history.insert(p.cycle, update_obj=self.update_object, update_probe=self.update_probe,
                                 update_background=False, update_pos=False, algorithm='DM',
                                 verbose=False)

            if self.show_obj_probe:
                if ic % self.show_obj_probe == 0 or ic == self.nb_cycle - 1:
                    s = algo_string('DM', p, self.update_object, self.update_probe)
                    tit = "%s #%3d, LLKn(p)=%8.3f" % (s, ic, p.llk_poisson / p.nb_obs)
                    p = cpuop.ShowObjProbe(fig_num=self.fig_num, title=tit) * p
            p.cycle += 1

        # Free some memory
        p._cu_psi_v = {}
        gc.collect()
        return p


class _Grad(CUOperatorPtycho):
    """
    Operator to compute the object and/or probe and/or background gradient corresponding to the current stack.
    """

    def __init__(self, update_object=True, update_probe=False, update_background=False, calc_llk=False):
        """
        :param update_object: compute gradient for the object ?
        :param update_probe: compute gradient for the probe ?
        :param update_background: compute gradient for the background ?
        :param calc_llk: calculate llk while in Fourier space
        """
        super(_Grad, self).__init__()
        self.update_object = update_object
        self.update_probe = update_probe
        self.update_background = update_background
        self.calc_llk = calc_llk

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        i = p._cu_stack_i
        ny = np.int32(p._probe.shape[-2])
        nx = np.int32(p._probe.shape[-1])
        nb_probe = np.int32(p._probe.shape[0])
        nb_obj = np.int32(p._obj.shape[0])
        nb_mode = np.int32(nb_obj * nb_probe)
        nyo = np.int32(p._obj.shape[-2])
        nxo = np.int32(p._obj.shape[-1])
        first_pass = np.int8(i == 0)
        nb_psi = p._cu_obs_v[i].npsi
        nxy = np.int32(ny * nx)
        nxystack = np.int32(pu.cu_stack_size * nxy)
        hann_filter = np.int8(1)
        if p.data.near_field:
            f = 0
            hann_filter = np.int8(0)
        else:
            f = np.float32(np.pi / (p.data.wavelength * p.data.detector_distance))

        # Obj * Probe
        p = ObjProbe2Psi() * p

        s = np.float32(1)  # FFT scale, if needed
        # To detector plane
        if p.data.near_field:
            p = PropagateNearField() * p
        else:
            p = FT(scale=False) * p
            s = 1 / np.sqrt(p._cu_psi[0, 0, 0].size)  # Compensates for FFT scaling

        if self.calc_llk:
            p = LLK(scale=False) * p

        # Calculate Psi.conj() * (1-Iobs/I_calc) [for Poisson Gradient)
        # TODO: different noise models
        pu.cu_grad_poisson_fourier(p._cu_obs_v[i].cu_obs[:nb_psi], p._cu_psi, p._cu_background,
                                   nb_mode, nx, ny, nxy, nxystack, hann_filter, 1, s ** 2)
        if p.data.near_field:
            p = PropagateNearField(forward=False) * p
        else:
            p = IFT(scale=False) * p

        if self.update_object:
            if True:  # TODO: this is slower, but yields better results ?
                for ii in range(p._cu_obs_v[i].npsi):
                    pu.cu_psi_to_obj_grad(p._cu_psi[0, 0, ii], p._cu_obj_grad, p._cu_probe,
                                          p._cu_obs_v[i].x[ii], p._cu_obs_v[i].y[ii],
                                          p.pixel_size_object, f, pu.cu_stack_size,
                                          nx, ny, nxo, nyo, nb_obj, nb_probe)
            else:
                # Use atomic operations to avoid looping over frames !
                pu.cu_psi_to_obj_grad_atomic(p._cu_psi[0, 0, 0], p._cu_obj_grad, p._cu_probe, p._cu_obs_v[i].cu_x,
                                             p._cu_obs_v[i].cu_y, p.pixel_size_object, f, pu.cu_stack_size, nx, ny,
                                             nxo, nyo, nb_obj, nb_probe, nb_psi)
        if self.update_probe:
            pu.cu_psi_to_probe_grad(p._cu_psi[0, 0, 0], p._cu_probe_grad, p._cu_obj,
                                    p._cu_obs_v[i].cu_x, p._cu_obs_v[i].cu_y,
                                    p.pixel_size_object, f, first_pass,
                                    nb_psi, pu.cu_stack_size,
                                    nx, ny, nxo, nyo, nb_obj, nb_probe)
        if self.update_background:
            # TODO
            pass
        return p


class Grad(CUOperatorPtycho):
    """
    Operator to compute the object and/or probe and/or background gradient. The gradient is stored
    in the ptycho object. It is assumed that the GPU gradient arrays have been already created, normally
    by the calling ML operator.
    """

    def __init__(self, update_object=True, update_probe=False, update_background=False,
                 reg_fac_obj=0, reg_fac_probe=0, calc_llk=False):
        """

        :param update_object: compute gradient for the object ?
        :param update_probe: compute gradient for the probe ?
        :param update_background: compute gradient for the background ?
        :param calc_llk: calculate llk while in Fourier space
        """
        super(Grad, self).__init__()
        self.update_object = update_object
        self.update_probe = update_probe
        self.update_background = update_background
        self.calc_llk = calc_llk
        self.reg_fac_obj = reg_fac_obj
        self.reg_fac_probe = reg_fac_probe

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        if self.update_object:
            p._cu_obj_grad.fill(np.complex64(0))

        p = LoopStack(_Grad(update_object=self.update_object, update_probe=self.update_probe,
                            update_background=self.update_background, calc_llk=self.calc_llk)) * p

        if self.reg_fac_obj is not None:
            reg_fac_obj = np.float32(p.reg_fac_scale_obj * self.reg_fac_obj)
        else:
            reg_fac_obj = 0
        if self.reg_fac_probe is not None:
            reg_fac_probe = np.float32(p.reg_fac_scale_probe * self.reg_fac_probe)
        else:
            reg_fac_probe = 0
        ny = np.int32(p._probe.shape[-2])
        nx = np.int32(p._probe.shape[-1])
        nyo = np.int32(p._obj.shape[-2])
        nxo = np.int32(p._obj.shape[-1])
        pu = self.processing_unit

        if self.update_object and reg_fac_obj > 0:
            # Regularisation contribution to the object gradient
            pu.cu_reg_grad(p._cu_obj_grad, p._cu_obj, reg_fac_obj, nxo, nyo)

        if self.update_probe and reg_fac_probe > 0:
            # Regularisation contribution to the probe gradient
            pu.cu_reg_grad(p._cu_probe_grad, p._cu_probe, reg_fac_probe, nx, ny)

        return p


class _CGGamma(CUOperatorPtycho):
    """
    Operator to compute the conjugate gradient gamma contribution to the current stack.
    """

    def __init__(self, update_background=False):
        """
        :param update_background: if updating the background ?
        """
        super(_CGGamma, self).__init__()
        self.update_background = update_background
        # TODO: fix this scale dynamically ? Used to avoid overflows
        self.gamma_scale = np.float32(1e-5)

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        i = p._cu_stack_i
        ny = np.int32(p._probe.shape[-2])
        nx = np.int32(p._probe.shape[-1])
        nb_probe = np.int32(p._probe.shape[0])
        nb_obj = np.int32(p._obj.shape[0])
        nb_mode = np.int32(nb_obj * nb_probe)
        nyo = np.int32(p._obj.shape[-2])
        nxo = np.int32(p._obj.shape[-1])
        nb_psi = p._cu_obs_v[i].npsi
        nxy = np.int32(ny * nx)
        nxystack = np.int32(pu.cu_stack_size * nxy)
        scale_fft = np.float32(1)
        if p.data.near_field:
            f = 0
        else:
            f = np.float32(np.pi / (p.data.wavelength * p.data.detector_distance))
            # scale_fft = 1 / np.sqrt(p._cu_psi[0, 0, 0].size)

        for cupsi, cuobj, cuprobe in zip([p._cu_PO, p._cu_PdO, p._cu_dPO, p._cu_dPdO],
                                         [p._cu_obj, p._cu_obj_dir, p._cu_obj, p._cu_obj_dir],
                                         [p._cu_probe, p._cu_probe, p._cu_probe_dir, p._cu_probe_dir]):

            pu.cu_object_probe_mult(cupsi[0, 0, 0], cuobj, cuprobe,
                                    p._cu_obs_v[i].cu_x, p._cu_obs_v[i].cu_y,
                                    p.pixel_size_object, f,
                                    nb_psi, pu.cu_stack_size,
                                    nx, ny, nxo, nyo, nb_obj, nb_probe)
            # switch cupsi and p._cu_psi for propagation
            cupsi, p._cu_psi = p._cu_psi, cupsi
            if p.data.near_field:
                p = PropagateNearField(forward=True) * p
            else:
                # Don't use scale here, but use scale_fft in cg_poisson_gamma_red kernel
                p = FT(scale=False) * p

        # TODO: take into account background
        tmp = self.processing_unit._cu_cg_poisson_gamma_red(p._cu_obs_v[i].cu_obs[:nb_psi], p._cu_PO, p._cu_PdO,
                                                            p._cu_dPO, p._cu_dPdO, self.gamma_scale, scale_fft,
                                                            nxy, nxystack, nb_mode).get()
        if np.isnan(tmp.real + tmp.imag) or np.isinf(tmp.real + tmp.imag):
            nP = self.processing_unit.cu_norm_complex_n(p._cu_probe, 2).get()
            nO = self.processing_unit.cu_norm_complex_n(p._cu_obj, 2).get()
            ndP = self.processing_unit.cu_norm_complex_n(p._cu_probe_dir, 2).get()
            ndO = self.processing_unit.cu_norm_complex_n(p._cu_obj_dir, 2).get()
            nPO = self.processing_unit.cu_norm_complex_n(p._cu_PO, 2).get()
            ndPO = self.processing_unit.cu_norm_complex_n(p._cu_dPO, 2).get()
            nPdO = self.processing_unit.cu_norm_complex_n(p._cu_PdO, 2).get()
            ndPdO = self.processing_unit.cu_norm_complex_n(p._cu_dPdO, 2).get()
            print('_CGGamma norms: P %e O %e dP %e dO %e PO %e, PdO %e, dPO %e, dPdO %e' % (
                nP, nO, ndP, ndO, nPO, ndPO, nPdO, ndPdO))
            print('_CGGamma (stack #%d, NaN Gamma:)' % i, tmp.real, tmp.imag)
            raise OperatorException("NaN")
        p._cu_cg_gamma_d += tmp.imag
        p._cu_cg_gamma_n += tmp.real
        if False:
            tmp = self.processing_unit._cu_cg_poisson_gamma4_red(p._cu_obs_v[i].cu_obs[:nb_psi], p._cu_PO,
                                                                 p._cu_PdO,
                                                                 p._cu_dPO, p._cu_dPdO, self.gamma_scale,
                                                                 nxy, nxystack, nb_mode).get()
            p._cu_cg_gamma4 += np.array((tmp['d'], tmp['c'], tmp['b'], tmp['a'], 0))

        if self.update_background:
            # TODO: use a different kernel if there is a background gradient
            pass
        return p


class ML(CUOperatorPtycho):
    """
    Operator to perform a maximum-likelihood conjugate-gradient minimization.
    """

    def __init__(self, nb_cycle=1, update_object=True, update_probe=False, update_background=False,
                 floating_intensity=False, reg_fac_obj=0, reg_fac_probe=0, calc_llk=False, show_obj_probe=False,
                 fig_num=-1):
        """

        :param update_object: update object ?
        :param update_probe: update probe ?
        :param update_background: update background ?
        :param floating_intensity: optimise floating intensity scale factor [TODO for CUDA operators]
        :param reg_fac_obj: use this regularization factor for the object (if 0, no regularization)
        :param reg_fac_probe: use this regularization factor for the probe (if 0, no regularization)
        :param calc_llk: if True, calculate llk while in Fourier space. If a positive integer is given, llk will be
                         calculated every calc_llk cycle
        :param show_obj_probe: if a positive integer number N, the object & probe will be displayed every N cycle.
                               By default 0 (no plot)
        :param fig_num: the number of the figure to plot the object and probe, as for ShowObjProbe()
        """
        super(ML, self).__init__()
        self.nb_cycle = nb_cycle
        self.update_object = update_object
        self.update_probe = update_probe
        self.update_background = update_background
        self.reg_fac_obj = reg_fac_obj
        self.reg_fac_probe = reg_fac_probe
        self.calc_llk = calc_llk
        self.show_obj_probe = show_obj_probe
        self.fig_num = fig_num

    def __pow__(self, n):
        """

        :param n: a strictly positive integer
        :return: a new ML operator with the number of cycles multiplied by n
        """
        assert isinstance(n, int) or isinstance(n, np.integer)
        return ML(nb_cycle=self.nb_cycle * n, update_object=self.update_object, update_probe=self.update_probe,
                  update_background=self.update_background, reg_fac_obj=self.reg_fac_obj,
                  reg_fac_probe=self.reg_fac_probe, calc_llk=self.calc_llk, show_obj_probe=self.show_obj_probe,
                  fig_num=self.fig_num)

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        # First perform an AP cycle to make sure object and probe are properly scaled with respect to iobs
        p = AP(update_object=self.update_object, update_probe=self.update_probe,
               update_background=self.update_background) * p
        ny = np.int32(p._probe.shape[-2])
        nx = np.int32(p._probe.shape[-1])
        nb_probe = np.int32(p._probe.shape[0])
        nb_obj = np.int32(p._obj.shape[0])
        nyo = np.int32(p._obj.shape[-2])
        nxo = np.int32(p._obj.shape[-1])
        stack_size = pu.cu_stack_size

        # Create the necessary GPU arrays for ML
        p._cu_PO = cua.empty_like(p._cu_psi)
        p._cu_PdO = cua.empty_like(p._cu_psi)
        p._cu_dPO = cua.empty_like(p._cu_psi)
        p._cu_dPdO = cua.empty_like(p._cu_psi)
        p._cu_obj_dir = cua.zeros((nb_obj, nyo, nxo), np.complex64, allocator=pu.cu_mem_pool.allocate)
        p._cu_probe_dir = cua.zeros((nb_probe, ny, nx), np.complex64, allocator=pu.cu_mem_pool.allocate)
        if self.update_object:
            p._cu_obj_grad = cua.empty_like(p._cu_obj)
            p._cu_obj_grad_last = cua.empty_like(p._cu_obj)
        if self.update_probe:
            p._cu_probe_grad = cua.empty_like(p._cu_probe)
            p._cu_probe_grad_last = cua.empty_like(p._cu_probe)
        if self.update_background:
            p._cu_background_grad = cua.zeros((ny, nx), np.float32, allocator=pu.cu_mem_pool.allocate)
            p._cu_background_grad_last = cua.zeros((ny, nx), np.float32, allocator=pu.cu_mem_pool.allocate)
            p._cu_background_dir = cua.zeros((ny, nx), np.float32, allocator=pu.cu_mem_pool.allocate)

        t0 = timeit.default_timer()
        ic_dt = 0
        for ic in range(self.nb_cycle):
            calc_llk = False
            if self.calc_llk:
                if ic % self.calc_llk == 0 or ic == self.nb_cycle - 1:
                    calc_llk = True

            # Swap gradient arrays - for CG, we need the previous gradient
            if self.update_object:
                p._cu_obj_grad, p._cu_obj_grad_last = p._cu_obj_grad_last, p._cu_obj_grad
            if self.update_probe:
                p._cu_probe_grad, p._cu_probe_grad_last = p._cu_probe_grad_last, p._cu_probe_grad
            if self.update_background:
                p._cu_background_grad, p._cu_background_grad_last = p._cu_background_grad_last, p._cu_background_grad

            # 1) Compute the gradients
            p = Grad(update_object=self.update_object, update_probe=self.update_probe,
                     update_background=self.update_background,
                     reg_fac_obj=self.reg_fac_obj, reg_fac_probe=self.reg_fac_probe, calc_llk=calc_llk) * p

            # 2) Search direction
            beta = np.float32(0)
            if ic == 0:
                # first cycle
                if self.update_object:
                    cu_drv.memcpy_dtod(src=p._cu_obj_grad.gpudata, dest=p._cu_obj_dir.gpudata,
                                       size=p._cu_obj_dir.nbytes)
                if self.update_probe:
                    cu_drv.memcpy_dtod(src=p._cu_probe_grad.gpudata, dest=p._cu_probe_dir.gpudata,
                                       size=p._cu_probe_dir.nbytes)
                if self.update_background:
                    cu_drv.memcpy_dtod(src=p._cu_background_grad.data, dest=p._cu_background_dir.data,
                                       size=p._cu_background_dir.nbytes)
            else:
                beta_d, beta_n = 0, 0
                # Polak-Ribière CG coefficient
                cg_pr = pu.cu_cg_polak_ribiere_red
                if self.update_object:
                    tmp = cg_pr(p._cu_obj_grad, p._cu_obj_grad_last).get()
                    beta_n += tmp.real
                    beta_d += tmp.imag
                if self.update_probe:
                    tmp = cg_pr(p._cu_probe_grad, p._cu_probe_grad_last).get()
                    beta_n += tmp.real
                    beta_d += tmp.imag
                if self.update_background:
                    tmp = cg_pr(p._cu_background_grad, p._cu_background_grad_last).get()
                    beta_n += tmp.real
                    beta_d += tmp.imag
                # print("Beta= %e / %e"%(beta_n, beta_d))
                # Reset direction if beta<0 => beta=0
                beta = np.float32(max(0, beta_n / max(1e-20, beta_d)))
                if np.isnan(beta_n + beta_d) or np.isinf(beta_n + beta_d):
                    raise OperatorException("NaN")
                if self.update_object:
                    pu.cu_linear_comb_fcfc(beta, p._cu_obj_dir, np.float32(-1), p._cu_obj_grad)
                if self.update_probe:
                    pu.cu_linear_comb_fcfc(beta, p._cu_probe_dir, np.float32(-1), p._cu_probe_grad)
                if self.update_background:
                    pu.cu_linear_comb_4f(beta, p._cu_background_dir, np.float32(-1), p._cu_background_grad)

            # 3) Line minimization
            p._cu_cg_gamma_d, p._cu_cg_gamma_n = 0, 0
            if False:
                # We could use a 4th order LLK(gamma) approximation, but it does not seem to improve
                p._cu_cg_gamma4 = np.zeros(5, dtype=np.float32)

            p = LoopStack(_CGGamma(update_background=self.update_background)) * p

            if np.isnan(p._cu_cg_gamma_d + p._cu_cg_gamma_n or np.isinf(p._cu_cg_gamma_d + p._cu_cg_gamma_n)):
                raise OperatorException("NaN")

            if self.update_object and self.reg_fac_obj != 0 and self.reg_fac_obj is not None:
                reg_fac_obj = np.float32(p.reg_fac_scale_obj * self.reg_fac_obj)
                nyo = np.int32(p._obj.shape[-2])
                nxo = np.int32(p._obj.shape[-1])
                tmp = self.processing_unit._cu_cg_gamma_reg_red(p._cu_obj, p._cu_obj_dir, nxo, nyo).get()
                # TODO: remove need to create a _CGGamma() operator to get the scale factor
                p._cu_cg_gamma_d += tmp.imag * reg_fac_obj * _CGGamma().gamma_scale
                p._cu_cg_gamma_n += tmp.real * reg_fac_obj * _CGGamma().gamma_scale

            if self.update_probe and self.reg_fac_probe != 0 and self.reg_fac_probe is not None:
                reg_fac_probe = np.float32(p.reg_fac_scale_probe * self.reg_fac_probe)
                ny = np.int32(p._probe.shape[-2])
                nx = np.int32(p._probe.shape[-1])
                tmp = self.processing_unit._cu_cg_gamma_reg_red(p._cu_probe, p._cu_probe_dir, nx, ny).get()
                p._cu_cg_gamma_d += tmp.imag * reg_fac_probe * _CGGamma().gamma_scale
                p._cu_cg_gamma_n += tmp.real * reg_fac_probe * _CGGamma().gamma_scale

            if np.isnan(p._cu_cg_gamma_d + p._cu_cg_gamma_n) or np.isinf(p._cu_cg_gamma_d + p._cu_cg_gamma_n):
                print("Gamma = NaN ! :", p._cu_cg_gamma_d, p._cu_cg_gamma_n)
            gamma = np.float32(p._cu_cg_gamma_n / p._cu_cg_gamma_d)
            if False:
                # It seems the 2nd order gamma approximation is good enough.
                gr = np.roots(p._cu_cg_gamma4)
                print("CG Gamma4", p._cu_cg_gamma4, "\n", gr, np.polyval(p._cu_cg_gamma4, gr))
                print("CG Gamma2=", gamma, "=", p._cu_cg_gamma_n, "/", p._cu_cg_gamma_d)

            # 4) Object and/or probe and/or background update
            if self.update_object:
                pu.cu_linear_comb_fcfc(np.float32(1), p._cu_obj, gamma, p._cu_obj_dir)

            if self.update_probe:
                pu.cu_linear_comb_fcfc(np.float32(1), p._cu_probe, gamma, p._cu_probe_dir)

            if self.update_background:
                pu.cu_linear_comb_4f(np.float32(1), p._cu_background, gamma, p._cu_background_dir)

            if calc_llk:
                # Average time/cycle over the last N cycles
                dt = (timeit.default_timer() - t0) / (ic - ic_dt + 1)
                ic_dt = ic + 1
                t0 = timeit.default_timer()

                p.update_history(mode='llk', update_obj=self.update_object, update_probe=self.update_probe,
                                 update_background=self.update_background, update_pos=False, dt=dt, algorithm='ML',
                                 verbose=True)
            else:
                p.history.insert(p.cycle, update_obj=self.update_object, update_probe=self.update_probe,
                                 update_background=self.update_background, update_pos=False, algorithm='ML',
                                 verbose=False)

            if self.show_obj_probe:
                if ic % self.show_obj_probe == 0 or ic == self.nb_cycle - 1:
                    s = algo_string('ML', p, self.update_object, self.update_probe, self.update_background)
                    tit = "%s #%3d, LLKn(p)=%8.3f" % (s, ic, p.llk_poisson / p.nb_obs)
                    p = cpuop.ShowObjProbe(fig_num=self.fig_num, title=tit) * p
            p.cycle += 1

        # Clean up
        del p._cu_PO, p._cu_PdO, p._cu_dPO, p._cu_dPdO, p._cu_obj_dir, p._cu_probe_dir
        if self.update_object:
            del p._cu_obj_grad, p._cu_obj_grad_last
        if self.update_probe:
            del p._cu_probe_grad, p._cu_probe_grad_last
        if self.update_background:
            del p._cu_background_grad, p._cu_background_grad_last, p._cu_background_dir

        gc.collect()

        return p


class ScaleObjProbe(CUOperatorPtycho):
    """
    Operator to scale the object and probe so that they have the same magnitude, and that the product of object*probe
    matches the observed intensity (i.e. sum(abs(obj*probe)**2) = sum(iobs))
    """

    def __init__(self, verbose=False):
        """

        :param verbose: print deviation if verbose=True
        """
        super(ScaleObjProbe, self).__init__()
        self.verbose = verbose

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        if True:
            # Compute the best scale factor
            snum, sden = 0, 0
            nxystack = np.int32(p._probe.shape[-1] * p._probe.shape[-2] * self.processing_unit.cu_stack_size)
            nb_mode = np.int32(p._probe.shape[0] * p._obj.shape[0])
            for i in range(p._cu_stack_nb):
                p = ObjProbe2Psi() * SelectStack(i) * p
                if p.data.near_field:
                    p = PropagateNearField(forward=True) * p
                else:
                    p = FT(scale=False) * p
                nb_psi = p._cu_obs_v[i].npsi
                r = pu.cu_scale_intensity(p._cu_obs_v[i].cu_obs[:nb_psi], p._cu_psi, nxystack, nb_mode).get()
                snum += r.real
                sden += r.imag
            s = np.sqrt(snum / sden)
            #if not p.data.near_field:
            #    # print("ScaleObjProbe: not near field, compensate FFT scaling")
            #    s *= np.sqrt(p._cu_psi[0, 0, 0].size)  # Compensate for FFT scaling
        else:
            nb_photons_obs = p.data.iobs_sum
            nb_photons_calc = 0
            for i in range(p._cu_stack_nb):
                p = ObjProbe2Psi() * SelectStack(i) * p
                nb_photons_calc += self.processing_unit.cu_norm_complex_n(p._cu_psi, 2).get()
            if p.data.near_field:
                s = np.sqrt(nb_photons_obs / nb_photons_calc)
            else:
                s = np.sqrt(nb_photons_obs / nb_photons_calc) / np.sqrt(p._cu_obj.size)
        os = self.processing_unit.cu_norm_complex_n(p._cu_obj, np.int32(1)).get()
        ps = self.processing_unit.cu_norm_complex_n(p._cu_probe, np.int32(1)).get()
        pu.cu_scale(p._cu_probe, np.float32(np.sqrt(os / ps * s)))
        pu.cu_scale(p._cu_obj, np.float32(np.sqrt(ps / os * s)))
        if self.verbose:
            print("ScaleObjProbe:", ps, os, s, np.sqrt(os / ps * s), np.sqrt(ps / os * s))
        if False:
            # Check the scale factor
            snum, sden = 0, 0
            for i in range(p._cu_stack_nb):
                p = ObjProbe2Psi() * SelectStack(i) * p
                if p.data.near_field:
                    p = PropagateNearField(forward=True) * p
                else:
                    p = FT(scale=False) * p
                r = pu.cu_scale_intensity(p._cu_psi, p._cu_obs_v[i].cu_obs).get()
                snum += r.real
                sden += r.imag
            s = snum / sden
            print("ScaleObjProbe: now s=", s)
        return p


class CenterObjProbe(CUOperatorPtycho):
    """
    Operator to check the center of mass of the probe and shift both object and probe if necessary.
    """

    def __init__(self, max_shift=5, power=2, verbose=False):
        """

        :param max_shift: the maximum shift of the probe with respect to the center of the array, in pixels.
                          The probe and object are only translated if the shift is larger than this value.
        :param power: the center of mass is calculated on the amplitude of the array elevated at this power.
        :param verbose: print deviation if verbose=True
        """
        super(CenterObjProbe, self).__init__()
        self.max_shift = np.int32(max_shift)
        self.power = power
        self.verbose = verbose

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        pu = self.processing_unit
        nz, ny, nx = np.int32(p._probe.shape[0]), np.int32(p._probe.shape[1]), np.int32(p._probe.shape[2])
        nzo, nyo, nxo = np.int32(p._obj.shape[0]), np.int32(p._obj.shape[1]), np.int32(p._obj.shape[2])
        cm = pu.cu_center_mass_complex(p._cu_probe, nx, ny, nz, self.power).get()
        dx, dy, dz = cm['a'] / cm['d'] - nx / 2, cm['b'] / cm['d'] - ny / 2, cm['c'] / cm['d'] - nz / 2
        if self.verbose:
            print("CenterObjProbe(): center of mass deviation: dx=%6.2f   dy=%6.2f" % (dx, dy))
        if np.sqrt(dx ** 2 + dy ** 2) > self.max_shift:
            dx = np.int32(round(-dx))
            dy = np.int32(round(-dy))
            cu_obj = cua.empty_like(p._cu_obj)
            cu_probe = cua.empty_like(p._cu_probe)
            pu.cu_circular_shift(p._cu_probe, cu_probe, dx, dy, np.int32(0), nx, ny, nz)
            p._cu_probe = cu_probe
            pu.cu_circular_shift(p._cu_obj, cu_obj, dx, dy, np.int32(0), nxo, nyo, nzo)
            p._cu_obj = cu_obj
        return p


class SelectStack(CUOperatorPtycho):
    """
    Operator to select a stack of observed frames to work on. Note that once this operation has been applied,
    the new Psi value may be undefined (empty array), if no previous array existed.
    """

    def __init__(self, stack_i, keep_psi=False):
        """
        Select a new stack of frames, swapping data to store the last calculated psi array in the
        corresponding, ptycho object's _cu_psi_v[i] dictionary.

        What happens is:
        * keep_psi=False: only the stack index in p is changed (p._cu_stack_i=stack_i)

        * keep_psi=True: the previous psi is stored in p._cu_psi_v[p._cu_stack_i], the new psi is swapped
                                   with p._cu_psi_v[stack_i] if it exists, otherwise initialized as an empty array.

        :param stack_i: the stack index.
        :param keep_psi: if True, when switching between stacks, store and restore psi in p._cu_psi_v.
        """
        super(SelectStack, self).__init__()
        self.stack_i = stack_i
        self.keep_psi = keep_psi

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        if self.stack_i == p._cu_stack_i:
            if self.keep_psi and self.stack_i in p._cu_psi_v:
                # This can happen if we use LoopStack(keep_psi=False) between LoopStack(keep_psi=True)
                p._cu_psi = p._cu_psi_v[self.stack_i].pop()
            return p

        if self.keep_psi:
            # Store previous Psi. This can be dangerous when starting a loop as the state of Psi may be incorrect,
            # e.g. in detector or sample space when the desired operations work in a different space...
            p._cu_psi_v[p._cu_stack_i] = p._cu_psi
            if self.stack_i in p._cu_psi_v:
                p._cu_psi = p._cu_psi_v.pop(self.stack_i)
            else:
                p._cu_psi = cua.empty_like(p._cu_psi_v[p._cu_stack_i])

        p._cu_stack_i = self.stack_i
        return p


class PurgeStacks(CUOperatorPtycho):
    """
    Operator to delete stored psi stacks in a Ptycho object's _cu_psi_v.

    This should be called for each main operator using LoopStack(), once it is finished processing, in order to avoid
    having another operator using the stored stacks, and to free memory.
    """

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        # First make sure processing is finished, as execution is asynchronous
        self.processing_unit.finish()
        p._cu_psi_v = {}
        return p


class LoopStack(CUOperatorPtycho):
    """
    Operator to apply a given operator sequentially to the complete stack of frames of a ptycho object.

    Make sure that the current selected stack is in a correct state (i.e. in sample or detector space,...) before
    starting such a loop with keep_psi=True.
    """

    def __init__(self, op, keep_psi=False, copy=False):
        """

        :param op: the operator to apply, which can be a multiplication of operators
        :param keep_psi: if True, when switching between stacks, store psi in p._cu_psi_v.
        :param copy: make a copy of the original p._cu_psi swapped in as p._cu_psi_copy, and
                     delete it after applying the operations. This is useful for operations requiring the previous
                     value.
        """
        super(LoopStack, self).__init__()
        self.stack_op = op
        self.keep_psi = keep_psi
        self.copy = copy

    def op(self, p: Ptycho):
        """

        :param p: the Ptycho object this operator applies to
        :return: the updated Ptycho object
        """
        if p._cu_stack_nb == 1:
            if self.copy:
                p._cu_psi_copy = cua.empty_like(p._cu_psi)

                cu_drv.memcpy_dtod(src=p._cu_psi.gpudata, dest=p._cu_psi_copy.gpudata, size=p._cu_psi.nbytes)
                p = self.stack_op * p

                if has_attr_not_none(p, '_cu_psi_copy'):
                    # Finished using psi copy, delete it (actual deletion will occur once GPU has finished processing)
                    p._cu_psi_copy.gpudata.free()
                    del p._cu_psi_copy
                return p
            else:
                return self.stack_op * p
        else:
            if self.copy:
                p._cu_psi_copy = cua.empty_like(p._cu_psi)

            for i in range(p._cu_stack_nb):
                p = SelectStack(i, keep_psi=self.keep_psi) * p
                if self.copy:
                    # The planned operations rely on keeping a copy of the previous Psi state...
                    cu_drv.memcpy_dtod(src=p._cu_psi.gpudata, dest=p._cu_psi_copy.gpudata, size=p._cu_psi.nbytes)
                p = self.stack_op * p

            if self.copy and has_attr_not_none(p, '_cu_psi_copy'):
                # Finished using psi copy, delete it (actual deletion will occur once GPU has finished processing)
                p._cu_psi_copy.gpudata.free()
                del p._cu_psi_copy

            if self.keep_psi:
                # Copy last stack to p._cu_psi_v
                p._cu_psi_v[p._cu_stack_i] = cua.empty_like(p._cu_psi)
                cu_drv.memcpy_dtod(src=p._cu_psi.gpudata, dest=p._cu_psi_v[p._cu_stack_i].gpudata,
                                   size=p._cu_psi.nbytes)
        return p
