# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

from __future__ import division

import time
import timeit
import gc

import gpyfft
import numpy as np
import pyopencl as cl
import pyopencl.array as cla
import pyopencl.reduction as clred
from pyopencl.elementwise import ElementwiseKernel as CL_ElK
from pyopencl.reduction import ReductionKernel as CL_RedK
from numpy import pi
from numpy.fft import fftshift
from scipy.signal import medfilt2d
from scipy.spatial import ConvexHull
from scipy.special import gammaln

from pynx.processing_unit.opencl_device import cl_device_fft_speed
from pynx.processing_unit.kernel_source import get_kernel_source as getks
from pynx.processing_unit.opencl_worker_thread import CLWorkerThread
from pynx.ptycho.shape import get_view_coord
from pynx.utils.history import History
from pynx.utils.plot_utils import show_obj_probe

if cl.VERSION==(2013,2):
    # Workaround bug...
    cl._cl.enqueue_fill_buffer = cl._cl._enqueue_fill_buffer


class CLPtycho2DWorker(CLWorkerThread):
    """
    2D Ptychographic class doing most of the computational work, but which may be working on part of the reconstructed object.
    All the computation requiring looping on observed data is performed in this object, but not the final update of
    object and probe if several worker threads are used.
    Instances of this class are created and called by Ptycho2D objects.
    """
    def __init__(self, dev, verbose, iobs, positions, probe, obj, nframes_total, mask=None, background=None, pixel_size_object=None, lambdaz=None,
                 is_first_worker=False):
        """
        
        :param dev: pyopencl.Device this worker will use
        :param verbose: if True, print messages
        :param iobs: observed intensity
        :param positions: positions of the probe for all the frames as a tuple of 1d vectors (posy, posx)
        :param probe: the starting estimate of the probe, as a 2D (or 3D with modes) numpy.complex64 array
        :param obj: the starting estimate of the object, as a 2D (or 3D with modes) numpy.complex64 array
        :param nframes_total: total number of frames (used for normalization)
        :param mask: the mask for the detector pixels (2D numpy, values different from zero are masked). Can be None.
        :param background: the background array (2D numpy data)
        :param pixel_size_object: object pixel size (meters)
        :param lambdaz: The wavelength*distance factor (SI units)
        :param is_first_worker: True if this is the first worker (performs more tasks)
        """
        super(CLPtycho2DWorker, self).__init__(dev, verbose)

        # We assume all passed data already has the proper type.
        self.iobs = iobs
        self.probe = probe.copy()
        self.obj = obj.copy()
        self.dy = positions[0]
        self.dx = positions[1]
        self.mask = mask
        self.background = background
        self.pixel_size_object = pixel_size_object
        self.lambdaz = lambdaz
        self._probe_mask = None

        # Parameters to be passed when submitting jobs
        self.update_object = None
        self.update_probe = None
        self.update_background = None
        self.update_positions = False
        self.calc_llk = None
        self.job_f = None # The job function
        self.work_alone = False
        self.reg_fac = None
        self.reg_fac_obj = None
        self.reg_fac_probe = None

        # Used for background update normalization
        self.nframes_total = np.int32(nframes_total)

        # The first worker is the one performing final merge updates, so this avoids some memory transfers
        self.is_first_worker = is_first_worker

        # Object/probe regularization/damping term as in: Marchesini et al, Inverse problems 29 (2013), 115009, eq (14)
        # TODO: supply it as a parameter
        self.reg_obj_probe = np.float32(1e-2)

    def init_cl(self):
        """
        Initialize all OpenCL kernels and arrays.
        
        Returns:

        """
        # TODO: check if each step of initialization is needed (already done or up-to-date ?)
        self._init_cl_kernels()
        self._init_cl_buffers()
        self._init_cl_vobs()
        self._init_cl_fft(timing=True)

    def check_init_cl(self):
        """
        Check OpenCL initialization, and perform it if needed
        :return: 
        """
        nobj, nprobe = 1, 1
        if self.obj.ndim == 3:
            nobj = len(self.obj)
        if self.probe.ndim == 3:
            nprobe = len(self.probe)
        if self.obj.shape != self._cl_obj.shape or self.probe.shape != self._cl_probe.shape or self._nobj != nobj or self._nprobe != nprobe:
            self._init_cl_kernels()
            self._init_cl_buffers()
            self._init_cl_fft(timing=False)

        if len(self.iobs) != len(self._cl_obs_v):
            # Observed frames were added ?
            self._init_cl_vobs()

    def _init_cl_kernels(self):
        """
        Initialize OpenCL kernels
        
        :return: 
        """
        if len(self.obj.shape)==2:
            self._nobj = 1
        else:
            self._nobj = self.obj.shape[0]
        if len(self.probe.shape)==2:
            self._nprobe = 1
        else:
            self._nprobe = self.probe.shape[0]

        self._nmode = np.int32(self._nobj * self._nprobe)

        self._cl_workgroup_size = 16
        if self.cl_ctx.devices[0].max_work_group_size < self._cl_workgroup_size:
            self._cl_workgroup_size = self.cl_ctx.devices[0].max_work_group_size
            print("WARNING: OpenCL device workgroup size is smaller than 16 !!!")

        # Workaround OpenCL intel Iris Pro wrong calculation of lgamma()...
        workaround = 0
        if self.cl_ctx.devices[0].name.find('Iris Pro') >= 0:
            workaround = 1
        nz = self._cl_workgroup_size
        ny, nx = self.iobs.shape[-2:]
        nyo, nxo = self.obj.shape[-2:]
        # kernel_params = {"block_size": self._cl_workgroup_size}
        options = "-cl-mad-enable -cl-fast-relaxed-math -DIRISPROBUG=%d "%(workaround)  # -cl-unsafe-math-optimizations
        options_mode = " -DBLOCKSIZE=%d -DNBOBJ=%d -DNBPROBE=%d -DNBMODE=%d -DNXY=%d -DNXYZ=%d -DNX=%d -DNY=%d -DNZ=%d -DNXO=%d -DNYO=%d -DNXYO=%d"%(
                        self._cl_workgroup_size, self._nobj, self._nprobe, self._nmode, ny*nx, ny*nx*nz, nx, ny, nz, nxo, nyo, nyo * nxo)

        # Compile all OpenCL programs & keep the kernels

        # These are not used
        #cl_apply_amplitude_prg = cl.Program(self.cl_ctx, getks('opencl/apply_amplitude.cl'), ).build(options=options)
        #self._cl_apply_amplitude = cl_apply_amplitude_prg.ApplyAmplitude

        cl_ml_poisson = cl.Program(self.cl_ctx, getks('ptycho/old/opencl/ptycho_ml_poisson.cl'), ).build(options=options+options_mode)
        self._cl_ml_poisson_psi_corr = cl_ml_poisson.PsiCorr
        self._cl_ml_poisson_psi_corr_mask = cl_ml_poisson.PsiCorrMask
        self._cl_ml_poisson_psi_corr_background = cl_ml_poisson.PsiCorrBackground
        self._cl_ml_poisson_psi_corr_mask_background = cl_ml_poisson.PsiCorrMaskBackground
        self._cl_ml_poisson_psi_corr_background_gradient = cl_ml_poisson.PsiCorrBackgroundGradient
        self._cl_ml_poisson_psi_corr_background_gradient_mask = cl_ml_poisson.PsiCorrBackgroundGradientMask
        self._cl_ml_poisson_obj_grad = cl_ml_poisson.ObjGrad
        self._cl_ml_poisson_obj_grad_quad_phase = cl_ml_poisson.ObjGradQuadPhase
        self._cl_ml_poisson_probe_grad = cl_ml_poisson.ProbeGrad
        self._cl_ml_poisson_probe_grad_quad_phase = cl_ml_poisson.ProbeGradQuadPhase
        self._cl_ml_poisson_sum_obj_grad = cl_ml_poisson.SumObjGradN
        self._cl_ml_poisson_sum_obj_grad_zero = cl_ml_poisson.SumObjGradNZero
        self._cl_ml_poisson_cg_linear_complex = cl_ml_poisson.CG_linear_complex
        self._cl_ml_poisson_cg_linear_float = cl_ml_poisson.CG_linear_float
        self._cl_ml_poisson_reg_grad = cl_ml_poisson.RegGrad

        cl_ptycho_proj = cl.Program(self.cl_ctx, getks('ptycho/old/opencl/ptycho_projection.cl'), ).build(options=options+options_mode)
        self._cl_ptycho_proj_object_probe_mult = cl_ptycho_proj.ObjectProbeMult
        self._cl_ptycho_proj_object_probe_mult_quad_phase = cl_ptycho_proj.ObjectProbeMultQuadPhase
        self._cl_ptycho_proj_apply_amplitude = cl_ptycho_proj.ApplyAmplitude
        self._cl_ptycho_proj_apply_amplitude_mask = cl_ptycho_proj.ApplyAmplitudeMask
        self._cl_ptycho_proj_apply_amplitude_background = cl_ptycho_proj.ApplyAmplitudeBackground
        self._cl_ptycho_proj_apply_amplitude_mask_background = cl_ptycho_proj.ApplyAmplitudeMaskBackground
        self._cl_ptycho_proj_update_obj = cl_ptycho_proj.UpdateObj
        self._cl_ptycho_proj_update_obj_quad_phase = cl_ptycho_proj.UpdateObjQuadPhase
        self._cl_ptycho_proj_update_probe = cl_ptycho_proj.UpdateProbe
        self._cl_ptycho_proj_update_probe_quad_phase = cl_ptycho_proj.UpdateProbeQuadPhase
        self._cl_ptycho_proj_sum_n = cl_ptycho_proj.SumN
        self._cl_ptycho_proj_sum_n_zero = cl_ptycho_proj.SumNZero
        self._cl_ptycho_proj_obj_norm = cl_ptycho_proj.ObjNorm
        self._cl_ptycho_proj_probe_norm_mask = cl_ptycho_proj.ProbeNormMask
        self._cl_ptycho_proj_probe_norm = cl_ptycho_proj.ProbeNorm

        cl_ptycho_dm = cl.Program(self.cl_ctx, getks('ptycho/old/opencl/ptycho_dm.cl'), ).build(options=options+options_mode)
        self._cl_dm_2po_psi = cl_ptycho_dm.ObjectProbeMult_Psi
        self._cl_dm_2po_psi_quad_phase = cl_ptycho_dm.ObjectProbeMult_PsiQuadPhase
        self._cl_dm_update_psi = cl_ptycho_dm.DM_UpdatePsi
        self._cl_dm_update_psi_quad_phase = cl_ptycho_dm.DM_UpdatePsiQuadPhase

        cl_ptycho_backgound = cl.Program(self.cl_ctx, getks('ptycho/old/opencl/ptycho_background.cl'), ).build(options=options+options_mode)
        self._cl_ptycho_background_loop = cl_ptycho_backgound.BackgroundLoop
        self._cl_ptycho_background_update = cl_ptycho_backgound.BackgroundUpdate
        self._cl_ptycho_background_update_mask = cl_ptycho_backgound.BackgroundUpdateMask

        # Reduction kernels
        # TODO: most reduction kernels have input arrays with different sizes, which may become unsupported  or not working in some future version
        # TODO: one clean way to proceed would be to supply the same number of elements for all input arrays using a slice (e.g. psi[0]),
        # TODO: and the called opencl function will take car of looping through the extra data (extra frames, extra modes)
        # TODO: side benefit: this would allow to avoid extra transfers of mask and background data
        # TODO: WARNING - this assumes that pyopencl does not do anything smart which would make unavailable the parts of the data out of the slice !
        self._cl_ml_poisson_cg_gamma_red = clred.ReductionKernel(self.cl_ctx, cl.array.vec.float2, neutral="(float2)(0,0)", reduce_expr="a+b",
            map_expr="Gamma(obs, PO, PdO, dPO, dPdO, i, scale, 0.0f, 0.0f, npsi)",
            preamble=getks('ptycho/old/opencl/ptycho_ml_poisson_cg_gamma_red.cl'), options=options+options_mode,
            arguments="__global float *obs, __global float2 *PO, __global float2 *PdO, __global float2 *dPO, __global float2 *dPdO, const float scale, const int npsi")

        self._cl_ml_poisson_cg_gamma_mask_red = clred.ReductionKernel(self.cl_ctx, cl.array.vec.float2, neutral="(float2)(0,0)", reduce_expr="a+b",
            map_expr="GammaMask(obs, PO, PdO, dPO, dPdO, mask, i, scale, npsi)",
            preamble=getks('ptycho/old/opencl/ptycho_ml_poisson_cg_gamma_red.cl'), options=options+options_mode,
            arguments="__global float *obs, __global float2 *PO, __global float2 *PdO, __global float2 *dPO, __global float2 *dPdO, __global char *mask, const float scale, const int npsi")

        self._cl_ml_poisson_cg_gamma_background_red = clred.ReductionKernel(self.cl_ctx, cl.array.vec.float2, neutral="(float2)(0,0)",
            reduce_expr="a+b", map_expr="GammaBackground(obs, PO, PdO, dPO, dPdO, background, i, scale, npsi)",
            preamble=getks('ptycho/old/opencl/ptycho_ml_poisson_cg_gamma_red.cl'), options=options+options_mode,
            arguments="__global float *obs, __global float2 *PO, __global float2 *PdO, __global float2 *dPO, __global float2 *dPdO, __global float2 *background, const float scale, const int npsi")

        self._cl_ml_poisson_cg_gamma_background_mask_red = clred.ReductionKernel(self.cl_ctx, cl.array.vec.float2, neutral="(float2)(0,0)",
            reduce_expr="a+b", map_expr="GammaBackgroundMask(obs, PO, PdO, dPO, dPdO, background, mask, i, scale, npsi)",
            preamble=getks('ptycho/old/opencl/ptycho_ml_poisson_cg_gamma_red.cl'), options=options+options_mode,
            arguments="__global float *obs, __global float2 *PO, __global float2 *PdO, __global float2 *dPO, __global float2 *dPdO, __global float2 *background, __global char *mask, const float scale, const int npsi")

        self._cl_ml_poisson_cg_gamma_background_grad_red = clred.ReductionKernel(self.cl_ctx, cl.array.vec.float2, neutral="(float2)(0,0)",
            reduce_expr="a+b", map_expr="GammaBackgroundGrad(obs, PO, PdO, dPO, dPdO, background, dbackground, i, scale, npsi)",
            preamble=getks('ptycho/old/opencl/ptycho_ml_poisson_cg_gamma_red.cl'), options=options+options_mode,
            arguments="__global float *obs, __global float2 *PO, __global float2 *PdO, __global float2 *dPO, __global float2 *dPdO, __global float2 *background, __global float2 *dbackground, const float scale, const int npsi")

        self._cl_ml_poisson_cg_gamma_background_grad_mask_red = clred.ReductionKernel(self.cl_ctx, cl.array.vec.float2, neutral="(float2)(0,0)",
            reduce_expr="a+b", map_expr="GammaBackgroundGradMask(obs, PO, PdO, dPO, dPdO, background, dbackground, mask, i, scale, npsi)",
            preamble=getks('ptycho/old/opencl/ptycho_ml_poisson_cg_gamma_red.cl'), options=options+options_mode,
            arguments="__global float *obs, __global float2 *PO, __global float2 *PdO, __global float2 *dPO, __global float2 *dPdO, __global float2 *background, __global float2 *dbackground, __global char *mask, const float scale, const int npsi")

        self._cl_llk_reg_red = clred.ReductionKernel(self.cl_ctx, np.float32, neutral="0", reduce_expr="a+b",
                                                     map_expr="LLKReg(v, i, nx, ny)",
                                                     preamble=getks('opencl/llk_reg_red.cl'),
                                                     arguments="__global float *v, const int nx, const int ny")

        self._cl_cg_gamma_reg_red = clred.ReductionKernel(self.cl_ctx, cl.array.vec.float2, neutral="(float2)(0,0)", reduce_expr="a+b",
                                                          map_expr="GammaReg(v, dv, i, nx, ny)",
                                                          preamble=getks('opencl/cg_gamma_reg_red.cl'),
                                                          arguments="__global float2 *v, __global float2 *dv, const int nx, const int ny")

        self._cl_cg_polak_ribiere_complex_red = clred.ReductionKernel(self.cl_ctx, cl.array.vec.float2, neutral="(float2)(0,0)", reduce_expr="a+b",
                                                                      map_expr="PolakRibiereComplex(grad[i], lastgrad[i])",
                                                                      preamble=getks('opencl/cg_polak_ribiere_red.cl'),
                                                                      arguments="__global float2 *grad, __global float2 *lastgrad")

        self._cl_cg_polak_ribiere_float_red = clred.ReductionKernel(self.cl_ctx, cl.array.vec.float2, neutral="(float2)(0,0)", reduce_expr="a+b",
                                                                    map_expr="PolakRibiereFloat(grad[i], lastgrad[i])",
                                                                    preamble=getks('opencl/cg_polak_ribiere_red.cl'),
                                                                    arguments="__global float *grad, __global float *lastgrad")

        self._cl_cg_fletcher_reeves_red = clred.ReductionKernel(self.cl_ctx, cl.array.vec.float2, neutral="(float2)(0,0)", reduce_expr="a+b",
                                                                map_expr="FletcherReeves(grad[i], lastgrad[i])",
                                                                preamble=getks('opencl/cg_fletcher_reeves_red.cl'),
                                                                arguments="__global float2 *grad, __global float2 *lastgrad")
        # TODO: for all llk reductions with mask or background, use nx*ny size to avoid transferring NZ times background and mask arrays.
        self._cl_llk_poisson_red = clred.ReductionKernel(self.cl_ctx, np.float32, neutral="0", reduce_expr="a+b",
                                                         map_expr="LLKPoisson(obs, psi, 0, i)",
                                                         preamble=getks('opencl/llk_poisson_red.cl'), options=options+options_mode,
                                                         arguments="__global float *obs, __global float2 *psi")

        self._cl_llk_poisson_mask_red = clred.ReductionKernel(self.cl_ctx, np.float32, neutral="0", reduce_expr="a+b",
                                                              map_expr="LLKPoissonMask(obs, psi, mask, i)",
                                                              preamble=getks('opencl/llk_poisson_red.cl'), options=options+options_mode,
                                                              arguments="__global float *obs, __global float2 *psi, __global char *mask")

        self._cl_llk_poisson_background_red = clred.ReductionKernel(self.cl_ctx, np.float32, neutral="0", reduce_expr="a+b",
                                                                    map_expr="LLKPoissonBackground(obs, psi, background, i)",
                                                                    preamble=getks('opencl/llk_poisson_red.cl'), options=options+options_mode,
                                                                    arguments="__global float *obs, __global float2 *psi, __global float *background")

        self._cl_llk_poisson_mask_background_red = clred.ReductionKernel(self.cl_ctx, np.float32, neutral="0", reduce_expr="a+b",
                                                                         map_expr="LLKPoissonMaskBackground(obs, psi, mask, background, i)",
                                                                         preamble=getks('opencl/llk_poisson_red.cl'), options=options+options_mode,
                                                                         arguments="__global float *obs, __global float2 *psi, __global char *mask, __global float *background")

        self._cl_llk_gaussian_red = clred.ReductionKernel(self.cl_ctx, np.float32, neutral="0", reduce_expr="a+b",
                                                          map_expr="Chi2(obs, psi, i)",
                                                          preamble=getks('opencl/llk_gaussian_red.cl'), options=options+options_mode,
                                                          arguments="__global float *obs, __global float2 *psi")

        self._cl_llk_gaussian_mask_red = clred.ReductionKernel(self.cl_ctx, np.float32, neutral="0", reduce_expr="a+b",
                                                               map_expr="Chi2Mask(obs, psi, mask, i)",
                                                               preamble=getks('opencl/llk_gaussian_red.cl'), options=options+options_mode,
                                                               arguments="__global float *obs, __global float2 *psi, __global char *mask")

        self.cl_scale = CL_ElK(self.cl_ctx, name='cl_scale',
                               operation="d[i] = (float2)(d[i].x * scale, d[i].y * scale )",
                               options=options, arguments="__global float2 *d, const float scale")
        if False:
            self._cl_arg_max = clred.ReductionKernel(self.cl_ctx, cl.array.vec.float2,
                                                     neutral="(float2)(0,0)", reduce_expr="ArgMax(a,b)",
                                                     map_expr="(float2)(length(d[i]), i)",
                                                     preamble=getks('opencl/argmax_red.cl'),
                                                     arguments="__global float *d")
            # Elementwise kernels
            self._cl_cross_power_spectrum = CL_ElK(self.cl_ctx, name='cl_cross_power_spectrum',
                                                   operation="CrossPowerSpectrum(i, psi_ref, psi)",
                                                   preamble=getks('ptycho/opencl/image_registration_elw.cl'),
                                                   options=options,
                                                   arguments="__global float2 *psi_ref, __global float2 *psi")


    def _init_cl_buffers(self):
        """
        Initialize OpenCL buffers for object, probe, and background (if any)
        
        :return: 
        """
        self.arrays_to_cl()

        # Psi = Obj x Probe 2D data used for each sample position, either in real or Fourier space
        self._nz = self._cl_workgroup_size  # Stack arrays for better efficiency (notably FFT)
        self._nobs, self._ny, self._nx = self.iobs.shape
        self._nyo, self._nxo = self.obj.shape[-2:]
        self._cl_psi = cl.array.zeros(self.cl_queue, (self._nz * self._nmode, self._ny, self._nx), np.complex64)

        # We need these if several threads are used
        self.obj_new_norm = np.empty((self._nyo, self._nxo), dtype=np.float32)
        self.obj_new = np.empty(self.obj.shape, dtype=np.complex64)
        self.probe_new_norm = np.empty((self._ny, self._nx), dtype=np.float32)
        self.probe_new = np.empty(self.probe.shape, dtype=np.complex64)
        self.background_d   = np.empty((self._ny, self._nx), dtype=np.float32)
        self.background_d2  = np.empty((self._ny, self._nx), dtype=np.float32)
        self.background_dz2 = np.empty((self._ny, self._nx), dtype=np.float32)
        self.background_z2  = np.empty((self._ny, self._nx), dtype=np.float32)

        # Temp arrays for ptycho background update
        self._cl_background_d   = cl.array.empty(self.cl_queue, (self._ny, self._nx), np.float32)
        self._cl_background_d2  = cl.array.empty(self.cl_queue, (self._ny, self._nx), np.float32)
        self._cl_background_dz2 = cl.array.empty(self.cl_queue, (self._ny, self._nx), np.float32)
        self._cl_background_z2  = cl.array.empty(self.cl_queue, (self._ny, self._nx), np.float32)

        # Newly computed object and probe - we try to keep only one obj*probe (in real or Fourier space) in memory at a time,
        # so we need separate arrays for new object and probe
        self._cl_obj_new = cl.array.empty(self.cl_queue, self._cl_obj.shape, np.complex64)
        # _cl_obj_newN must be filled with 0 originally, to avoid undefined values outside of the scanned area
        self._cl_obj_newN = cl.array.zeros(self.cl_queue, (self._nobj * self._nz, self._nyo, self._nxo), np.complex64)
        self._cl_probe_new = cl.array.empty(self.cl_queue, (self._nprobe, self._ny, self._nx), np.complex64)
        # Normalization array for object and probe update
        self._cl_obj_norm = cl.array.empty(self.cl_queue, (self._nyo, self._nxo), np.float32)
        # _cl_obj_normN must be filled with 0 originally, to avoid undefined values outside of the scanned area
        self._cl_obj_normN = cl.array.zeros(self.cl_queue, (self._nz, self._nyo, self._nxo), np.float32)
        self._cl_probe_norm = cl.array.empty(self.cl_queue, (self._ny, self._nx), np.float32)
        # Probe mask: can be used to avoid updating probe outside scan area
        if self._probe_mask is not None:
            self._cl_probe_mask = cl.array.to_device(self.cl_queue, self._probe_mask.astype(np.float32), async=False)

        # Buffers for ML
        self._cl_PO = cl.array.empty(self.cl_queue, self._cl_psi.shape, np.complex64)
        self._cl_PdO = cl.array.empty(self.cl_queue, self._cl_psi.shape, np.complex64)
        self._cl_dPO = cl.array.empty(self.cl_queue, self._cl_psi.shape, np.complex64)
        self._cl_dPdO = cl.array.empty(self.cl_queue, self._cl_psi.shape, np.complex64)

        # object, probe and background gradient and CG search direction
        self._cl_obj_grad = cl.array.zeros(self.cl_queue, self._cl_obj.shape, np.complex64)
        self._cl_obj_grad_last = cl.array.zeros(self.cl_queue, self._cl_obj.shape, np.complex64)
        self._cl_obj_gradN = cl.array.zeros(self.cl_queue, (self._nz * self._nmode, self._nyo, self._nxo), np.complex64)
        self._cl_obj_dir = cl.array.zeros(self.cl_queue, self._cl_obj.shape, np.complex64)

        self._cl_probe_grad = cl.array.zeros(self.cl_queue, self._cl_probe.shape, np.complex64)
        self._cl_probe_grad_last = cl.array.zeros(self.cl_queue, self._cl_probe.shape, np.complex64)
        self._cl_probe_dir = cl.array.zeros(self.cl_queue, self._cl_probe.shape, np.complex64)

        self._cl_background_grad = cl.array.zeros(self.cl_queue, (self._ny, self._nx), np.float32)
        self._cl_background_grad_last = cl.array.zeros(self.cl_queue, (self._ny, self._nx), np.float32)
        self._cl_background_dir = cl.array.zeros(self.cl_queue, (self._ny, self._nx), np.float32)

        # Buffer for image registration (only uses the first mode)
        self._cl_psi_registration = cl.array.empty(self.cl_queue, (self._nz, self._ny, self._nx), np.complex64)

    def _init_cl_vobs(self):
        """
        Initialize observed intensity and scan positions in OpenCL space
        :return: 
        """
        self._cl_obs_v = []
        for i in range(0, self._nobs, self._nz):
            vcxy = np.zeros((self._nz * 2), dtype=np.int32)
            vobs = np.zeros((self._nz, self._ny, self._nx), dtype=np.float32)
            if self._nobs < (i + self._nz):
                print("Number of frames is not a multiple of %d, adding %d null frames" % (self._nz, i + self._nz - self._nobs))
            for j in range(self._nz):
                ij = i + j
                if ij < self._nobs:
                    dx, dy = self.dx[ij], self.dy[ij]
                    cy, cx = get_view_coord((self._nyo, self._nxo), (self._ny, self._nx), dy, dx)
                    vcxy[j] = np.int32(round(cx))
                    vcxy[j + self._nz] = np.int32(round(cy))
                    vobs[j] = fftshift(self.iobs[ij])
                else:
                    vcxy[j] = vcxy[0]
                    vcxy[j + self._nz] = vcxy[0 + self._nz]
                    vobs[j] = np.zeros_like(vobs[0], dtype=np.float32)
            cl_vcxy = cl.array.to_device(self.cl_queue, vcxy, async=False)
            cl_vobs = cl.array.to_device(self.cl_queue, vobs, async=False)
            self._cl_obs_v.append([cl_vobs, cl_vcxy, i, np.int32(min(16, self._nobs - i))])

    def _init_cl_fft(self, timing=False):
        """
        Initialize gpyfft
        
        :param timing: if True, measure the FFT performance and print a message.
        :return: 
        """
        # TODO: check which axis order is best for the FFT.
        self._gpyfft_plan = gpyfft.FFT(self.cl_ctx, self.cl_queue, self._cl_psi, None, axes=(-1, -2))
        nbobs, ny, nx = self.iobs.shape
        nz = self._cl_workgroup_size  # Stack arrays for better efficiency (notably FFT)
        if timing:
            nb=200
            for ev in self._gpyfft_plan.enqueue(forward=True): ev.wait()
            for ev in self._gpyfft_plan.enqueue(forward=False): ev.wait()
            t00=timeit.default_timer()
            for i in range(nb):
                for ev in self._gpyfft_plan.enqueue(forward=True): ev.wait()
                for ev in self._gpyfft_plan.enqueue(forward=False): ev.wait()
            dt = (timeit.default_timer() - t00)/nb
            flop = 2 * 2 * nx * 5 * nx * np.log2(nx) * nz * self._nmode
            print("Pure FFT performance[nz=%dx%d, ny=%d, nx=%d]: dt=%8.6fs (2fft => %8.6fs/cycle for %dx%d frames), %7.2f Gflop/s"
                  % (nz, self._nmode, ny, nx, dt, dt*nbobs/(nz*self._nmode), self._nmode, self._nmode, flop / dt / 1e9))
            t00=timeit.default_timer()
        # Special plan for a single mode, used for registration
        self._gpyfft_plan_1mode = gpyfft.FFT(self.cl_ctx, self.cl_queue, self._cl_psi_registration, None, axes=(-1, -2))

    def arrays_to_cl(self):
        """
        Get data (obj, probe, iobs, mask, background to their respective OpenCL buffers).
        This is called by InitCL and should not normally be directly called unless a manual
        change of these arrays has been done.
        
        Returns:
            Nothing
        """
        self._cl_obj = cl.array.to_device(self.cl_queue, self.obj.astype(np.complex64), async=False)
        self._cl_probe = cl.array.to_device(self.cl_queue, self.probe.astype(np.complex64), async=False)
        if self.mask is not None:
            self._cl_mask = cl.array.to_device(self.cl_queue, fftshift(self.mask.astype(np.int8)), async=False)
        if self.background is not None:
            self._cl_background = cl.array.to_device(self.cl_queue, fftshift(self.background.astype(np.float32)), async=False)

    def arrays_from_cl(self):
        """
        Get updated data (obj, probe, background from their respective OpenCL buffers).
        This is called automatically at the end of main algorithm functions, and should not be directly called.
        
        Returns:
            Nothing. self.obj, self.probe, self.background are updated from OpenCL space
        """

        self.obj = self._cl_obj.get()
        self.probe = self._cl_probe.get()
        if self.background is not None:
            self.background = self._cl_background.get()

    def work(self):
        """

        Returns:

        """
        self.job_f()

    def _calc_forward(self, ev=None, cl_psi= None):
        """
        Forward calculation (object * probe multiplication, quadratic phase if possible, then FT) for the current group of images in self._cl_psi
        
        Args:
            ev: pyopencl events that should be waited for
            cl_psi: the opencl array in which Psi will be stored. If unspecified, self._cl_psi will be used.
        
        Returns:
            pyopencl event list that may need to complete before proceeding. self._cl_psi is updated
        """
        if cl_psi is None:
            cl_psi = self._cl_psi
        if self.lambdaz is None:
            ev = [self._cl_ptycho_proj_object_probe_mult(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                                       self._cl_obj.data, self._cl_probe.data, cl_psi.data,
                                                       self._v[1].data, self._v[3], wait_for=ev)]
        else:
            ev = [self._cl_ptycho_proj_object_probe_mult_quad_phase(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                                                  self._cl_obj.data, self._cl_probe.data, cl_psi.data,
                                                                  self._v[1].data, np.float32(self.pixel_size_object),
                                                                  np.float32(self.pixel_size_object),
                                                                  np.float32(pi / self.lambdaz), self._v[3], wait_for=ev)]

        for e in ev: e.wait()
        self._gpyfft_plan.data = cl_psi
        ev = self._gpyfft_plan.enqueue(forward=True)
        if False:
            ev = [self.cl_scale(cl_psi, np.float32(1 / np.sqrt(cl_psi[0].size)), wait_for=ev)]
        return ev

    def _update_obj_probe(self, ev=None, firstpass=np.int8(0), cl_psi=None):
        """
        Update object and probe from a back-FT series of projections.
        
        Args:
            ev: the opencl events that should be waited for
            firstpass: is passed to the kernels. If 1, the filled in values are added to the existing ones - otherwise, they are set to the new ones.
            cl_psi: the psi which should be used to update the object and/or probe. If None, self._cl_psi will be used.
        
        Returns:
            Nothing. self._cl_obj_new_n and self._cl_probe_new and the corresponding norms are updated.
        """
        if cl_psi is None:
            cl_psi = self._cl_psi
        if self.update_object:
            if self.lambdaz is None:
                ev = [self._cl_ptycho_proj_update_obj(self.cl_queue, (self._ny * self._nx,), (self._nz,), self._cl_obj_newN.data,
                                                    self._cl_probe.data, cl_psi.data, self._cl_obj_normN.data,
                                                    self._v[1].data, self._v[3], np.int8(firstpass), wait_for=ev)]
            else:
                ev = [self._cl_ptycho_proj_update_obj_quad_phase(self.cl_queue, (self._ny * self._nx,), (self._nz,), self._cl_obj_newN.data,
                                                               self._cl_probe.data, cl_psi.data, self._cl_obj_normN.data,
                                                               self._v[1].data, np.float32(self.pixel_size_object),
                                                               np.float32(self.pixel_size_object),
                                                               np.float32(-pi / self.lambdaz), self._v[3], np.int8(firstpass), wait_for=ev)]
        if self.update_probe:
            if self.lambdaz is None:
                ev = [self._cl_ptycho_proj_update_probe(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                                      self._cl_obj.data, self._cl_probe_new.data, cl_psi.data, self._cl_probe_norm.data,
                                                      self._v[1].data, self._v[3], np.int8(firstpass), wait_for=ev)]
            else:
                ev = [self._cl_ptycho_proj_update_probe_quad_phase(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                                                 self._cl_obj.data, self._cl_probe_new.data, cl_psi.data,
                                                                 self._cl_probe_norm.data,
                                                                 self._v[1].data, np.float32(self.pixel_size_object),
                                                                 np.float32(self.pixel_size_object),
                                                                 np.float32(-pi / self.lambdaz), self._v[3], np.int8(firstpass), wait_for=ev)]
        for e in ev: e.wait()

    def _update_obj(self, ev=None, firstpass=np.int8(0), cl_psi=None):
        """
        Update object from a back-FT series of projections.
        
        Args:
            ev: the opencl events that should be waited for
            firstpass: is passed to the kernels. If 0, the filled in values are added to the existing ones - otherwise, they are set to the new ones.
            cl_psi: the psi which should be used to update the object. If None, self._cl_psi will be used.
            
        Returns:
            Nothing. self._cl_obj_new_n and the corresponding norm are updated.
        """
        if cl_psi is None:
            cl_psi = self._cl_psi
        if self.lambdaz is None:
            self._cl_ptycho_proj_update_obj(self.cl_queue, (self._ny * self._nx,), (self._nz,), self._cl_obj_newN.data,
                                            self._cl_probe.data, cl_psi.data, self._cl_obj_normN.data,
                                            self._v[1].data, self._v[3], np.int8(firstpass), wait_for=ev).wait()
        else:
            self._cl_ptycho_proj_update_obj_quad_phase(self.cl_queue, (self._ny * self._nx,), (self._nz,), self._cl_obj_newN.data,
                                                       self._cl_probe.data, cl_psi.data, self._cl_obj_normN.data,
                                                       self._v[1].data, np.float32(self.pixel_size_object),
                                                       np.float32(self.pixel_size_object),
                                                       np.float32(-pi / self.lambdaz), self._v[3], np.int8(firstpass), wait_for=ev).wait()

    def _update_probe(self, ev=None, firstpass=np.int8(0), cl_psi=None):
        """
        Update probe from a back-FT series of projections.
        
        Args:
            ev: the opencl events that should be waited for
            firstpass: is passed to the kernels. If 0, the filled in values are added to the existing ones - otherwise, they are set to the new ones.
            cl_psi: the psi which should be used to update the probe. If None, self._cl_psi will be used.
        
        Returns:
            Nothing. self._cl_probe_new and the corresponding norms are updated.
        """
        if cl_psi is None:
            cl_psi = self._cl_psi
        if self.lambdaz is None:
            self._cl_ptycho_proj_update_probe(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                              self._cl_obj.data, self._cl_probe_new.data, cl_psi.data, self._cl_probe_norm.data,
                                              self._v[1].data, self._v[3], np.int8(firstpass), wait_for=ev).wait()
        else:
            self._cl_ptycho_proj_update_probe_quad_phase(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                                         self._cl_obj.data, self._cl_probe_new.data, cl_psi.data,
                                                         self._cl_probe_norm.data,
                                                         self._v[1].data, np.float32(self.pixel_size_object),
                                                         np.float32(self.pixel_size_object),
                                                         np.float32(-pi / self.lambdaz), self._v[3], np.int8(firstpass), wait_for=ev).wait()

    def _loop_fourier_projection(self, ev=None, ignore_mask=False):
        """
        Apply the observed amplitude for the current looped stack of frames.
        
        Args:
            ev: opencl events to wait for
            ignore_mask: if True, no mask will be used even is self.mask is not None (this is useful during the first cycle, so that masked
                         pixels end up with zero intensity instead of being free to take unreasonably large values, due to improper scaling)

        Returns:
            Nothing. Self._cl_psi is updated
        """
        # TODO: change kernel size to nx*ny so that background is only transferred once per stack of NZ frames
        if self.mask is not None and ignore_mask is False:
            if self.background is not None:
                self._cl_ptycho_proj_apply_amplitude_mask_background(self.cl_queue, (self._nx * self._ny * self._nz,),
                                                                     (self._cl_workgroup_size,), self._v[0].data, self._cl_psi.data,
                                                                     self._cl_mask.data, self._cl_background.data, wait_for=ev)
            else:
                self._cl_ptycho_proj_apply_amplitude_mask(self.cl_queue, (self._nx * self._ny * self._nz,), (self._cl_workgroup_size,),
                                                          self._v[0].data, self._cl_psi.data, self._cl_mask.data, wait_for=ev)
        else:
            if self.background is not None:
                self._cl_ptycho_proj_apply_amplitude_background(self.cl_queue, (self._nx * self._ny * self._nz,), (self._cl_workgroup_size,),
                                                                self._v[0].data, self._cl_psi.data, self._cl_background.data, wait_for=ev)
            else:
                self._cl_ptycho_proj_apply_amplitude(self.cl_queue, (self._nx * self._ny * self._nz,), (self._cl_workgroup_size,),
                                                     self._v[0].data, self._cl_psi.data, wait_for=ev).wait()

    def _loop_llk(self, ev=None):
        """
        Compute the log-likelihood from the computed scattering from the current looped stack of frames.
        Args:
            ev: opencl events to wait for

        Returns:
            Nothing. The computed llk is added to self.llk
        """
        npsi = self._v[3]
        if self.mask is not None:
            if self.background is not None:
                self.llk += self._cl_llk_poisson_mask_background_red(self._v[0][:npsi], self._cl_psi, self._cl_mask,
                                                                     self._cl_background, wait_for=ev).get()
            else:
                self.llk += self._cl_llk_poisson_mask_red(self._v[0][:npsi], self._cl_psi, self._cl_mask, wait_for=ev).get()
        else:
            if self.background is not None:
                self.llk += self._cl_llk_poisson_background_red(self._v[0][:npsi], self._cl_psi, self._cl_background, wait_for=ev).get()
            else:
                self.llk += self._cl_llk_poisson_red(self._v[0][:npsi], self._cl_psi, wait_for=ev).get()

    def projection_cycle(self):
        """
        Ptychographic projection cycle, where:
            1. calculated frames are computed in Fourier space as the FT of the multiplication of object and probe at different positions

               * optionally the LLK is calculated
               * the Fourier constraint (amplitude) is applied

            2. updated object and probe (depending on self.update_obj and self.update_probe flags)are calculated from the Fourier-constrained
               complex frames.

        If the calculation is done by multiple workers, the object and probe are only partially computed (the final normalization step is missing)
        and transferred back to host memory.
        
        Returns:
            nothing. Updated arrays for probe and/or object.
        """

        t0 = timeit.default_timer()
        self.llk = 0
        ev = []
        if not self.work_alone and not self.is_first_worker:
            # Object, probe, background have been updated by parent. This copy is not needed for the first worker where the merge was done
            if self.update_object:
                cl.enqueue_copy(self.cl_queue, src=self.obj, dest=self._cl_obj.data)

            if self.update_probe:
                ev.append(cl.enqueue_copy(self.cl_queue, src=self.probe, dest=self._cl_probe.data))
            if self.update_background:
                ev.append(cl.enqueue_copy(self.cl_queue, src=self.background, dest=self._cl_background.data))

        firstpass = np.int8(1)  # Flag to zero updated probe arrays on first pass in self._update_obj_probe()
        for self._v in self._cl_obs_v:
            ev = self._calc_forward(ev)

            if self.calc_llk:
                ev = self._loop_llk(ev)

            if self.update_background:
                self._cl_ptycho_background_loop(self.cl_queue, (self._ny * self._nx,), (self._nz,), self._cl_psi.data, self._v[0].data,
                                                self._cl_background.data, self._cl_background_d.data,self._cl_background_d2.data,
                                                self._cl_background_z2.data, self._cl_background_dz2.data, self._v[3], firstpass)

            if self.update_positions:
                # Copy original psi for image cross-correlation, only the first mode even if there are several
                cl.enqueue_copy(self.cl_queue, src=self._cl_psi[:self._nz].data, dest=self._cl_psi_registration.data)

            self._loop_fourier_projection(ev)

            if self.update_positions:
                # Calculate the cross-correlation
                self._cl_cross_power_spectrum(self._cl_psi_registration, self._cl_psi)
                if False:
                    for e in self._gpyfft_plan_1mode.enqueue(forward=False):
                        e.wait()
                    for iz in range(self._nz):
                        vargmax = self._cl_arg_max(self._cl_psi_registration).get()
                        vmax = vargmax['x']
                        imax = int(vargmax['y'])
                        dx = imax % self._nx
                        dy = int(imax // self._nx)
                        dx -= (dx > (self._nx // 2)) * self._nx
                        dy -= (dy > (self._ny // 2)) * self._ny
                        print("iz=%d, dx=%3d, dy=%3d, vmax=%e" % (iz, dx, dy, vmax))


            ev = self._gpyfft_plan.enqueue(forward=False)
            if False:
                ev = [self.cl_scale(self._cl_psi, np.float32(np.sqrt(self._cl_psi[0].size)), wait_for=ev)]

            self._update_obj_probe(ev, firstpass)
            firstpass = np.int8(0)

        if self.update_background:
            if self.work_alone:
                self._cl_ptycho_background_update(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                                  self._cl_background.data, self._cl_background_d.data, self._cl_background_d2.data,
                                                  self._cl_background_z2.data, self._cl_background_dz2.data, np.int32(len(self.iobs)))
            else:
                # Copy back arrays so that they can be merged once all workers have finished
                cl.enqueue_copy(self.cl_queue, src=self._cl_background_d.data  , dest=self.background_d).wait()
                cl.enqueue_copy(self.cl_queue, src=self._cl_background_d2.data , dest=self.background_d2).wait()
                cl.enqueue_copy(self.cl_queue, src=self._cl_background_dz2.data, dest=self.background_dz2).wait()
                cl.enqueue_copy(self.cl_queue, src=self._cl_background_z2.data , dest=self.background_z2).wait()

        if self.update_object:
            if self.work_alone:
                self._cl_ptycho_proj_sum_n_zero(self.cl_queue, (self._nyo * self._nxo,), (self._nz,), self._cl_obj_newN.data, self._cl_obj_normN.data,
                                                self._cl_obj_new.data, self._cl_obj_norm.data, wait_for=ev).wait()
                # Get max of obj_norm for regularization.
                reg = np.float32(float(cl.array.max(self._cl_obj_norm).get()) * self.reg_obj_probe)
                ev = []
                # Normalize object
                ev = [self._cl_ptycho_proj_obj_norm(self.cl_queue, (self._nxo * self._nyo,), (self._cl_workgroup_size,),
                                                    self._cl_obj_new.data, self._cl_obj_norm.data, self._cl_obj.data, reg, wait_for=ev)]
            else:
                self._cl_ptycho_proj_sum_n_zero(self.cl_queue, (self._nyo * self._nxo,), (self._nz,), self._cl_obj_newN.data, self._cl_obj_normN.data,
                                                self._cl_obj_new.data, self._cl_obj_norm.data, wait_for=ev).wait()
                # We only have part of the object yet
                #self._cl_obj_norm.get(ary=self.obj_new_norm, async=False)
                #self._cl_obj_new.get(ary=self.obj_new, async=False)
                cl.enqueue_copy(self.cl_queue, src=self._cl_obj_norm.data, dest=self.obj_new_norm).wait()
                cl.enqueue_copy(self.cl_queue, src=self._cl_obj_new.data, dest=self.obj_new).wait()

        if self.update_probe:
            if self.work_alone:
                # Get max of probe_norm for regularization.
                reg = np.float32(float(cl.array.max(self._cl_probe_norm).get()) * self.reg_obj_probe)
                ev = []
                # Normalize probe, filter if necessary
                if self._probe_mask is not None:
                    ev = [self._cl_ptycho_proj_probe_norm_mask(self.cl_queue, (self._nx * self._ny,), (self._cl_workgroup_size,),
                                                               self._cl_probe_new.data, self._cl_probe_norm.data, self._cl_probe.data,
                                                               reg, self._cl_probe_mask.data, wait_for=ev)]
                else:
                    ev = [self._cl_ptycho_proj_probe_norm(self.cl_queue, (self._nx * self._ny,), (self._cl_workgroup_size,),
                                                          self._cl_probe_new.data, self._cl_probe_norm.data, self._cl_probe.data,
                                                          reg, wait_for=ev)]
            else:
                # We only have part of the object yet
                #self._cl_probe_norm.get(ary=self.probe_new_norm, async=False)
                #self._cl_probe_new.get(ary=self.probe_new, async=False)
                cl.enqueue_copy(self.cl_queue, src=self._cl_probe_norm.data, dest=self.probe_new_norm).wait()
                cl.enqueue_copy(self.cl_queue, src=self._cl_probe_new.data, dest=self.probe_new).wait()
        for e in ev: e.wait()
        if self.verbose:
            print("%s: finished work, dt=%6.4fs"%(self.name, timeit.default_timer()-t0))

    def projection_merge_update(self):
        """
        Merged final update when using multiple worker threads. Get the full object and probe from parent, then perform normalization
        and extract objects to be used in other threads. Also update background if necessary.
        
        Returns:
            Nothing. self.obj, self.probe and self.background are updated, as their opencl counterparts.
        """
        if self.update_object:
            # Object has been updated by parent
            #self._cl_obj = cl.array.to_device(self.cl_queue, self.obj, async=False)
            #self._cl_obj_norm = cl.array.to_device(self.cl_queue, self.obj_new_norm, async=False)
            cl.enqueue_copy(self.cl_queue, src=self.obj_new, dest=self._cl_obj_new.data).wait()
            cl.enqueue_copy(self.cl_queue, src=self.obj_new_norm, dest=self._cl_obj_norm.data).wait()
            reg = np.float32(float(cl.array.max(self._cl_obj_norm).get()) * self.reg_obj_probe)
            self._cl_ptycho_proj_obj_norm(self.cl_queue, (self._nxo * self._nyo,), (self._cl_workgroup_size,),
                                              self._cl_obj_new.data, self._cl_obj_norm.data, self._cl_obj.data, reg).wait()
            self._cl_obj.get(ary=self.obj, async=False)
            #cl.enqueue_copy(self.cl_queue, src=self._cl_obj.data, dest=self.obj).wait()

        if self.update_probe:
            # Probe has been updated by parent
            #self._cl_probe = cl.array.to_device(self.cl_queue, self.probe, async=False)
            #self._cl_probe_norm = cl.array.to_device(self.cl_queue, self.probe_new_norm, async=False)
            cl.enqueue_copy(self.cl_queue, src=self.probe, dest=self._cl_probe.data).wait()
            cl.enqueue_copy(self.cl_queue, src=self.probe_new_norm, dest=self._cl_probe_norm.data).wait()
            reg = np.float32(float(cl.array.max(self._cl_probe_norm).get()) * self.reg_obj_probe)
            self._cl_ptycho_proj_probe_norm(self.cl_queue, (self._nx * self._ny,), (self._cl_workgroup_size,),
                                            self._cl_probe.data, self._cl_probe_norm.data, self._cl_probe.data, reg).wait()
            self._cl_probe.get(ary=self.probe, async=False)

        if self.update_background:
            cl.enqueue_copy(self.cl_queue, src=self.background_d  , dest=self._cl_background_d.data).wait()
            cl.enqueue_copy(self.cl_queue, src=self.background_d2 , dest=self._cl_background_d2.data).wait()
            cl.enqueue_copy(self.cl_queue, src=self.background_dz2, dest=self._cl_background_dz2.data).wait()
            cl.enqueue_copy(self.cl_queue, src=self.background_z2 , dest=self._cl_background_z2.data).wait()
            self._cl_ptycho_background_update(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                              self._cl_background.data, self._cl_background_d.data, self._cl_background_d2.data,
                                              self._cl_background_z2.data, self._cl_background_dz2.data, self.nframes_total)
            self._cl_background.get(ary=self.background, async=False)

    def _loop_ml_poisson_psi_corr(self, ev=None, firstpass=np.int8(0)):
        if self.update_background:
            # Calculate LLK background gradient at the same time
            if self.mask is not None:
                self._cl_ml_poisson_psi_corr_background_gradient_mask(self.cl_queue, (self._nx * self._ny,), (self._cl_workgroup_size,),
                                                                      self._cl_psi.data, self._v[0].data, self._cl_background.data,
                                                                      self._cl_background_grad.data, self._cl_mask.data, self._v[3], firstpass,
                                                                      wait_for=ev).wait()
            else:
                self._cl_ml_poisson_psi_corr_background_gradient(self.cl_queue, (self._nx * self._ny,), (self._cl_workgroup_size,),
                                                                 self._cl_psi.data, self._v[0].data, self._cl_background.data,
                                                                 self._cl_background_grad.data, self._v[3], firstpass, wait_for=ev).wait()
        else:
            if self.mask is not None:
                if self.background is not None:
                    self._cl_ml_poisson_psi_corr_mask_background(self.cl_queue, (self._nx * self._ny * self._nz,), (self._cl_workgroup_size,),
                                                                 self._cl_psi.data, self._v[0].data, self._cl_mask.data, self._cl_background.data,
                                                                 wait_for=ev).wait()
                else:
                    self._cl_ml_poisson_psi_corr_mask(self.cl_queue, (self._nx * self._ny * self._nz,), (self._cl_workgroup_size,),
                                                      self._cl_psi.data, self._v[0].data, self._cl_mask.data, wait_for=ev).wait()
            else:
                if self.background is not None:
                    self._cl_ml_poisson_psi_corr_background(self.cl_queue, (self._nx * self._ny * self._nz,), (self._cl_workgroup_size,),
                                                            self._cl_psi.data, self._v[0].data, self._cl_background.data, wait_for=ev).wait()
                else:
                    self._cl_ml_poisson_psi_corr(self.cl_queue, (self._nx * self._ny * self._nz,), (self._cl_workgroup_size,),
                                                 self._cl_psi.data, self._v[0].data, wait_for=ev).wait()

    def ml_poisson_cycle_single_worker(self):
        t0 = timeit.default_timer()
        self.llk = 0
        ev = []
        # 1) Gradient
        if self.update_object:
            self._cl_obj_grad, self._cl_obj_grad_last = self._cl_obj_grad_last, self._cl_obj_grad
        if self.update_probe:
            self._cl_probe_grad, self._cl_probe_grad_last = self._cl_probe_grad_last, self._cl_probe_grad
        if self.update_background:
            self._cl_background_grad, self._cl_background_grad_last = self._cl_background_grad_last, self._cl_background_grad

        firstpass = np.int8(1) # Flag to zero probe gradient array on first pass in self._cl_ml_poisson_probe_grad()
        for self._v in self._cl_obs_v:
            ev = self._calc_forward(ev)

            if self.calc_llk:
                self._loop_llk(ev)
                ev=[]

            # Calculate Psi * (1 - Iobs/Icalc) and the gradient vs background values if update_background
            self._loop_ml_poisson_psi_corr(ev=ev, firstpass=firstpass)

            ev = self._gpyfft_plan.enqueue(forward=False)
            if False:
                ev = [self.cl_scale(self._cl_psi, np.float32(np.sqrt(self._cl_psi[0].size)), wait_for=ev)]

            if self.update_object:
                if self.lambdaz is None:
                    ev=[self._cl_ml_poisson_obj_grad(self.cl_queue, (self._nx * self._ny,), (self._cl_workgroup_size,),
                                                     self._cl_obj_gradN.data, self._cl_probe.data, self._cl_psi.data,
                                                     self._v[1].data, self._v[3], firstpass, wait_for=ev)]
                else:
                    ev=[self._cl_ml_poisson_obj_grad_quad_phase(self.cl_queue, (self._nx * self._ny,), (self._cl_workgroup_size,),
                                                                self._cl_obj_gradN.data, self._cl_probe.data, self._cl_psi.data,
                                                                self._v[1].data,
                                                                np.float32(self.pixel_size_object), np.float32(self.pixel_size_object),
                                                                np.float32(-pi / self.lambdaz), self._v[3], firstpass, wait_for=ev)]

            if self.update_probe:
                if self.lambdaz is None:
                    ev = [self._cl_ml_poisson_probe_grad(self.cl_queue, (self._nx * self._ny,), (self._cl_workgroup_size,),
                                                         self._cl_probe_grad.data, self._cl_obj.data, self._cl_psi.data,
                                                         self._v[1].data, self._v[3], firstpass, wait_for=ev)]
                else:
                    ev = [self._cl_ml_poisson_probe_grad_quad_phase(self.cl_queue, (self._nx * self._ny,), (self._cl_workgroup_size,),
                                                                    self._cl_probe_grad.data, self._cl_obj.data, self._cl_psi.data,
                                                                    self._v[1].data,
                                                                    np.float32(self.pixel_size_object), np.float32(self.pixel_size_object),
                                                                    np.float32(-pi / self.lambdaz), self._v[3], firstpass, wait_for=ev)]
            if False:
                self.psi1 = self._cl_obj_gradN.get()
                self.obj1 = self._cl_obj.get()
                self.probe1 = self._cl_probe.get()

            firstpass = np.int8(0)

        if self.update_object:
            ev=[self._cl_ml_poisson_sum_obj_grad_zero(self.cl_queue, (self._nxo * self._nyo,), (self._cl_workgroup_size,),
                                                      self._cl_obj_grad.data, self._cl_obj_gradN.data, wait_for=ev)]

        #TODO: Add a regularization factor on the background ?
        if self.reg_fac is not None:
            # Regularization contribution to the object and probe gradient
            if self.update_object:
                ev = [self._cl_ml_poisson_reg_grad(self.cl_queue, (self._cl_obj.size,), (self._cl_workgroup_size,), self._cl_obj_grad.data,
                                                   self._cl_obj.data, np.float32(self.reg_fac_obj), wait_for=ev)]
            if self.update_probe and self.reg_fac_probe != 0:
                ev = [self._cl_ml_poisson_reg_grad(self.cl_queue, (self._cl_probe.size,), (self._cl_workgroup_size,), self._cl_probe_grad.data,
                                                   self._cl_probe.data, np.float32(self.reg_fac_probe), wait_for=ev)]
            # Regularization contribution to the log-likelihood
            if self.calc_llk:
                llkregobj = self._cl_llk_reg_red(self._cl_obj, self._nxo, self._nyo, wait_for=ev).get()
                if self.reg_fac_probe != 0:
                    llkregprobe = self._cl_llk_reg_red(self._cl_probe, self._nx, self._ny, wait_for=ev).get()
                else:
                    llkregprobe = 0
                ev = []
                llkregobj *= self.reg_fac_obj
                llkregprobe *= self.reg_fac_probe
                if False:  # DEBUG
                    print("LLK = %12.2f [ptycho] %12.2f [reg-obj] %12.2f [reg-probe]" % (self.llk, llkregobj, llkregprobe))
                self.llk += llkregobj + llkregprobe

        # 2) Search direction
        if self.beta is None:
            # First cycle
            self.beta = np.float32(0)
            if self.update_object:
                ev = [cl.enqueue_copy(self.cl_queue, self._cl_obj_dir.data,   self._cl_obj_grad.data,   wait_for=ev)]
            if self.update_probe:
                ev = [cl.enqueue_copy(self.cl_queue, self._cl_probe_dir.data, self._cl_probe_grad.data, wait_for=ev)]
            if self.update_background:
                ev = [cl.enqueue_copy(self.cl_queue, self._cl_background_dir.data, self._cl_background_grad.data, wait_for=ev)]
        else:
            beta_d, beta_n = 0, 0
            # Polak-Ribière CG coefficient
            if self.update_object:
                tmp = self._cl_cg_polak_ribiere_complex_red(self._cl_obj_grad, self._cl_obj_grad_last, wait_for=ev).get()
                ev = []
                beta_n += tmp['x']
                beta_d += tmp['y']
                if False:  # DEBUG
                    g1 = self._cl_obj_grad.get()
                    g0 = self._cl_obj_grad_last.get()
                    # A,B = (g1.real*(g1.real-g0.real)+g1.imag*(g1.imag-g0.imag)).sum(), (g0.real*g0.real+g0.imag*g0.imag).sum()
                    A, B = (g1 * (g1 - g0)).sum(), (abs(g0) ** 2).sum()
                    cpubeta = A / B
                    print("Object - betaPR: (GPU)=%8.4e  , (CPU)=%8.4e%+8.4ei [%8.4e/%8.4e], dot(g0.g1)=%8e [%8e]" %
                          (tmp['x'] / tmp['y'], cpubeta.real, cpubeta.imag, A, B, (g0 * g1).sum().real, (abs(g0) ** 2).sum().real))

            if self.update_probe:
                tmp = self._cl_cg_polak_ribiere_complex_red(self._cl_probe_grad, self._cl_probe_grad_last, wait_for=ev).get()
                ev = []
                beta_n += tmp['x']
                beta_d += tmp['y']
                if False:  # DEBUG
                    g1 = self._cl_probe_grad.get()
                    g0 = self._cl_probe_grad_last.get()
                    # A,B = (g1.real*(g1.real-g0.real)+g1.imag*(g1.imag-g0.imag)).sum(), (g0.real*g0.real+g0.imag*g0.imag).sum()
                    A, B = (g1 * (g1 - g0)).sum(), (abs(g0) ** 2).sum()
                    cpubeta = A / B
                    print("Object - betaPR: (GPU)=%8.4e  , (CPU)=%8.4e%+8.4ei [%8.4e/%8.4e], dot(g0.g1)=%8e [%8e]" %
                          (tmp['x'] / tmp['y'], cpubeta.real, cpubeta.imag, A, B, (g0 * g1).sum().real, (abs(g0) ** 2).sum().real))

            if self.update_background:
                tmp = self._cl_cg_polak_ribiere_float_red(self._cl_background_grad, self._cl_background_grad_last, wait_for=ev).get()
                ev = []
                beta_n += tmp['x']
                beta_d += tmp['y']
                if False:
                    print("Beta(background)= %8g = %8g / %8g"%(tmp['x'] / tmp['y'],tmp['x'], tmp['y']))

            if np.isnan(beta_d) or np.isnan(beta_n):
                print("Beta = NaN ! :", beta_d, beta_n)
            # Reset direction if beta<0 => beta=0
            self.beta = np.float32(max(0, beta_n / max(1e-20, beta_d)))
            if self.update_object:
                ev = [self._cl_ml_poisson_cg_linear_complex(self.cl_queue, (self._cl_obj.size,),
                                                            (self._cl_workgroup_size,), self.beta,
                                                            self._cl_obj_dir.data, np.float32(-1),
                                                            self._cl_obj_grad.data, wait_for=ev)]
            if self.update_probe:
                ev = [self._cl_ml_poisson_cg_linear_complex(self.cl_queue, (self._cl_probe.size,),
                                                            (self._cl_workgroup_size,), self.beta,
                                                            self._cl_probe_dir.data, np.float32(-1),
                                                            self._cl_probe_grad.data, wait_for=ev)]
            if self.update_background:
                ev = [self._cl_ml_poisson_cg_linear_float(self.cl_queue, (self._cl_background.size,),
                                                          (self._cl_workgroup_size,), self.beta,
                                                          self._cl_background_dir.data, np.float32(-1),
                                                          self._cl_background_grad.data, wait_for=ev)]
        # 3) Line minimization
        gamma_d, gamma_n = 0, 0
        gamma_scale = 1e-6  # TODO: properly define this scale dynamically ! (used to avoid overflows)
        for self._v in self._cl_obs_v:
            for clpsi, clobj, clprobe  in zip([self._cl_PO,    self._cl_PdO,     self._cl_dPO,       self._cl_dPdO],
                                              [self._cl_obj,   self._cl_obj_dir, self._cl_obj,       self._cl_obj_dir],
                                              [self._cl_probe, self._cl_probe,   self._cl_probe_dir, self._cl_probe_dir]):
                if self.lambdaz is None:
                    ev = [self._cl_ptycho_proj_object_probe_mult(self.cl_queue, (self._ny * self._nx,), (self._cl_workgroup_size,),
                                                               clobj.data, clprobe.data, clpsi.data,
                                                               self._v[1].data, self._v[3], wait_for=ev)]
                else:
                    ev = [self._cl_ptycho_proj_object_probe_mult_quad_phase(self.cl_queue, (self._ny * self._nx,), (self._cl_workgroup_size,),
                                                                          clobj.data, clprobe.data, clpsi.data,
                                                                          self._v[1].data,
                                                                          np.float32(self.pixel_size_object),
                                                                          np.float32(self.pixel_size_object),
                                                                          np.float32(pi / self.lambdaz), self._v[3], wait_for=ev)]
                self._gpyfft_plan.data = clpsi
                cl.wait_for_events(ev)
                ev = self._gpyfft_plan.enqueue(forward=True)

            if self.update_background:
                if self.mask is not None:
                    tmp = self._cl_ml_poisson_cg_gamma_background_grad_mask_red(self._v[0], self._cl_PO, self._cl_PdO, self._cl_dPO, self._cl_dPdO,
                                                                                 self._cl_background, self._cl_background_dir, self._cl_mask,
                                                                                 gamma_scale, self._v[3], wait_for=ev).get()
                else:
                    tmp = self._cl_ml_poisson_cg_gamma_background_grad_red(self._v[0], self._cl_PO, self._cl_PdO, self._cl_dPO, self._cl_dPdO,
                                                                           self._cl_background, self._cl_background_dir, gamma_scale, self._v[3],
                                                                           wait_for=ev).get()
            else:
                if self.mask is not None:
                    if self.background is not None:
                        tmp = self._cl_ml_poisson_cg_gamma_background_mask_red(self._v[0], self._cl_PO, self._cl_PdO, self._cl_dPO, self._cl_dPdO,
                                                                               self._cl_background, self._cl_mask, gamma_scale, self._v[3], wait_for=ev).get()
                    else:
                        tmp = self._cl_ml_poisson_cg_gamma_mask_red(self._v[0], self._cl_PO, self._cl_PdO, self._cl_dPO, self._cl_dPdO, self._cl_mask,
                                                                    gamma_scale, self._v[3], wait_for=ev).get()
                else:
                    if self.background is not None:
                        tmp = self._cl_ml_poisson_cg_gamma_background_red(self._v[0], self._cl_PO, self._cl_PdO, self._cl_dPO, self._cl_dPdO,
                                                                          self._cl_background, gamma_scale, self._v[3], wait_for=ev).get()
                    else:
                        tmp = self._cl_ml_poisson_cg_gamma_red(self._v[0], self._cl_PO, self._cl_PdO, self._cl_dPO, self._cl_dPdO,
                                                               gamma_scale, self._v[3], wait_for=ev).get()

            tmp_n, tmp_d = tmp['x'], tmp['y']
            if False:  # DEBUG
                # Comparison only correct for 1 mode and multiple of 16 frames, without background - OK

                iobs = (self._v[0].get() * gamma_scale) ** 2
                PO = self._cl_PO.get() * gamma_scale
                PdO = self._cl_PdO.get() * gamma_scale
                dPO = self._cl_dPO.get() * gamma_scale
                dPdO = self._cl_dPdO.get() * gamma_scale
                R_PO_OdP_Pdo = (PO.conjugate() * (dPO + PdO)).real
                OdP_PdO_2R = abs(dPO) ** 2 + abs(PdO) ** 2 + 2 * (PO.conjugate() * dPdO + dPO * PdO.conjugate()).real
                A = ((OdP_PdO_2R) * (1 - iobs / abs(PO) ** 2)) + 2 * iobs * R_PO_OdP_Pdo ** 2 / abs(PO) ** 4
                B = -2 * R_PO_OdP_Pdo * (1 - iobs / abs(PO) ** 2)
                print("Gamma: %8e/%8e [%8e/%8e]" % (tmp_n, tmp_d, B.sum() / 2, A.sum()),
                      "A1=%e, A2=%e" % (((OdP_PdO_2R) * (1 - iobs / abs(PO) ** 2)).sum(),
                                        (4 * iobs * R_PO_OdP_Pdo ** 2 / abs(PO) ** 4).sum()),
                      "abs(obj grad).sum()=%8e, abs(probe grad).sum()=%8e" % (abs(self._cl_obj_grad.get()).sum(), abs(self._cl_probe_grad.get()).sum()),
                      "PO: %8e, dPO: %8e, PdO:%8e, dPdO:%8e" % (abs(self._cl_PO.get()).sum(), abs(self._cl_dPO.get()).sum(),
                                                                abs(self._cl_PdO.get()).sum(), abs(self._cl_dPdO.get()).sum(),))
                self.R_PO_OdP_Pdo = R_PO_OdP_Pdo
                self.OdP_PdO_2R = OdP_PdO_2R
                self.iobs = iobs
                self.PO = PO
                self.PdO = PdO
                self.dPO = dPO
                self.dPdO = dPdO

            gamma_d += tmp_d
            gamma_n += tmp_n

        if self.reg_fac is not None:
            if self.update_object:
                tmp = self._cl_cg_gamma_reg_red(self._cl_obj, self._cl_obj_dir, self._nxo, self._nyo, wait_for=ev).get()
                ev = []
                tmp_n, tmp_d = tmp['x'] * self.reg_fac_obj * gamma_scale, tmp['y'] * self.reg_fac_obj * gamma_scale
                gamma_d += tmp_d
                gamma_n += tmp_n
            if self.update_probe and self.reg_fac_probe != 0:
                tmp = self._cl_cg_gamma_reg_red(self._cl_probe, self._cl_probe_dir, self._nx, self._ny, wait_for=ev).get()
                ev = []
                tmp_n, tmp_d = tmp['x'] * self.reg_fac_probe * gamma_scale, tmp['y'] * self.reg_fac_probe * gamma_scale
                gamma_d += tmp_d
                gamma_n += tmp_n


        if np.isnan(gamma_d) or np.isnan(gamma_n):
            print("Gamma = NaN ! :", gamma_d, gamma_n)
        if abs(gamma_d) < 1e-10: gamma_d = 1e-10
        gamma = np.float32(gamma_n / gamma_d)
        if False:  # DEBUG
            print("ML, beta=%12.10f, gamma=%12.10f" % (self.beta, gamma))

        # 4) Object and/or probe and/or background update
        if self.update_object:
            ev = [self._cl_ml_poisson_cg_linear_complex(self.cl_queue, (self._cl_obj.size,), (self._cl_workgroup_size,),
                                                        np.float32(1), self._cl_obj.data, gamma, self._cl_obj_dir.data,
                                                        wait_for=ev)]

        if self.update_probe:
            ev = [
                self._cl_ml_poisson_cg_linear_complex(self.cl_queue, (self._cl_probe.size,), (self._cl_workgroup_size,),
                                                      np.float32(1), self._cl_probe.data, gamma,
                                                      self._cl_probe_dir.data, wait_for=ev)]

        if self.update_background:
            ev = [self._cl_ml_poisson_cg_linear_float(self.cl_queue, (self._cl_background.size,),
                                                      (self._cl_workgroup_size,),
                                                      np.float32(1), self._cl_background.data, gamma,
                                                      self._cl_background_dir.data, wait_for=ev)]

        for e in ev: e.wait()
        if self.verbose:
            print("%s: finished work, dt=%6.4fs" % (self.name, timeit.default_timer() - t0))

    def ml_poisson_gradient(self):
        """
        First step of multi-worker ML algorithm.
        
        Returns:

        """
        t0 = timeit.default_timer()
        self.llk = 0
        ev = []
        if self.update_object:
            if not(self.work_alone):
                # Object has been updated by parent TODO: this copy is not needed for the first worker (where the object update was done)
                cl.enqueue_copy(self.cl_queue, src=self.obj, dest=self._cl_obj.data)

        if self.update_probe:
            if not(self.work_alone):
                # Probe has been updated by parent TODO: this copy is not needed for the first worker (where the probe update was done)
                ev.append(cl.enqueue_copy(self.cl_queue, src=self.probe, dest=self._cl_probe.data))

        for self._v in self._cl_obs_v:
            ev = self._calc_forward(ev)

            if self.calc_llk:
                self._loop_llk(ev)

            self._loop_ml_poisson_psi_corr()

            ev = self._gpyfft_plan.enqueue(forward=False)

        # TODO
        ev = self._calc_forward(ev)

    def dm_cycle_single_worker(self):
        """
        Difference map algorithm for a single worker.
        
        Returns:
            Nothing
        """
        t0 = timeit.default_timer()
        self.llk = 0
        ev = []
        if self.cycle == 1:
            # Compute starting view of Psi from the current object and probe. And keep Psi in memory.
            for self._v in self._cl_obs_v:
                if len(self._v) == 4:
                    self._v.append(cl.array.empty_like(self._cl_psi))
                ev = self._calc_forward(ev, cl_psi=self._v[4])

                # TODO: this line does nothing, as it applies to self._cl_psi and not cl_psi=self._v[4] !
                # But this code is obsolete anyway... Here we only need to init Psi to PO
                # self._loop_fourier_projection(ev, ignore_mask=True)

                ev = self._gpyfft_plan.enqueue(forward=False)
                if False:
                    ev = [self.cl_scale(self._cl_psi, np.float32(np.sqrt(self._cl_psi[0].size)), wait_for=ev)]

        for e in ev: e.wait()
        ev=[]
        # Regular cycle ##########################################################
        # 1) update psi using DM algorithm #######################################
        for self._v in self._cl_obs_v:
            # Calculate 2PO-Psi, optionally multiplied by the quadratic phase (current Psi in _v[4] already is)
            if self.lambdaz is None:
                self._cl_dm_2po_psi(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                    self._cl_obj.data, self._cl_probe.data, self._cl_psi.data, self._v[4].data,
                                    self._v[1].data, self._v[3], wait_for=ev).wait()
            else:
                self._cl_dm_2po_psi_quad_phase(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                               self._cl_obj.data, self._cl_probe.data, self._cl_psi.data, self._v[4].data,
                                               self._v[1].data, np.float32(self.pixel_size_object),
                                               np.float32(self.pixel_size_object),
                                               np.float32(pi / self.lambdaz), self._v[3], wait_for=ev).wait()
            # FT
            self._gpyfft_plan.data = self._cl_psi
            ev = self._gpyfft_plan.enqueue(forward=True)
            if False:
                ev = [self.cl_scale(self._cl_psi, np.float32(1 / np.sqrt(self._cl_psi[0].size)), wait_for=ev)]
            # Apply amplitude in Fourier space
            self._loop_fourier_projection(ev)
            # FT-1
            ev = self._gpyfft_plan.enqueue(forward=False)
            if False:
                ev = [self.cl_scale(self._cl_psi, np.float32(np.sqrt(self._cl_psi[0].size)), wait_for=ev)]
            # Update Psi
            if self.lambdaz is None:
                self._cl_dm_update_psi(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                       self._cl_obj.data, self._cl_probe.data, self._cl_psi.data, self._v[4].data,
                                       self._v[1].data, self._v[3], wait_for=ev).wait()
            else:
                self._cl_dm_update_psi_quad_phase(self.cl_queue, (self._ny * self._nx,), (self._nz,),
                                                  self._cl_obj.data, self._cl_probe.data, self._cl_psi.data, self._v[4].data,
                                                  self._v[1].data, np.float32(self.pixel_size_object),
                                                  np.float32(self.pixel_size_object),
                                                  np.float32(pi / self.lambdaz), self._v[3], wait_for=ev).wait()
            ev=[]

        # 2) update object and probe from new Psi ############################
        # WHY is it necessary to first completely update the object, and then the probe, instead of updating them
        # simultaneously like in the AP algorithm ???
        for i in range(1):
            if self.update_object:
                firstpass = np.int8(1)
                for self._v in self._cl_obs_v:
                    self._update_obj(ev=None, firstpass=firstpass, cl_psi=self._v[4])
                    firstpass = np.int8(0)

                self._cl_ptycho_proj_sum_n_zero(self.cl_queue, (self._nyo * self._nxo,), (self._nz,), self._cl_obj_newN.data,
                                                self._cl_obj_normN.data,
                                                self._cl_obj_new.data, self._cl_obj_norm.data, wait_for=ev).wait()
                # Get max of obj_norm for regularization.
                reg = np.float32(float(cl.array.max(self._cl_obj_norm).get()) * self.reg_obj_probe)
                # Normalize object
                self._cl_ptycho_proj_obj_norm(self.cl_queue, (self._nxo * self._nyo,), (self._cl_workgroup_size,),
                                              self._cl_obj_new.data, self._cl_obj_norm.data, self._cl_obj.data, reg).wait()

            if self.update_probe:
                firstpass = np.int8(1)
                for self._v in self._cl_obs_v:
                    self._update_probe(ev=None, firstpass=firstpass, cl_psi=self._v[4])
                    firstpass = np.int8(0)
                # Get max of probe_norm for regularization.
                reg = np.float32(float(cl.array.max(self._cl_probe_norm).get()) * self.reg_obj_probe)
                ev = []
                # Normalize probe, filter if necessary
                if self._probe_mask is not None:
                    ev = [self._cl_ptycho_proj_probe_norm_mask(self.cl_queue, (self._nx * self._ny,), (self._cl_workgroup_size,),
                                                               self._cl_probe_new.data, self._cl_probe_norm.data, self._cl_probe.data,
                                                               reg, self._cl_probe_mask.data, wait_for=ev)]
                else:
                    ev = [self._cl_ptycho_proj_probe_norm(self.cl_queue, (self._nx * self._ny,), (self._cl_workgroup_size,),
                                                          self._cl_probe_new.data, self._cl_probe_norm.data, self._cl_probe.data,
                                                          reg, wait_for=ev)]

        for e in ev: e.wait()

        if self.calc_llk:
            # We need to calculate the LLK from the current object and probe
            for self._v in self._cl_obs_v:
                self._loop_llk(self._calc_forward())

        if self.verbose:
            print("%s: finished work, dt=%6.4fs" % (self.name, timeit.default_timer() - t0))


class Ptycho2D(object):
    """
    2D Ptychographic class, working exclusively in OpenCL, with tasks more clearly separated in small(er) functions.

    Algorithms implemented include:
      - Alternating Projections
      - Maximum-Likelihood (Poisson noise) conjugate gradient
      - Difference Map
      - Positions optimization using registration [TODO]

    Actual computing is delegated to opencl workers in multiple threads.

    """
    def __init__(self, iobs, positions, probe, obj, mask=None, background=None, pixel_size_object=None, lambdaz=None,
                 opencl_device=None, opencl_platform=None,history=None):
        """
        Initialize Ptycho2D object. Amplitudes, initial probe0 and obj0 are assumed
        to be centered on the array. Amplitudes will internally be shifted to have the origin at (0,0).
        All calculations are made in single precision (float32, complex64).
        If OpenCL device or context is not given, it can be initialized later by calling InitCL
        
        Args:
            iobs: 3D array of shape [nbobs, ny, nx], assumed to follow Poisson statistics, and centered.
            positions: positions in pixels, with positions[1]= 1D vector of x positions, positions[0]= 1D vector of y positions
            probe: initial probe (either 2D or a 3D stack of 2D probe modes with shape [nbprobe, ny,ny])
            obj: initial object (either 2D or a 3D stack of 2D object modes with shape [nbobj, ny,ny])
            background: 2D array of background intensity (default= None) [optional] TODO
            mask: array of booleans indicating which detector pixels should be used (True=good pixel, False=wrong) [optional]
            pixel_size_object: size of pixels in object space, in meters
            lambdaz: wavelength*distance to detector, in m^2.
               Both lambdaz and pixel_size_object are necessary to take into account quadratic phase factors for far field propagation.
               If either is missing, quadratic phase correction is not applied (can induce some stitching errors)
            opencl_device: device to use for OpenCL computing. This can be supplied either as a string or a pyopencl.Device object.
                           If multiple devices with different names should be used (to combine GPUs), a list can be supplied.
            opencl_platform: string for the OpenCL platform to use (optional)
            history: a dictionary of OrderedDict with the optimization history. It is used to store the parameters like the algorithm,
                     negative log-likelihood (llk), chi^2, resolution, cputime, walltime as a function of the cycle number. If not supplied,
                     a new history will be started. Not all values need be stored for all cycles.
        """
        self.iobs = iobs.astype(np.float32)
        self.probe = probe.copy().astype(np.complex64)
        self.obj = obj.copy().astype(np.complex64)
        self.dy = positions[0]
        self.dx = positions[1]
        self.mask = mask
        if self.mask is not None:
            if self.mask.sum():
                self.mask = self.mask.astype(np.int8)
            else:
                self.mask = None

        # Norm so that (llk -llk_offset) / self._llk_norm = 1 once a perfect model is reached
        self._llk_norm = self.iobs.size * 0.5
        # Correction factor for the log-likelihood sum. The asymptotic value is
        #            0.5 * nb_obs+ (gammaln(iobs+1) - (iobs * np.log(iobs) - iobs))[iobs>0].sum()
        if self.mask is not None:
            idx = ((self.mask == 0) * iobs) > 0
            self._llk_norm *= (self.mask == 0).sum() / self.mask.size
        else:
            idx = iobs > 0
        self._llk_offset = (gammaln(iobs[idx]+1) - (iobs[idx] * np.log(iobs[idx]) - iobs[idx])).sum()
        # TODO: reduce self._llk_norm by the number of 'parameters', taking into account the real number of used points in the object...

        self.background = background
        if self.background is not None:
            self.background = self.background.astype(np.float32)
        if pixel_size_object is not None:
            self.pixel_size_object = np.float32(pixel_size_object)
        else:
            self.pixel_size_object = None
        if lambdaz is not None:
            self.lambdaz = np.float32(lambdaz)
        else:
            self.lambdaz = None

        if (self.lambdaz is not None and self.pixel_size_object is None) or (self.lambdaz is None and self.pixel_size_object is not None):
            print("WARNING: both lambdaz and pixel_size_object must be given to take into account the quadratic phase factor.")
            self.lambdaz = None
            self.pixel_size_object = None

        # Probe mask can be used to set probe values to 0 beyond the area which is scanned
        self._probe_mask = None

        self.init_cl_workers(opencl_device, opencl_platform)
        self.history = History()
        self.scan_area_probe = None
        self.scan_area_obj = None
        self.scan_area_points = None

        # log-likelihood
        self.llk = 0
        self.llk_normalized = 0

    def __del__(self):
        print("Ptycho2D: destructing worker threads")
        while len(self._cl_workers):
            w = self._cl_workers.pop()
            w.join()
            del w
            gc.collect()

    def calc_scan_area(self):
        """
        Compute the scan area for the object and probe, using scipy ConvexHull. The scan area for the object is
        augmented by twice the average distance between scan positions for a more realistic estimation.
        
        :return: Nothing. self.scan_area_probe and self.scan_area_obj are updated, as 2D arrays with the same shape as the object and probe,
                 with False outside the scan area and True inside.
        """
        # TODO: expand scan area by the average distance between neighbours
        points = np.array([(x, y) for x, y in zip(self.dx - self.dx.mean(), self.dy - self.dy.mean())])
        c = ConvexHull(points)
        vx = points[:, 0]
        vy = points[:, 1]
        # Scan center & average distance between points
        xc = vx.mean()
        yc = vy.mean()
        # Estimated average distance between points with an hexagonal model
        try:
            w = 4 / 3 / np.sqrt(3) * np.sqrt(c.volume / self.dx.size)
        except:
            # c.volume only supported in scipy >=0.17 (2016/02)
            w = 0
        # print("calc_scan_area: scan area = %8g pixels^2, center @(%6.1f, %6.1f), <d>=%6.2f)"%(c.volume, xc, yc, w))
        # Object
        ny, nx = self.obj.shape[-2:]
        xx = np.fft.fftshift(np.fft.fftfreq(nx, d=1. / nx))
        yy = np.fft.fftshift(np.fft.fftfreq(ny, d=1. / ny))[:, np.newaxis]
        self.scan_area_obj = np.ones((ny, nx), dtype=np.float32)
        vd = []
        for i in range(len(c.vertices)):
            # c.vertices are in counterclockwise order
            xx0, yy0 = vx[c.vertices[i - 1]], vy[c.vertices[i - 1]]
            xx1, yy1 = vx[c.vertices[i]], vy[c.vertices[i]]
            if w > 0:
                # Increase distance-to-center for a more realistic scan area
                d0 = np.sqrt((xx0 - xc)**2 + (yy0 - yc)**2)
                xx0 = xc + (xx0 - xc) * (d0 + 2 * w) / d0
                yy0 = yc + (yy0 - yc) * (d0 + 2 * w) / d0
                d1 = np.sqrt((xx1 - xc)**2 + (yy1 - yc)**2)
                xx1 = xc + (xx1 - xc) * (d1 + 2 * w) / d1
                yy1 = yc + (yy1 - yc) * (d1 + 2 * w) / d1

            dx, dy = xx1 - xx0, yy1 - yy0
            vd.append(np.sqrt(dx ** 2 + dy ** 2))
            self.scan_area_obj *= ((xx - xx0) * dy - (yy - yy0) * dx) <= 0
        self.scan_area_obj = self.scan_area_obj > 0.5
        # Probe
        ny, nx = self.probe.shape[-2:]
        xx = np.fft.fftshift(np.fft.fftfreq(nx, d=1. / nx))
        yy = np.fft.fftshift(np.fft.fftfreq(ny, d=1. / ny))[:, np.newaxis]
        self.scan_area_probe = np.ones((ny, nx), dtype=np.float32)
        vd = []
        for i in range(len(c.vertices)):
            # c.vertices are in counterclockwise order
            xx0, yy0 = vx[c.vertices[i - 1]], vy[c.vertices[i - 1]]
            dx, dy = vx[c.vertices[i]] - xx0, vy[c.vertices[i]] - yy0
            vd.append(np.sqrt(dx ** 2 + dy ** 2))
            self.scan_area_probe *= ((xx - xx0) * dy - (yy - yy0) * dx) <= 0
        self.scan_area_probe = self.scan_area_probe > 0.5
        x, y =[points[i, 1] for i in c.vertices], [points[i, 0] for i in c.vertices]
        x.append(points[c.vertices[0], 1])
        y.append(points[c.vertices[0], 0])
        self.scan_area_points = np.array(x), np.array(y)

    def init_cl_workers(self, opencl_device = None, opencl_platform = None):
        """
        Init the OpenCL  workers
        
        Args:
        """
        self._cl_workers = []
        found_dev = []
        if type(opencl_device) is not list:
            opencl_device = [opencl_device]
        for p in cl.get_platforms():
            if opencl_platform is not None:
                if p.name.lower().find(opencl_platform.lower) < 0:
                    continue
            for d in p.get_devices():
                for wanted_dev in opencl_device:
                    if type(wanted_dev) is str:
                        if d.name.lower().find(wanted_dev.lower()) >= 0:
                            found_dev.append(d)
                    else:
                        if d == wanted_dev:
                            found_dev.append(d)
        # Sort devices according to measured fft speed
        found_dev_speed=[]
        ny, nx =self.iobs.shape[-2:]
        for d in found_dev:
            found_dev_speed.append(cl_device_fft_speed(d, fft_shape=(16, ny, nx)))
        idx = (-np.array(found_dev_speed)).argsort()
        found_dev = [found_dev[i] for i in idx]
        found_dev_speed = [found_dev_speed[i] for i in idx]

        # TODO: for now, deactivate using multiple GPUs as the speed increase is not always there.
        found_dev = [found_dev[0]]
        found_dev_speed = [found_dev_speed[0]]

        ndev = len(found_dev)
        nobs = len(self.iobs)
        n16 = ((nobs // ndev + 1) // 16 + 1) * 16
        # print(nobs, ndev, n16)
        vn = list(range(0,nobs,n16))
        if vn[-1] != nobs:
            vn.append(nobs)
        # print("Using %d OpenCL devices:"%ndev)
        print("Using OpenCL device(s):")
        # TODO: load-balancing between GPUs ?
        # TODO: smarter partitioning of ptycho points (per object area)
        for i in range(len(vn)-1):
            d = found_dev[i]
            g = found_dev_speed[i]
            print("   Device: %s [platform=%s] [frame range: %d-%d] [FFT speed: %8.2fGflop/s]"%(d.name, d.platform.name, vn[i], vn[i+1], g))
            self._cl_workers.append(CLPtycho2DWorker(d, False, self.iobs[vn[i]:vn[i+1]],
                                                     (self.dy[vn[i]:vn[i+1]], self.dx[vn[i]:vn[i+1]]), self.probe, self.obj, nobs, mask=self.mask,
                                                     background=self.background, pixel_size_object=self.pixel_size_object, lambdaz=self.lambdaz))
            self._cl_workers[-1].setDaemon(True)
            self._cl_workers[-1].start()
        for w in self._cl_workers:
            w.is_init.wait()

    def _rescale_obj_probe(self):
        """
        Rescale object and probe so that they have the same absolute sum
        
        Returns:

        """
        ps = abs(self.probe).sum()
        os = abs(self.obj).sum()
        self.probe *= np.sqrt(os / ps)
        self.obj   *= np.sqrt(ps / os)

    def run_alternating_projection(self, ncycle, update_object=True, update_probe=True, update_background=False, verbose=False, doplot=False,
                                   timing=False):
        """
        Run alternating projections algorithm - object and/or probe are updating each cycle using the amplitude constraint in Fourier space.
        
        Args:
            ncycle: number of iterations
            update_object: if True, will update the object at each cycle
            update_probe: if True, will update the probe at each cycle
            update_background: if True, will update the background at each cycle
            verbose: report progress in the console
            doplot: plot the object and probe
            timing: report timing

        Returns:
            Nothing. object and probe are updated at the end of the optimization
        """
        s = "\n## Beginning Alternating Projections optimisation of: "
        algo_string = "AP"
        if update_object:
            s += " object"
            if self.obj.ndim == 3:
                algo_string += '/%do'%(self.obj.shape[0])
            else:
                algo_string += '/o'
        if update_probe:
            s += " probe"
            if self.probe.ndim == 3:
                algo_string += '/%dp'%(self.probe.shape[0])
            else:
                algo_string += '/p'
        if update_background:
            s += " background"
            algo_string += '/b'
        print(s + "\n")
        if doplot and self.scan_area_obj is None:
            self.calc_scan_area()
        self._rescale_obj_probe()
        single_worker = len(self._cl_workers) == 1
        if update_background and self.background is None:
            self.background = np.zeros(self.iobs.shape[1:], dtype=np.float32)
            print("Adding null background for optimisation.")
        for w in self._cl_workers:
            w.update_object = update_object
            w.update_probe = update_probe
            w.update_background = update_background
            w.work_alone = single_worker
            w.obj = self.obj
            w.probe = self.probe
            w.mask = self.mask
            w.background = self.background
            w.arrays_to_cl()
            # Check if size of objects (number of modes) has not changed
            w.check_init_cl()
        t00 = timeit.default_timer()
        cputime0 = self.history['cputime'].last_value()
        if cputime0 is None:
            cputime0 = 0
        cycle0 = self.history.last_cycle
        for i in range(1, ncycle+1):
            t0 = timeit.default_timer()

            verb = False
            if verbose:
                if i % verbose == 0 or i ==ncycle or i == 1:
                    verb = True
                    llk = 0
            for w in self._cl_workers:
                w.calc_llk = verb
                w.cycle = i
                # Alternate updating probe and background if both are set to be updated ?
                #if update_probe and update_background:
                #    if i%2 == 0:
                #        w.update_probe = True
                #        w.update_background = False
                #    else:
                #        w.update_probe = False
                #        w.update_background = True
                w.job_f = w.projection_cycle
                w.event_finished.clear()
                w.event_start.set()
            for w in self._cl_workers:
                w.event_finished.wait()
                if verb:
                    llk += w.llk
                if not(single_worker):
                    if w == self._cl_workers[0]:
                        if update_object:
                            self.obj_new = w.obj_new
                            obj_norm = w.obj_new_norm
                        if update_probe:
                            self.probe = w.probe_new
                            probe_norm = w.probe_new_norm
                        if update_background:
                            self.background_d   = w.background_d
                            self.background_d2  = w.background_d2
                            self.background_dz2 = w.background_dz2
                            self.background_z2  = w.background_z2
                    else:
                        if update_object:
                            self.obj_new += w.obj_new
                            obj_norm += w.obj_new_norm
                        if update_probe:
                            self.probe += w.probe_new
                            probe_norm += w.probe_new_norm
                        if update_background:
                            self.background_d   += w.background_d
                            self.background_d2  += w.background_d2
                            self.background_dz2 += w.background_dz2
                            self.background_z2  += w.background_z2
                elif verb:
                    w.arrays_from_cl()
                    self.obj = w.obj
                    self.probe = w.probe
            if not(single_worker):
                # Use first worker to finish probe and object update
                w0 = self._cl_workers[0]
                if update_object:
                    w0.obj_new = self.obj_new
                    w0.obj_new_norm = obj_norm
                if update_probe:
                    w0.probe = self.probe
                    w0.probe_new_norm = probe_norm
                w0.projection_merge_update() # Is that thread-safe ?? Apparently
                # Careful here: do not copy data, but make sure no worker overwrites the obj and probe arrays !
                if update_object:
                    self.obj = w0.obj
                if update_probe:
                    self.probe = w0.probe
                if update_background:
                    self.background = w0.background
                for w in self._cl_workers[1:]:
                    if update_object:
                        w.obj = self.obj
                    if update_probe:
                        w.probe = self.probe
                    if update_background:  # temporary background has origin at 0, instead of center of detector
                        w.background = self.background
                if verb:
                    self.obj = w0.obj.copy()
                    self.probe = w0.probe.copy()
                    if update_background:
                        self.background = w0.background.copy()

            if verb:
                self.llk = llk
                self.llk_normalized = (self.llk - self._llk_offset) / self._llk_norm
                dt= timeit.default_timer() - t0
                dt0= timeit.default_timer() - t00
                print("%s, cycle=%5d/%d: LLK= %6.4e [dt/cycle = %6.3fs]"%(algo_string,i, ncycle, self.llk_normalized, dt))
                self.history.add(cycle0+i, llk=self.llk_normalized, cputime=cputime0+dt0, walltime=time.time(), algorithm=algo_string)
                if doplot:
                    show_obj_probe(self.obj, self.probe, stit="AP cycle %3d: <LLK>= %8.4e" % (i, self.llk_normalized),
                                   pixel_size_object=self.pixel_size_object, scan_area_obj=self.scan_area_obj,
                                   scan_area_probe=self.scan_area_probe, scan_pos=self.scan_area_points)

        if single_worker:
            self._cl_workers[0].arrays_from_cl()
            self.obj = self._cl_workers[0].obj
            self.probe = self._cl_workers[0].probe
            if update_background:  # FInal background has origin at center (like observed data)
                self.background = fftshift(self._cl_workers[0].background)

    def run_ml_poisson(self, ncycle, reg_fac=None, update_object=True, update_probe=True, update_background=False, verbose=False, doplot=False,
                       timing=False):
        """
        Maximum likelihood optimization, using Poisson noise and a conjugate gradient method.
        
        Args:
            ncycle: number of cycles of optimization
            reg_fac: regularization factor, which will smooth the object and/or probe during minimization. Typical value to use = 1e-2
            updateObject: if True, will update the object
            updateProbe: if True, will update the probe
            verbose: will report the current Chi^2 or log(likelihood) every int(verbose) cycle
            doplot: if True, will update the plot of the object and probe every int(verbose) cycles

        Returns: nothing - self.obj and self.probe are updated with new values

        """
        s = "\n## Beginning Maximum Likelihood Conjugate Gradient optimisation of:"
        algo_string = "ML"
        if update_object:
            s += " object"
            if self.obj.ndim == 3:
                algo_string += '/%do'%(self.obj.shape[0])
            else:
                algo_string += '/o'
        if reg_fac is not None:
            s += " (regularization=%g)"%reg_fac
            algo_string += '/reg'
        if update_probe:
            s += " probe"
            if self.probe.ndim == 3:
                algo_string += '/%dp'%(self.probe.shape[0])
            else:
                algo_string += '/p'
        if update_background:
            s += " background"
            algo_string += '/b'
        print(s + "\n")
        if doplot and self.scan_area_obj is None:
            self.calc_scan_area()
        self._rescale_obj_probe()
        nobs = len(self.iobs)
        ny, nx = self.probe.shape[-2:]
        nyo, nxo = self.obj.shape[-2:]
        if update_background and self.background is None:
            self.background = np.zeros((ny,nx), dtype=np.float32)
            print("Adding null background for optimisation.")
        # Regularization scaling factor (Thibault & Guizar-Sicairos 2012)
        reg_fac_obj, reg_fac_probe = None, None
        if reg_fac is not None:
            Nm = nx*ny*nobs
            Nobj0 = self.obj.shape[-2]*self.obj.shape[-1]
            Nprobe0 = self.probe.shape[-2]*self.probe.shape[-1]
            if False:
                # Try to estimate the actual number of pixels covered by the object and probe ?
                if len(self.obj.shape)==3:
                    tmp = abs(self.obj).sum(axis=0)
                else:
                    tmp = abs(self.obj)
                tmp = medfilt2d(tmp, 3)
                Nobj = (tmp>(tmp.max()*0.3)).astype(np.int32).sum()
                if len(self.probe.shape)==3:
                    tmp = abs(self.probe).sum(axis=0)
                else:
                    tmp = abs(self.probe)
                tmp = medfilt2d(tmp, 3)
                Nprobe = (tmp>(tmp.max()*0.3)).astype(np.int32).sum()
            else:
                Nobj = Nobj0
                Nprobe = Nprobe0
            # Total number of photons
            Nph = self.iobs.sum()
            reg_fac_obj   = reg_fac / (8 * Nobj ** 2   / (Nm * Nph))
            # TODO : test probe regularization
            #reg_fac_probe = reg_fac / (8 * Nprobe ** 2 / (Nm * Nph))
            reg_fac_probe = 0
            if True:
                print("Regularization factors: reg_fac/K = %8.4f [object, %d/%d pixels], %8.4f [probe, %d/%d pixels]"%
                      (reg_fac_obj, Nobj, Nobj0,reg_fac_probe, Nprobe, Nprobe0))

        t00 = timeit.default_timer()
        cputime0 = self.history['cputime'].last_value()
        if cputime0 is None:
            cputime0 = 0
        cycle0 = self.history.last_cycle

        single_worker = True #len(self._cl_workers) == 1
        for w in self._cl_workers:
            w.update_object = update_object
            w.update_probe = update_probe
            w.update_background = update_background
            w.work_alone = single_worker
            w.obj = self.obj
            w.probe = self.probe
            w.mask = self.mask
            w.background = self.background
            w.reg_fac = reg_fac
            w.reg_fac_obj = reg_fac_obj
            w.reg_fac_probe = reg_fac_probe
            w.beta = None # Beta is None during the first cycle
            w.arrays_to_cl()
            # Check if size of objects (number of modes) has not changed
            w.check_init_cl()
        if single_worker:
            w = self._cl_workers[0]
            for i in range(1, ncycle+1):
                t0 = timeit.default_timer()

                verb = False
                if verbose:
                    if i % verbose == 0 or i == ncycle or i == 1:
                        verb = True
                        llk = 0
                w.calc_llk = verb
                w.cycle = i
                w.job_f = w.ml_poisson_cycle_single_worker
                w.event_finished.clear()
                w.event_start.set()
                w.event_finished.wait()
                if verb:
                    self.llk = w.llk
                    self.llk_normalized = (self.llk - self._llk_offset) / self._llk_norm
                    dt = timeit.default_timer() - t0
                    dt0 = timeit.default_timer() - t00
                    print("%s, cycle=%5d/%d: LLK= %6.4e [dt/cycle = %6.3fs]" % (algo_string, i, ncycle, self.llk_normalized, dt))
                    self.history.add(cycle0 + i, llk=self.llk_normalized, cputime=cputime0 + dt0, walltime=time.time(), algorithm=algo_string)
                    if doplot:
                        w.arrays_from_cl()
                        self.obj = w.obj.copy()
                        self.probe = w.probe.copy()
                        show_obj_probe(self.obj, self.probe, stit="ML cycle %3d: <LLK>= %8.4e" % (i, self.llk_normalized), pixel_size_object=self.pixel_size_object,
                                       scan_area_obj=self.scan_area_obj, scan_area_probe=self.scan_area_probe, scan_pos=self.scan_area_points)
            w.arrays_from_cl()
            self.obj = w.obj
            self.probe = w.probe
        else:
            # TODO
            for i in range(ncycle):
                t0 = timeit.default_timer()

                verb = False
                if verbose:
                    if i % verbose == 0 or i == ncycle or i == 1:
                        verb = True
                        llk = 0
                for w in self._cl_workers:
                    w.calc_llk = verb
                    w.cycle = i
                    w.job_f = w.ml_poisson_gradient
                    w.event_finished.clear()
                    w.event_start.set()

    def run_difference_map(self, ncycle, update_object=True, update_probe=True, verbose=False, doplot=False, timing=False):
        """
        Difference map optimization.
        
        Args:
            ncycle: number of cycles of optimization
            updateObject: if True, will update the object
            updateProbe: if True, will update the probe
            verbose: will report the current Chi^2 or log(likelihood) every int(verbose) cycle
            doplot: if True, will update the plot of the object and probe every int(verbose) cycles

        Returns: nothing - self.obj and self.probe are updated with new values

        """
        s = "\n## Beginning Difference Map optimisation of: "
        algo_string = "DM"
        if update_object:
            s += " object"
            if self.obj.ndim == 3:
                algo_string += '/%do'%(self.obj.shape[0])
            else:
                algo_string += '/o'
        if update_probe:
            s += " probe"
            if self.probe.ndim == 3:
                algo_string += '/%dp'%(self.probe.shape[0])
            else:
                algo_string += '/p'
        #if update_background:
        #    s += " background"
        #    algo_string += '/b'
        print(s + "\n")
        if doplot and self.scan_area_obj is None:
            self.calc_scan_area()
        self._rescale_obj_probe()
        nobs = len(self.iobs)
        ny, nx = self.probe.shape[-2:]
        nyo, nxo = self.obj.shape[-2:]

        t00 = timeit.default_timer()
        cputime0 = self.history['cputime'].last_value()
        if cputime0 is None:
            cputime0 = 0
        cycle0 = self.history.last_cycle

        single_worker = True #len(self._cl_workers) == 1
        for w in self._cl_workers:
            w.update_object = update_object
            w.update_probe = update_probe
            w.work_alone = single_worker
            w.obj = self.obj
            w.probe = self.probe
            w.mask = self.mask
            w.background = self.background
            w.arrays_to_cl()
            # Check if size of objects (number of modes) has not changed
            w.check_init_cl()
        if single_worker:
            w = self._cl_workers[0]
            for i in range(1, ncycle+1):
                t0 = timeit.default_timer()

                verb = False
                if verbose:
                    if i % verbose == 0 or i == ncycle or i == 1:
                        verb = True
                        llk = 0
                w.calc_llk = verb
                w.cycle = i
                w.job_f = w.dm_cycle_single_worker
                w.event_finished.clear()
                w.event_start.set()
                w.event_finished.wait()
                if verb:
                    self.llk = w.llk
                    self.llk_normalized = (self.llk - self._llk_offset) / self._llk_norm
                    dt = timeit.default_timer() - t0
                    dt0 = timeit.default_timer() - t00
                    print("%s, cycle=%5d/%d: LLK= %6.4e [dt/cycle = %6.3fs]" % (algo_string, i, ncycle, self.llk_normalized, dt))
                    self.history.add(cycle0 + i, llk=self.llk_normalized, cputime=cputime0 + dt0, walltime=time.time(), algorithm=algo_string)
                    if doplot:
                        w.arrays_from_cl()
                        self.obj = w.obj.copy()
                        self.probe = w.probe.copy()
                        show_obj_probe(self.obj, self.probe, stit="DM cycle %3d: <LLK>= %8.4e" % (i, self.llk_normalized), pixel_size_object=self.pixel_size_object,
                                       scan_area_obj=self.scan_area_obj, scan_area_probe=self.scan_area_probe, scan_pos=self.scan_area_points)
            w.arrays_from_cl()
            self.obj = w.obj
            self.probe = w.probe
        else:
            # TODO
            for i in range(ncycle):
                t0 = timeit.default_timer()

                verb = False
                if verbose:
                    if i % verbose == 0 or i == ncycle or i == 1:
                        verb = True
                        llk = 0
                for w in self._cl_workers:
                    w.calc_llk = verb
                    w.job_f = w.ml_poisson_gradient
                    w.event_finished.clear()
                    w.event_start.set()

    def optimize_positions_register(self, upsampling=1, verbose=False, doplot=False, quiver_scale=None, timing=False):
        """
        Optimize the probe positions using image registration. TODO
        
        Args:
            upsampling: if 1, only integer pixels shifts are calculated. Otherwise, subpixel resolution = 1/upsampling is used
            verbose: report progress
            doplot: plot the change in positions
            quiver_scale: scale to plot the change in positions. If None, an automatic scaling will be used
            timing: report how much time is taken

        Returns:
            Nothing. self.pos_x and self.pos_y are updated
        """

    def _llk_poisson(self):
        """
        Compute the Poisson log-likelihood by comparing the calculated intensities and the observed ones, for the current group of images.
        Assumes self._cl_psi currently holds the images calculated by self._calc_forward(i)
        
        Returns:
            the log-likelihood value
        """

    def llk_poisson(self):
        """
        Computes the Poisson log-likelihood by comparing calculated and observed intensities. This function will
        compute _all_ scattered images to make this calculation (should not be called when running an algorithm).
        
        Returns:
            the Poisson log-likelihood from calculated and observed intensities.
        """
    def llk_regularization(self):
        """
        Computes the (negative) log-likelihood value arising from regularization of object and/or probe.
        
        Returns:
            the llk
        """
    def _llk_iobs_gaussian(self):
        """
        Compute the Poisson log-likelihood by comparing the calculated intensities and the observed ones, for the current group of images.
        Assumes self._cl_psi currently holds the images calculated by self._calc_forward(i).

        Returns:
            the log-likelihood value
        """


