#! /opt/local/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

from __future__ import division

import warnings
from sys import stdout
import os
import sys
import time
import copy
import traceback
from PIL import Image
import numpy as np
from scipy.ndimage import zoom
from scipy.ndimage.interpolation import rotate
from scipy.signal import fftconvolve, medfilt2d
from scipy.io import loadmat
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib import pyplot as plt

try:
    from scipy.spatial import ConvexHull
except ImportError:
    ConvexHull = None
try:
    import hdf5plugin
except:
    pass
import h5py

import fabio
from pynx.utils import plot_utils, phase
from pynx import wavefront
from pynx.wavefront import PropagateNearField as PropagateNearField_Wavefront
from pynx.version import __version__
from pynx.utils.math import smaller_primes
from pynx.ptycho import *
from pynx.ptycho import simulation, shape

from pynx.utils.array import rebin, center_array_2d
from pynx.ptycho import analysis

# Generic help text, to be completed with beamline/instrument-specific help text

helptext_generic = """
generic (not beamline-specific) command-line arguments: (all keywords are case-insensitive)
    scan=56: scan number (e.g. in specfile) [mandatory, unless cxifile is used as input].
             Alternatively a list or range of scans can be given:
                scan=12,23,45 or scan="range(12,25)" (note the necessary quotes when using range)

    maxframe=128: limit the total number of frames used to the first N
                  [default=None, all frames are used]

    maxsize=256: the frames are automatically cropped to the largest possible size while keeping 
                 the center of gravity of diffraction in the frame center. Use this to limit the 
                 maximum frame size, for tests or limiting memory use [default=512]

    moduloframe=n1,n2: instead of using all sequential frames of the scan, only take one in n1.
                       if n2<n1 is also given, then the frame numbers taken will be the n for which
                       n % n1 == n2. This can be used to perform two reconstructions with half of 
                       the frames, and then analyse the resolution using Fourier ring correlation.
                       If both moduloframe and maxframe are given, the total number of frames taken 
                       is still maxframe.
                       [default: take all frames]

    algorithm=ML**50,DM**100,probe=1: algorithm for the optimization: [default='ML**50,DM**100']
    
        The algorithm used is:
        - divided in 'steps' separated by commas, e.g: ML**50,DM**100
        - interpreted from right to left, as a mathematical operator to an object on the 
          right-hand side
        - should not contain any space, unless it is given between quotes ('')
        
        - the first type of commands can change basic parameters or perform some analysis:
          - probe=1 or 0: activate or deactivate the probe optimisation (by default only the 
            object is optimised)
          - object=1 or 0: activate or deactivate the object optimisation
          - background=1 or 0: activate or deactivate the background optimisation 
            (only works with AP)
          - nbprobe=3: change the number of modes for the probe (can go up or down)
          - regularization=1e-4: setting the regularization parameter for the object, to penalize 
                                 local variations in ML runs and smooth the solution
          - ortho: will perform orthogonalisation of the probe modes. The modes are sorted by 
            decreasing intensity.
          - analysis: perform an analysis of the probe (propagation, modes). Useful combined 
            with 'saveplot' to save the analysis plots
        
        - the second type are operators which will be applied to the Ptycho object:
          - AP: alternate projections. Slow but converging algorithm
          - DM: difference map. Fast early convergence, oscillating after.
          - ML: maximum likelihood conjugate gradient (Poisson-noise). Robust, converging,
            for final optimization.
          These operators can be combined mathematically, e.g.:
          - DM**100: corresponds to 100 cycles of difference map
          - ML**40*DM**100: 100 cycles of DM followed by 40 cycles of ML (note the order)


        Example algorithms chains:
          - algorithm=ML**40,DM**100,probe=1: activate probe optimisation, 
            then 100 DM and 40 ML (quick)
          - algorithm=ML**100,DM**200,nbprobe=3,ML**40,DM**100,probe=1,DM**100: first DM with 
            object update only,  then 100 DM also updating the probe, then use 3 probe modes 
            and do 100 DM followed by 40 ML
          - algorithm=ML**100*AP**200*DM**200,probe=1: 200 DM then 200 AP then 100 ML (one step)
          - algorithm='(ML**10*DM**20)**5,probe=1': 
            repeat 5 times [20 cycles of DM followed by 5 cycles of ML]
            (note the quotes necessary for the parenthesis)

    nbrun=10: number of optimizations to perform [default=1]

    run0=10: number for the first run (can be used to overwrite previous run results) 
             [default: after previous results or 1]

    liveplot: liveplot during optimization [default: no display]

    saveplot: will save plot at the end of the optimization (png file) [default= not saved].
              Optionally this can also specify if only the object phase should be plotted, e.g.:
              saveplot=object_phase: will display the object phase
              saveplot=object_rgba: will use RGBA to display both amplitude and phase.

    saveprefix=ResultsScan%04d/Run%04d: prefix to save the optimized object and probe 
              (as a .cxi or .npz file) and optionally image (png) [default='ResultsScan%04d/Run%04d']

    output_format='cxi': choose the output format for the final object and support.
                         Possible choices: 'cxi', 'npz'
                         [Default='cxi']

    save=all: either 'final' or 'all' this keyword will activate saving after each optimization 
              step (ptycho, ML) of the algorithm in any given run [default=final]

    load=Results0057/Run0001.cxi (or .npz): load object and probe from previous optimization. Note that the
               object and probe will be scaled if the number of pixels is different for the probe.
              [default: start from a random object, simulate probe]


    loadprobe=Results0057/Run0001.npz (or .cxi): load only probe from previous optimization 
                                       [default: simulate probe]

    loadpixelsize=8.6e-9: specify the pixel size (in meters) from a loaded probe 
                          (and possibly object). If the pixel size is different,
                          the loaded arrays will be scaled to match the new pixel size.
                          [default: when loading previous files, object/probe pixel size is 
                          calculated from the size of the probe array, assuming same detector 
                          distance and pixel size]

    probe=focus,60e-6x200e-6,0.09: define the starting probe, either using:
                                  focus,60e-6x200e-6,0.09: slits size (horizontal x vertical),
                                                           focal distance (all in meters)
                                  focus,200e-6,0.09: radius of the circular aperture, 
                                                     focal distance (all in meters)
                                  gaussian,100e-9x200e-9: gaussian type with horizontal x vertical
                                                          FWHM, both given in meters
                                  disc,100e-9: disc-shape, with diameter given in meters
                                  [mandatory, ignored if 'load' or 'loadprobe' is used]

    defocus=1e-6: defocused position (+: towards detector). The probe is propagated by this distance
                  before being used. This is true both for calculated probes (using 'probe=...') 
                  and for probes loaded from a previous file.

    rotate=30: rotate the probe (either simulated or loaded) by X degrees [default: no rotation]

    object=random,0.9,1,0,6: specify the original object values. The object will be initialised
                             over the entire area using random values: random,0-1,0-6.28 : random 
                             amplitudes between 0 and 1, random phases between 0 and 6.28.
                             For high energy small-angle ptycho (i.e. high transmission), 
                             recommended value is: random,0.9,1,0,0
                             [default: random,0,1,0,6.28]

    verbose=20: print evolution of llk (and display plot if 'liveplot' is set) every N cycle 
                [default=50]

    data2cxi: if set, the raw data will be saved in CXI format (http://cxidb.org/cxi.html), 
              will all the required information for a ptychography experiment (energy, detector 
              distance, scan number, translation axis are all required). if 'data2cxi=crop' 
              is used, the data will be saved after centering and cropping (default is to save 
              the raw data). If this keyword is present, the processing stops after exporting the data.

    mask= or loadmask=mask.npy: the mask to be used for detector data, which should have the same 2D
                       shape as the raw detector data.
                       This should be a boolean or integer array with good pixels=0 and bad ones>0
                       (values are expected to follow the CXI convention)
                       Acceptable formats:
                       - mask.npy, mask.npz (the first data array will be used)
                       - mask.edf or mask.edf.gz (a single 2D array is expected)
                       - "mask.h5:/entry_1/path/to/mask" hdf5 format with the full path to the 
                         2D array. 'hdf5' is also accepted as extension.
                       - 'maxipix': if this special name is entered, the masked pixels will be rows 
                         and columns multiples of 258+/-3

    roi=xmin,xmax,ymin,ymax: the region-of-interest to be used for actual inversion. The area is taken 
                             with python conventions, i.e. pixels with indices xmin<= x < xmax and 
                             ymin<= y < ymax.
                             Additionally, the shape of the area must be square, and 
                             n=xmax-xmin=ymax-ymin must also be a suitable integer number
                             for OpenCL or CUDA FFT, i.e. it must be a multiple of 2 and the largest number in
                             its prime factor decomposition must be less or equal to the largest value
                             acceptable by clFFT (<=13 as of November 2016) or cuFFT (<=7).
                             If n does not fulfill these constraints,
                             it will be reduced using the largest possible integer smaller than n.
                             This option supersedes 'maxsize' unless roi='auto'.
                             Other possible values:
                             - 'auto': automatically selects the roi from the center of mass 
                                       and the maximum possible size. [default]
                             - 'all': use the full, uncentered frames. Only useful for pre-processed
                                      data. Cropping may still be performed to get a square and 
                                      FFT-friendly size.

    rebin=2: the experimental images can be rebinned (i.e. a group of n x n pixels is replaced by a
             single one whose intensity is equal to the sum of all the pixels). This 'rebin' is 
             performed last: the ROI, mask, background, pixel size should all correspond to 
             full (non-rebinned) frames. 
             [default: no rebin]

    autocenter=0: by default, the object and probe are re-centered automatically after each 
                  optimisation step, to avoid drifts. This can be used to deactivate this behaviour
                  [default=True]
                  
    detector_orientation=1,0,0: three flags which, if True, will do in this order: 
                                array transpose (x/y exchange), flipud, fliplr [default: no change]
                                The changes also apply to the mask.
    
    xy=y,x: order and expression to be used for the XY positions (e.g. '-x,y',...). Mathematical 
            operations can also be used, e.g.: xy=0.5*x+0.732*y,0.732*x-0.5*y
            [default: None- may be superseded in some scripts e.g. for ptypy]
    
    flatfield=flat.npy: the flatfield correction to be applied to the detector data. The array must
                        have the same shape as the frames, which will be multiplied by this 
                        correction.
                        Acceptable formats:
                        - flat.npy, flat.npz (the first data array will be used)
                        - flat.edf or flat.edf.gz (a single 2D array is expected)
                        - "flat.h5:/entry_1/path/to/flat" hdf5 format with the full path to the 
                          2D array. 'hdf5' is also accepted as extension.
                        - flat.mat: from a matlab file. The first array found is loaded
                       [default: no flatfield correction is applied]
    
    dark=dark.npy: the dark correction to be subtracted from the detector data. The array must have 
                       the same shape as the frames, which will be substracted by this correction.
                       Acceptable formats:
                       - dark.npy, dark.npz (the first data array will be used)
                       - dark.edf or dark.edf.gz (a single 2D array is expected)
                       - "dark.h5:/entry_1/path/to/dark" hdf5 format with the full path to the 
                         2D array. 'hdf5' is also accepted as extension.
                       - dark.mat: from a matlab file. The first array found is loaded
                       [default: no dark correction is applied]
    
    orientation_round_robin: will test all possible combinations of xy and detector_orientation 
                             to find the correct detector configuration.
"""

# This must be defined in in beamline/instrument-specific scripts
# helptext_beamline = ""

params_generic = {'scan': None, 'algorithm': '100DM,50ML', 'nbrun': 1, 'run0': None, 'liveplot': False,
                  'saveplot': False, 'saveprefix': None, 'livescan': False, 'load': None, 'loadprobe': None,
                  'probe': None, 'defocus': None, 'gpu': None, 'regularization': None, 'save': 'final',
                  'loadpixelsize': None, 'rotate': None, 'maxframe': None, 'maxsize': 512, 'output_format': 'cxi',
                  'object': 'random,0,1,0,6.28', 'verbose': 50, 'moduloframe': None, 'data2cxi': False, 'nrj': None,
                  'pixelsize': None, 'instrument': None, 'epoch': time.time(), 'cxifile': None,
                  'detector_orientation': None, 'xy': None, 'livescan': False, 'loadmask': None, 'roi': 'auto',
                  'rebin': None, 'autocenter': True, 'data': None, 'flatfield': None, 'dark': None,
                  'orientation_round_robin': False, 'fig_num': 100, 'profiling': False}


class PtychoRunnerException(Exception):
    pass


class PtychoRunnerScan(object):
    """
    Abstract class to handle ptychographic data. Must be derived to be used.
    """

    def __init__(self, params, scan):
        self.params = params
        self.scan = scan
        self.defocus_done = False
        self.raw_data_monitor = None
        self.raw_mask = None  # Original mask (uncropped, etc..)
        self.mask = None  # mask for running algorithm
        self.rebinf = 1
        self.p = None  # Ptycho object
        self.data = None  # PtychoData object
        self.flatfield = None
        self.dark = None
        self.iobs = None
        self.raw_x, self.raw_y, self.x, self.y = None, None, None, None

        # Default parameters for optimization
        self.update_object = True
        self.update_probe = False
        self.update_background = False

    def load_data(self):
        """
        Loads data, using beamline-specific parameters. Abstract function, must be derived
        
        Returns:

        """
        print(
            "Why are you calling this function ? It should be superseded in a child class for each instrument/beamline")

    def prepare_processing_unit(self):
        """
        Prepare processing unit (CUDA, OpenCL, or CPU).

        Returns: nothing. Creates self.processing_unit

        """
        # TODO: check processing unit is actually prepared, and max_prime_fft is available
        s = "CDI runner: preparing processing unit"
        if self.params['gpu'] is not None:
            s += " [given GPU name: %s]" % self.params['gpu']
        print(s)
        try:
            default_processing_unit.select_gpu(gpu_name=self.params['gpu'])
        except Exception as ex:
            s0 = "\n  original error: " + str(ex)
            if self.params['gpu'] is not None:
                s = "Failed initialising GPU. Please check GPU name [%s] or CUDA/OpenCL installation"
                raise PtychoRunnerException(s % self.params['gpu'] + s0)
            else:
                raise PtychoRunnerException(
                    "Failed initialising GPU. Please check GPU name or CUDA/OpenCL installation" + s0)
        if default_processing_unit.pu_language == 'cpu':
            raise PtychoRunnerException("CUDA or OpenCL or GPU not available - you need a GPU to use pynx.CDI !")
        self.processing_unit = default_processing_unit

    def load_data_post_process(self):
        """
        Applies some post-processing to the input data, according to parameters. Also loads the mask.
        User-supplied mask is loaded if necessary.
        
        This must be called at the end of load_data()
        :return: 
        """
        self._init_mask(self.raw_data[0].shape)
        if self.raw_mask.shape != self.raw_data[0].shape:
            raise PtychoRunnerException("Mask and raw data shape are not identical !")

        self._load_flat_field()
        if self.flatfield is not None:
            if self.flatfield.shape != self.raw_data[0].shape:
                raise PtychoRunnerException("flatfield and raw data shapes are not identical !")

        self._load_dark()
        if self.dark is not None:
            if self.dark.shape != self.raw_data[0].shape:
                raise PtychoRunnerException("dark and raw data shapes are not identical !")

        # Store original x,y in case we use self.params['xy']
        self.raw_x, self.raw_y = self.x, self.y

    def _init_mask(self, shape):
        """
        Load mask if the corresponding parameter has been set, or just initialize an array of 0.
        This is called after raw data has been loaded by center_crop_data()
        Note that a mask may already exist if pixels were flagged by the detector
        
        Args:
            shape: the 2D shape of the raw data
        Returns:
            Nothing
        """
        mask_user = None
        if self.params['loadmask'] is not None:
            if self.params['loadmask'].find('.h5:') > 0 or self.params['loadmask'].find('.hdf5:') > 0:
                # hdf5 file with path to mask
                s = self.params['loadmask'].split(':')
                h5f = h5py.File(s[0], 'r')
                if s[1] not in h5f:
                    raise PtychoRunnerException(
                        "Error extracting mask from hdf5file: path %s not found in %s" % (s[1], s[0]))
                mask_user = h5f[s[1]].value
                h5f.close()
            elif self.params['loadmask'] == 'maxipix':
                mask_user = np.zeros(shape, dtype=np.int8)
                ny, nx = shape
                for i in range(258, ny, 258):
                    mask_user[i - 3:i + 3] = 1
                for i in range(258, nx, 258):
                    mask_user[:, i - 3:i + 3] = 1
            else:
                filename = self.params['loadmask']
                ext = os.path.splitext(filename)[-1]
                if ext == '.edf' or ext == '.gz':
                    mask_user = fabio.open(filename).data
                elif ext == '.npy':
                    mask_user = np.load(filename)
                elif ext == '.npz':
                    for v in np.load(filename).items():
                        mask_user = v[1]
                        break
                elif ext == '.tif' or ext == '.tiff':
                    mask_user = np.array(Image.open(filename)) > 0
                else:
                    print(ext)
                    print("What is this mask extension: %s ??" % (ext))
            print("Loaded MASK from: %s with % d pixels masked (%5.3f%%)"
                  % (self.params['loadmask'], mask_user.sum(), mask_user.sum() * 100 / mask_user.size))
        if self.raw_mask is None:
            if mask_user is None:
                self.raw_mask = np.zeros(shape, dtype=np.int8)
            else:
                self.raw_mask = mask_user.astype(np.int8)
        elif mask_user is not None:
            self.raw_mask += mask_user.astype(np.int8)
        s = self.raw_mask.sum()
        if s:
            print("Initialized mask with %d (%6.3f%%) bad pixels" % (s, s * 100 / self.raw_mask.size))

    def _load_flat_field(self):
        """
        Load flat field if the corresponding parameter has been set.

        Returns:
            Nothing
        """
        flatfield = None
        if self.params['flatfield'] is not None:
            if self.params['flatfield'].find('.h5:') > 0 or self.params['flatfield'].find('.hdf5:') > 0:
                # hdf5 file with path to flatfield
                s = self.params['flatfield'].split(':')
                h5f = h5py.File(s[0], 'r')
                if s[1] not in h5f:
                    raise PtychoRunnerException(
                        "Error extracting flatfield from hdf5file: path %s not found in %s" % (s[1], s[0]))
                flatfield = h5f[s[1]].value
                h5f.close()
            else:
                filename = self.params['flatfield']
                ext = os.path.splitext(filename)[-1]
                if ext == '.edf' or ext == '.gz':
                    flatfield = fabio.open(filename).data
                elif ext == '.npy':
                    flatfield = np.load(filename)
                elif ext == '.npz':
                    # Just grab the first array
                    flatfield = np.load(filename).items()[0][1]
                elif ext == '.mat':
                    a = list(loadmat(filename).values())
                    for v in a:
                        if np.size(v) > 1000:
                            # Avoid matlab strings and attributes, and get the array
                            flatfield = np.array(v)
                            break
                else:
                    print(ext)
                    print("What is this flatfield extension: %s ??" % (ext))
            print("Loaded FLATFIELD from: ", self.params['flatfield'])
        if flatfield is not None:
            self.flatfield = flatfield.astype(np.float32)
            self.flatfield /= self.flatfield.mean()

    def _load_dark(self):
        """
        Load dark if the corresponding parameter has been set.

        Returns:
            Nothing
        """
        dark = None
        if self.params['dark'] is not None:
            if self.params['dark'].find('.h5:') > 0 or self.params['dark'].find('.hdf5:') > 0:
                # hdf5 file with path to dark
                s = self.params['dark'].split(':')
                h5f = h5py.File(s[0], 'r')
                if s[1] not in h5f:
                    raise PtychoRunnerException(
                        "Error extracting dark from hdf5file: path %s not found in %s" % (s[1], s[0]))
                dark = h5f[s[1]].value
                h5f.close()
            else:
                filename = self.params['dark']
                ext = os.path.splitext(filename)[-1]
                if ext == '.edf' or ext == '.gz':
                    dark = fabio.open(filename).data
                elif ext == '.npy':
                    dark = np.load(filename)
                elif ext == '.npz':
                    # Just grab the first array
                    dark = np.load(filename).items()[0][1]
                elif ext == '.mat':
                    a = list(loadmat(filename).values())
                    for v in a:
                        if np.size(v) > 1000:
                            # Avoid matlab strings and attributes, and get the array
                            dark = np.array(v)
                            break
                else:
                    print(ext)
                    print("What is this dark extension: %s ??" % (ext))
            print("Loaded DARK from: ", self.params['dark'])
        if dark is not None:
            self.dark = dark.astype(np.float32)

    def center_crop_data(self):
        """
        Once the data has been loaded in self.load_data() [overloaded in child classes), this function can be called at the end of load_data
        to take care of centering the data (finding the center of diffraction) and cropping it with a size suitable for clFFT.
        Rebin is performed if necessary.
        
        Returns:
            Nothing. self.iobs and self.dsize are updated. self.raw_data holds the raw data if needed for CXI export
        """
        if self.params['xy'] is not None:
            # TODO: move this to load_data_post_process ?
            x, y = self.raw_x, self.raw_y
            self.x, self.y = eval(self.params['xy'])
        else:
            self.x, self.y = self.raw_x, self.raw_y

        raw_data = self.raw_data

        mask = self.raw_mask

        if self.flatfield is not None:
            flatfield = self.flatfield
        else:
            flatfield = 1

        if self.dark is not None:
            dark = self.dark
        else:
            dark = 0

        if self.params['detector_orientation'] is not None:
            # TODO: move this to load_data_post_process ?
            # User-supplied change of orientation
            do_transpose, do_flipud, do_fliplr = eval(self.params['detector_orientation'])

            if do_fliplr or do_flipud or do_transpose:
                print("Transpose: %d, flipud: %d, fliplr: %d" % (do_transpose, do_flipud, do_fliplr))
                if do_transpose:
                    raw_data = self.raw_data.transpose((0, -1, -2))
                if do_flipud:
                    raw_data = self.raw_data[..., ::-1, :]
                if do_fliplr:
                    raw_data = self.raw_data[..., ::-1]

                if mask is not None:
                    if do_transpose:
                        mask = self.raw_mask.transpose((-1, -2))
                    if do_flipud:
                        mask = self.raw_mask[..., ::-1, :]
                    if do_fliplr:
                        mask = self.raw_mask[..., ::-1]

                if self.flatfield is not None:
                    if do_transpose:
                        flatfield = self.flatfield.transpose((-1, -2))
                    if do_flipud:
                        flatfield = self.flatfield[..., ::-1, :]
                    if do_fliplr:
                        flatfield = self.flatfield[..., ::-1]

                if self.dark is not None:
                    if do_transpose:
                        dark = self.dark.transpose((-1, -2))
                    if do_flipud:
                        dark = self.dark[..., ::-1, :]
                    if do_fliplr:
                        dark = self.dark[..., ::-1]

        raw_data_sum = raw_data.sum(axis=0) - dark * len(raw_data)
        # Find image center
        ny, nx = raw_data_sum.shape

        raw_data_sum0 = raw_data_sum
        # Did user set the ROI to use ?
        if self.params['roi'] == 'auto':
            X, Y = np.meshgrid(np.arange(nx), np.arange(ny))
            if self.raw_mask is not None:
                raw_data_sum[mask > 0] = 0

            # Try to remove hot pixels, but not if too much intensity is removed
            tmp = medfilt2d(raw_data_sum.astype(np.float32), 3)  # Should remove hot pixels
            if tmp.sum() > 0.3 * raw_data_sum.sum():
                raw_data_sum = tmp

            tmp = raw_data_sum > np.percentile(raw_data_sum, 95)
            if tmp.astype(np.int32).sum() > 100:
                raw_data_sum *= tmp

            raw_data_sum = np.ma.masked_array(raw_data_sum, mask)
            x0, y0 = (raw_data_sum * X).sum() / raw_data_sum.sum(), (raw_data_sum * Y).sum() / raw_data_sum.sum()
            print("Center of diffraction: X=%6.2f Y=%6.2f" % (x0, y0))
            x0n, y0n = (int(round(x0 + 0.5)), int(round(y0 + 0.5)))

            # Maximum window size
            if self.params['maxsize'] is not None:
                self.dsize = int(min(x0n, y0n, nx - x0n, ny - y0n, self.params['maxsize'] // 2)) * 2
            else:
                self.dsize = int(min(x0n, y0n, nx - x0n, ny - y0n)) * 2
        elif self.params['roi'] == 'all':
            xmin, xmax, ymin, ymax = 0, raw_data.shape[1], 0, raw_data.shape[0]
            x0n = int((xmax + xmin) // 2)
            y0n = int((ymax + ymin) // 2)
            self.dsize = min((xmax - xmin, ymax - ymin))
        else:
            vs = self.params['roi'].split(',')
            xmin, xmax, ymin, ymax = int(vs[0]), int(vs[1]), int(vs[2]), int(vs[3])
            x0n = int((xmax + xmin) // 2)
            y0n = int((ymax + ymin) // 2)
            self.dsize = min((xmax - xmin, ymax - ymin))

        # Rebin ?
        if self.params['rebin'] is not None:
            self.rebinf = self.params['rebin']
            self.dsize = self.dsize // self.rebinf

        # COmpute acceptable size, depending on cuFFT or clFFT version, with both dimensions even
        prime_fft = self.processing_unit.max_prime_fft
        print("Largest prime number acceptable for FFT size:", prime_fft)
        assert (self.rebinf <= prime_fft)
        self.dsize = smaller_primes(self.dsize, prime_fft, [2])

        ds2r = self.dsize // 2 * self.rebinf

        if self.params['liveplot']:
            # Plot crop area and highlight masked pixels
            plt.figure(99)
            plt.clf()
            vmax = np.log10(np.percentile(raw_data_sum0, 99.9))
            v = np.log10(raw_data_sum0 + 1e-6)
            v = v * (mask == 0) + vmax * 1.2 * (mask != 0)
            plt.imshow(v, vmin=0, vmax=vmax * 1.2, cmap=plt.cm.jet)
            plt.plot([x0n - ds2r, x0n + ds2r, x0n + ds2r, x0n - ds2r, x0n - ds2r],
                     [y0n - ds2r, y0n - ds2r, y0n + ds2r, y0n + ds2r, y0n - ds2r], 'r')
            plt.colorbar()
            plt.title("Sum of raw data [log scale 0-99.9%%, cutoff=1], crop area and masked pixels")
            plt.xlim(0, nx)
            plt.ylim(0, ny)
            try:
                plt.draw()
                plt.pause(.002)
            except:
                pass
                # Don't close window here, Tk + ipython --pylab crashes on this (somehow GUI update out of main loop).
                # plt.close()

        nb = len(raw_data)

        self.iobs = np.zeros((nb, self.dsize, self.dsize))
        for jj in range(nb):
            # TODO: dark should not be substracted (messing with statistics) !
            img = (raw_data[jj] - dark) * flatfield
            img[img < 0] = 0  # Only needed because of dark subtraction...

            if self.rebinf > 1:
                self.iobs[jj] = rebin(img[y0n - ds2r:y0n + ds2r, x0n - ds2r:x0n + ds2r], self.rebinf)
            else:
                self.iobs[jj] = img[y0n - ds2r:y0n + ds2r, x0n - ds2r:x0n + ds2r]

        self.mask = mask[y0n - ds2r:y0n + ds2r, x0n - ds2r:x0n + ds2r]
        if self.rebinf > 1:
            self.mask = rebin(self.mask, self.rebinf)

        # Set masked pixels to 0
        self.iobs *= (self.mask == 0)

        self.params['roi_actual'] = [x0n - ds2r, x0n + ds2r, y0n - ds2r, y0n + ds2r]

        if self.params['algorithm'] != 'manual' and self.params['orientation_round_robin'] is False:
            # Free memory
            self.raw_data = None

    def prepare(self):
        """
        Prepare object and probe.
        
        Returns: nothing. Adds self.obj0 and self.probe0

        """
        z = self.params['detectordistance']
        pixelsize = self.params['pixelsize'] * self.rebinf
        self.angle_rad_per_pixel = pixelsize / z
        self.wavelength = 12.3984 / self.params['nrj'] * 1e-10

        # pix size in reciprocal space
        pix_size_reciprocal = 1 / self.wavelength * self.angle_rad_per_pixel

        # pix size in direct space
        pix_size_direct = 1. / pix_size_reciprocal / self.dsize

        print("E=%6.3fkeV, zdetector=%6.3fm, pixel size=%6.2fum, pixel size(object)=%6.1fnm"
              % (self.params['nrj'], z, pixelsize * 1e6, pix_size_direct * 1e9))

        # scan positions in pixels, relative to center. X and Y from scan data, in meters
        self.xpix = (self.x - self.x.mean()) // pix_size_direct
        self.ypix = (self.y - self.y.mean()) // pix_size_direct

        dx = self.xpix.max() - self.xpix.min()
        dy = self.ypix.max() - self.ypix.min()
        if dx > 8000 or dy > 8000:
            raise PtychoRunnerException(
                "Scan range is to large: (dx, dy) = (%d, %d) in pixels !! Are scan positions in meters and not microns ?" % (
                    dx, dy))

        # Compute the size of the reconstructed object (obj)
        self.ny, self.nx = shape.calc_obj_shape(self.ypix, self.xpix, (self.dsize, self.dsize))
        self.probe_init_info = None
        oldpixelsize = None  # only needed if we load an old object or probe
        if self.params['load'] is None:
            # Initial object parameters
            s = self.params['object'].split(',')
            if s[0].lower() == 'random':
                a0, a1, p0, p1 = float(s[1]), float(s[2]), float(s[3]), float(s[4])
                self.obj_init_info = {'type': 'random', 'range': (a0, a1, p0, p1), 'shape': (self.ny, self.nx)}
                print("Using random object type with amplitude range: %5.2f-%5.2f and phase range: %5.2f-%5.2f" % (
                    a0, a1, p0, p1))
            else:
                raise PtychoRunnerException("Could not understand starting object type: %s", self.params['object'])
            self.data_info = {'wavelength': self.wavelength, 'detector_dist': z, 'detector_pixel_size': pixelsize}

            if self.params['loadprobe'] is None:
                # Initial probe
                s = self.params['probe'].split(',')
                if s[0] == 'focus':
                    s6 = s[1].split('x')
                    if len(s6) == 1:
                        s6r = float(s6[0])
                        self.probe_init_info = {'type': 'focus', 'aperture': (s6r,), 'focal_length': float(s[2]),
                                                'shape': (self.dsize, self.dsize)}
                    else:
                        s6h, s6v = float(s6[0]), float(s6[1])
                        self.probe_init_info = {'type': 'focus', 'aperture': (s6h, s6v), 'focal_length': float(s[2]),
                                                'shape': (self.dsize, self.dsize)}
                elif s[0] == 'disc':
                    s6 = float(s[1])
                    self.probe_init_info = {'type': 'disc', 'radius_pix': (s6 / 2 / pix_size_direct),
                                            'shape': (self.dsize, self.dsize)}
                elif s[0] == 'gaussian' or s[0] == 'gauss':
                    s6 = s[1].split('x')
                    s6h, s6v = float(s6[0]) / (pix_size_direct * 2.3548), float(s6[1]) / (
                            pix_size_direct * 2.3548)
                    self.probe_init_info = {'type': 'gauss', 'sigma_pix': (s6h, s6v), 'shape': (self.dsize, self.dsize)}
                else:
                    # Focused rectangular opening, without initial 'focus' keyword (DEPRECATED)
                    s6 = s[0].split('x')
                    s6h, s6v = float(s6[0]), float(s6[1])
                    self.probe_init_info = {'type': 'focus', 'aperture': (s6h, s6v), 'focal_length': float(s[1]),
                                            'shape': (self.dsize, self.dsize)}
            else:
                self.params['probe'] = None
                ext = os.path.splitext(self.params['loadprobe'])[-1]
                if ext == '.npz':
                    tmp = np.load(self.params['loadprobe'])
                    self.probe0 = tmp['probe']
                    if self.params['loadpixelsize'] is not None:
                        oldpixelsize = self.params['loadpixelsize']
                    elif tmp.keys().count('pixelsize') > 0:
                        # TODO: take into account separate x and y pixel size
                        if np.isscalar(tmp['pixelsize']):
                            oldpixelsize = float(tmp['pixelsize'])
                        else:
                            oldpixelsize = tmp['pixelsize'].mean()
                    else:
                        oldpixelsize = pix_size_direct * self.dsize / self.probe0.shape[-1]
                else:
                    h = h5py.File(self.params['loadprobe'], 'r')
                    # Find last entry in file
                    i = 1
                    while True:
                        if 'entry_%d' % i not in h:
                            break
                        i += 1
                    entry = h['entry_%d' % (i - 1)]
                    self.probe0 = entry['probe/data'].value
                    if self.params['loadpixelsize'] is not None:
                        oldpixelsize = self.params['loadpixelsize']
                    else:
                        oldpixelsize = (entry['probe/x_pixel_size'].value + entry['probe/y_pixel_size'].value) / 2
        else:
            # TODO: also import background if available
            self.params['loadprobe'] = None
            self.params['probe'] = None
            ext = os.path.splitext(self.params['load'])[-1]
            if ext == '.npz':
                self.objprobe = np.load(self.params['load'])
                self.obj0 = self.objprobe['obj']
                self.probe0 = self.objprobe['probe']
                if self.params['loadpixelsize'] is not None:
                    oldpixelsize = self.params['loadpixelsize']
                elif self.objprobe.keys().count('pixelsize') > 0:
                    # TODO: take into account separate x and y pixel size
                    if np.isscalar(self.objprobe['pixelsize']):
                        oldpixelsize = float(self.objprobe['pixelsize'])
                    else:
                        oldpixelsize = self.objprobe['pixelsize'].mean()
                else:
                    oldpixelsize = pix_size_direct * self.dsize / self.probe0.shape[-1]
            else:
                h = h5py.File(self.params['load'], 'r')
                # Find last entry in file
                i = 1
                while True:
                    if 'entry_%d' % i not in h:
                        break
                    i += 1
                entry = h['entry_%d' % (i - 1)]
                self.probe0 = entry['probe/data'].value
                self.obj0 = entry['object/data'].value
                if self.params['loadpixelsize'] is not None:
                    oldpixelsize = self.params['loadpixelsize']
                else:
                    oldpixelsize = (entry['probe/x_pixel_size'].value + entry['probe/y_pixel_size'].value) / 2

        if oldpixelsize is not None:
            print(1, oldpixelsize, pix_size_direct, np.isclose(oldpixelsize, pix_size_direct, 1e-3, 0))
            # We loaded a probe and/or object, need to scale ?
            if np.isclose(oldpixelsize, pix_size_direct, 1e-3, 0) is False:
                scale = oldpixelsize / pix_size_direct
                print(
                    "RESCALING by factor %5.2f, pixel sizes: %5e -> %5e" % (scale, oldpixelsize, pix_size_direct))
                # resize probe
                nz = 1
                if self.probe0.ndim == 3:
                    nz = self.probe0.shape[0]
                nold = self.probe0.shape[-1]
                oldprobe = self.probe0.reshape((nz, nold, nold))
                oldprobe = zoom(oldprobe.real, (1, scale, scale)) + 1j * zoom(oldprobe.imag, (1, scale, scale))
                self.probe0 = np.zeros((nz, self.dsize, self.dsize), dtype=np.complex64)
                nold = oldprobe.shape[-1]
                if nold % 2:
                    oldprobe = oldprobe[:, :-1, :-1]
                    nold = oldprobe.shape[-1]
                if nold < self.dsize:
                    self.probe0[:, self.dsize // 2 - nold // 2: self.dsize // 2 + nold // 2,
                    self.dsize // 2 - nold // 2:self.dsize // 2 + nold // 2] = oldprobe
                else:
                    self.probe0 = oldprobe[:, nold // 2 - self.dsize // 2: nold // 2 + self.dsize // 2,
                                  nold // 2 - self.dsize // 2:nold // 2 + self.dsize // 2]

                if nz == 1:
                    self.probe0 = self.probe0.reshape((self.dsize, self.dsize))

                if self.params['load'] is not None:
                    # Resize object as well
                    oldobj = self.obj0
                    nzo = 1
                    if oldobj.ndim == 3:
                        nzo = oldobj.shape[0]
                    nyo, nxo = oldobj.shape[-2:]
                    oldobj = oldobj.reshape((nzo, nyo, nxo))
                    self.obj0 = np.zeros((nzo, self.ny, self.nx), dtype=np.complex64)
                    oldobj = zoom(oldobj.real, (1, scale, scale)) + 1j * zoom(oldobj.imag, (1, scale, scale))
                    nyo, nxo = oldobj.shape[-2:]
                    if nyo % 2 == 1:
                        if oldobj.ndim == 2:
                            oldobj = oldobj[:-1]
                        else:
                            oldobj = oldobj[:, :-1]
                        nyo -= 1
                    if nxo % 2 == 1:
                        if oldobj.ndim == 2:
                            oldobj = oldobj[:, :-1]
                        else:
                            oldobj = oldobj[:, :, :-1]
                        nxo -= 1
                    if (nyo + nxo) < (self.ny + self.nx):
                        self.obj0[:, self.ny // 2 - nyo // 2: self.ny // 2 + nyo // 2,
                        self.nx // 2 - nxo // 2:self.nx // 2 + nxo // 2] = oldobj
                    else:
                        self.obj0 = oldobj[:, nyo // 2 - self.ny // 2: nyo // 2 + self.ny // 2,
                                    nxo // 2 - self.nx // 2:nxo // 2 + self.nx // 2]

                    if nzo == 1:
                        self.obj0 = self.obj0.reshape((self.ny, self.nx))

    def run(self):
        """
        Main fwork function, will run according to the set of algorithms specified in self.params.
        
        :return: 
        """
        # Create directory to save files
        path = os.path.split(self.params['saveprefix'])[0] % (self.scan)
        os.makedirs(path, exist_ok=True)
        if self.params['run0'] is None:
            # Look for existing saved files
            run0 = 1
            while True:
                testfile1 = self.params['saveprefix'] % (self.scan, run0) + ".npz"
                testfile2 = self.params['saveprefix'] % (self.scan, run0) + "-00.npz"
                testfile3 = self.params['saveprefix'] % (self.scan, run0) + ".cxi"
                if os.path.isfile(testfile1) or os.path.isfile(testfile2) or os.path.isfile(testfile3):
                    run0 += 1
                else:
                    break
        else:
            run0 = 1
        # Enable profiling ?
        if self.params['profiling']:
            self.processing_unit.enable_profiling(True)

        for run in range(run0, run0 + self.params['nbrun']):
            self._run = run
            print("\n", "#" * 100, "\n#", "\n# Scan: %3d Run: %g\n#\n" % (self.scan, run), "#" * 100)

            # Init object and probe according to parameters
            if self.params['load'] is None:
                init = simulation.Simulation(obj_info=self.obj_init_info, probe_info=self.probe_init_info,
                                             data_info=self.data_info)
                init.make_obj()
                self.obj0 = init.obj.values
                print("Making obj:", self.obj0.shape, self.ny, self.nx)

                if self.params['loadprobe'] is None:
                    init.make_probe()
                    self.probe0 = init.probe.values

            self.data = PtychoData(iobs=self.iobs, positions=(self.y - self.y.mean(), self.x - self.x.mean()),
                                   detector_distance=self.params['detectordistance'],
                                   mask=None, pixel_size_detector=self.params['pixelsize'], wavelength=self.wavelength)

            if self.params['defocus'] is not None and self.params['defocus'] != 0 and self.defocus_done is False:
                self.defocus_done = True  # Don't defocus again for multiple runs of the same scan
                if len(self.probe0.shape) == 2:
                    self.w = wavefront.Wavefront(d=np.fft.fftshift(self.probe0.astype(np.complex64)),
                                                 wavelength=self.wavelength,
                                                 pixel_size=self.data.pixel_size_object()[0])
                    self.w = PropagateNearField_Wavefront(self.params['defocus']) * self.w

                    self.probe0 = self.w.get(shift=True)
                else:
                    # Propagate all modes
                    for i in range(len(self.probe0)):
                        self.w = wavefront.Wavefront(d=np.fft.fftshift(self.probe0[i].astype(np.complex64)),
                                                     wavelength=self.wavelength,
                                                     pixel_size=self.data.pixel_size_object()[0])
                        self.w = PropagateNearField_Wavefront(self.params['defocus']) * self.w
                        self.probe0[i] = self.w.get(shift=True)

            if self.params['rotate'] is not None:
                # Rotate probe
                print("ROTATING probe by %6.2f degrees" % (self.params['rotate']))
                re, im = self.probe0.real, self.probe0.imag
                self.probe0 = rotate(re, self.params['rotate'], reshape=False, axes=(-2, -1)) + 1j * rotate(im,
                                                                                                            self.params[
                                                                                                                'rotate'],
                                                                                                            reshape=False,
                                                                                                            axes=(
                                                                                                                -2, -1))

            self.p = Ptycho(probe=self.probe0, obj=self.obj0, data=self.data, background=None)
            self.p = ScaleObjProbe() * self.p

            self._algo_s = ""
            self._stepnum = 0
            if self.params['algorithm'].lower() == 'manual':
                return
            self.run_algorithm(self.params['algorithm'])

    def run_algorithm(self, algo_string):
        """
        Run a single or suite of algorithms in a given run.
        
        :param algo_string: a single or suite of algorithm steps to use, e.g. 'ML**100' or
                           'analysis,ML**100,DM**100,nbprobe=3,DM**100'
                           or 'analysis,ML**100*DM**100,nbprobe=3,DM**100'
        :return: Nothing
        """
        use_old_algo_string = False
        if '*' not in algo_string:
            for s in ['0ap', 'ap0', '0dm', 'dm0', '0ml', 'ml0']:
                if s in algo_string.lower():
                    use_old_algo_string = True
                    break

        if use_old_algo_string:
            print("\n", "#" * 100, "\n#",
                  "\n# WARNING: You are using the old algorithm strings, which are DEPRECATED\n#"
                  "\n#      5s sleep - remember to read the updated command-line help !\n"
                  "# If you were writing: algorithm=probe=1,nbprobe=3,100DM,40ML,analysis\n"
                  "# You should now use:  algorithm=analysis,ML**40,DM**100,nbprobe=3,probe=1 (order right-to-left !)\n"
                  "# Or alternatively:    algorithm=analysis,ML**40*DM**100,nbprobe=3,probe=1\n"
                  + "#" * 100)
            time.sleep(5)
            for algo in algo_string.split(','):
                if self._algo_s == "":
                    self._algo_s += algo
                else:
                    self._algo_s += ',' + algo
                print("\n", "#" * 100, "\n#", "\n#         Run: %g , Algorithm: %s\n#\n" % (self._run, algo), "#" * 100)
                realoptim = False  # Is this a real optimization (ptycho or ML), or just a change of parameter ?
                show_obj_probe = self.params['liveplot']
                if show_obj_probe:
                    show_obj_probe = self.params['verbose']
                if algo.lower().find('ap') >= 0:
                    realoptim = True
                    s = algo.lower().split('ap')
                    if len(s[0].strip()) >= 1:
                        nbcycle = int(s[0])
                    elif len(s[1].strip()):
                        nbcycle = int(s[1])
                    print("Updating background:", self.update_background)
                    self.p = AP(update_object=self.update_object, update_probe=self.update_probe,
                                update_background=self.update_background,
                                show_obj_probe=show_obj_probe,
                                calc_llk=self.params['verbose'], fig_num=100) ** nbcycle * self.p
                elif algo.lower().find('dm') >= 0:
                    realoptim = True
                    s = algo.lower().split('dm')
                    if len(s[0].strip()) >= 1:
                        nbcycle = int(s[0])
                    elif len(s[1].strip()):
                        nbcycle = int(s[1])
                    # TODO: enable background update once it's efficient with DM
                    self.p = DM(update_object=self.update_object, update_probe=self.update_probe,
                                show_obj_probe=show_obj_probe,
                                calc_llk=self.params['verbose'], fig_num=100) ** nbcycle * self.p
                elif algo.lower().find('ml') >= 0:
                    realoptim = True
                    s = algo.lower().split('ml')
                    if len(s[0].strip()) >= 1:
                        nbcycle = int(s[0])
                    elif len(s[1].strip()):
                        nbcycle = int(s[1])
                    # TODO: enable background update once it's efficient with ML
                    self.p = ML(update_object=self.update_object, update_probe=self.update_probe,
                                show_obj_probe=show_obj_probe, reg_fac_obj=self.params['regularization'],
                                calc_llk=self.params['verbose'], fig_num=100) ** nbcycle * self.p
                elif algo.lower().find('ortho') >= 0:
                    self.p = OrthoProbe(verbose=True) * self.p
                elif algo.lower().find('nbprobe=') >= 0:
                    nb_probe = int(algo.lower().split('nbprobe=')[-1])

                    pr = self.p.get_probe()
                    nz, ny, nx = pr.shape
                    if nb_probe == nz:
                        continue

                    pr1 = np.empty((nb_probe, ny, nx), dtype=np.complex64)
                    for i in range(min(nz, nb_probe)):
                        pr1[i] = pr[i]
                    for i in range(nz, nb_probe):
                        n = abs(pr[0]) * np.random.uniform(0, 0.04, (ny, nx))
                        pr1[i] = n * np.exp(1j * np.random.uniform(0, 2 * np.pi, (ny, nx)))

                    self.p.set_probe(pr1)

                elif algo.lower().find('object=') >= 0:
                    self.update_object = int(algo.lower().split('object=')[-1])
                elif algo.lower().find('probe=') >= 0:
                    self.update_probe = int(algo.lower().split('probe=')[-1])
                elif algo.lower().find('background=') >= 0:
                    self.update_background = int(algo.lower().split('background=')[-1])
                elif algo.lower().find('regularization=') >= 0:
                    self.params['regularization'] = float(algo.lower().split('regularization=')[-1])
                elif algo.lower().find('analyze') >= 0 or algo.lower().find('analysis') >= 0:
                    probe = self.p.get_probe()
                    if self.params['saveplot']:
                        steps = "-%02d" % (self._stepnum - 1)
                        save_prefix = self.params['saveprefix'] % (self.scan, self._run) + steps
                    else:
                        save_prefix = None
                    self.p = AnalyseProbe(modes=True, focus=True, verbose=True,
                                          save_prefix=save_prefix, show_plot=False) * self.p
                    if self.params['saveplot']:
                        if os.name is 'posix':
                            try:
                                if probe.shape[0] > 1:
                                    sf = os.path.split(save_prefix + '-probe-modes.png')
                                    os.system('ln -sf "%s" %s' % (sf[1], os.path.join(sf[0], 'latest-probe-modes.png')))
                                sf = os.path.split(save_prefix + '-probe-z.png')
                                os.system('ln -sf "%s" %s' % (sf[1], os.path.join(sf[0], 'latest-probe-z.png')))
                            except:
                                pass
                else:
                    print("ERROR: did not understand algorithm step:", algo)

                if realoptim and self.params['autocenter']:
                    pr = self.p.get_probe()
                    obj = self.p.get_obj()
                    pr, obj = center_array_2d(pr, other_arrays=obj, iz=0)
                    self.p.set_obj(obj)
                    self.p.set_probe(pr)

                if self.params['save'] == 'all' and realoptim:
                    self.save(self._run, self._stepnum, self._algo_s)
                    self._stepnum += 1
        else:
            # Using new operator-based algorithm
            algo_split = algo_string.split(',')
            algo_split.reverse()
            t0 = timeit.default_timer()
            for algo in algo_split:
                if self._algo_s == "":
                    self._algo_s = algo + self._algo_s
                else:
                    self._algo_s = algo + ',' + self._algo_s
                print("\n", "#" * 100, "\n#", "\n#         Run: %g , Algorithm: %s\n#\n" % (self._run, algo), "#" * 100)
                realoptim = False  # Is this a real optimization (ptycho or ML), or just a change of parameter ?

                if algo.lower().find('ortho') >= 0:
                    self.p = OrthoProbe(verbose=True) * self.p
                elif algo.lower().find('nbprobe=') >= 0:
                    nb_probe = int(algo.lower().split('nbprobe=')[-1])

                    pr = self.p.get_probe()
                    nz, ny, nx = pr.shape
                    if nb_probe == nz:
                        continue

                    pr1 = np.empty((nb_probe, ny, nx), dtype=np.complex64)
                    for i in range(min(nz, nb_probe)):
                        pr1[i] = pr[i]
                    for i in range(nz, nb_probe):
                        n = abs(pr[0]) * np.random.uniform(0, 0.04, (ny, nx))
                        pr1[i] = n * np.exp(1j * np.random.uniform(0, 2 * np.pi, (ny, nx)))

                    self.p.set_probe(pr1)

                elif algo.lower().find('object=') >= 0:
                    self.update_object = int(algo.lower().split('object=')[-1])
                elif algo.lower().find('probe=') >= 0:
                    self.update_probe = int(algo.lower().split('probe=')[-1])
                elif algo.lower().find('background=') >= 0:
                    self.update_background = int(algo.lower().split('background=')[-1])
                elif algo.lower().find('regularization=') >= 0:
                    self.params['regularization'] = float(algo.lower().split('regularization=')[-1])
                elif algo.lower().find('analyze') >= 0 or algo.lower().find('analysis') >= 0:
                    probe = self.p.get_probe()
                    if self.params['saveplot']:
                        steps = "-%02d" % (self._stepnum - 1)
                        save_prefix = self.params['saveprefix'] % (self.scan, self._run) + steps
                    else:
                        save_prefix = None
                    self.p = AnalyseProbe(modes=True, focus=True, verbose=True,
                                          save_prefix=save_prefix, show_plot=False) * self.p
                    if self.params['saveplot']:
                        if os.name is 'posix':
                            try:
                                if probe.shape[0] > 1:
                                    sf = os.path.split(save_prefix + '-probe-modes.png')
                                    os.system('ln -sf "%s" %s' % (sf[1], os.path.join(sf[0], 'latest-probe-modes.png')))
                                sf = os.path.split(save_prefix + '-probe-z.png')
                                os.system('ln -sf "%s" %s' % (sf[1], os.path.join(sf[0], 'latest-probe-z.png')))
                            except:
                                pass
                elif algo.lower().find('verbose=') >= 0:
                    self.params['verbose'] = int(algo.lower().split('verbose=')[-1])
                elif algo.lower().find('live_plot=') >= 0:
                    self.params['liveplot'] = int(algo.lower().split('live_plot=')[-1])
                elif algo.lower().find('fig_num=') >= 0:
                    self.params['fig_num'] = int(algo.lower().split('fig_num=')[-1])
                else:
                    # This should be an operator string to apply
                    realoptim = True
                    show_obj_probe = self.params['liveplot']
                    if show_obj_probe:
                        show_obj_probe = self.params['verbose']
                    fig_num = self.params['fig_num']

                    # First create basic operators
                    ap = AP(update_object=self.update_object, update_probe=self.update_probe,
                            update_background=self.update_background,
                            show_obj_probe=show_obj_probe,
                            calc_llk=self.params['verbose'], fig_num=fig_num)
                    dm = DM(update_object=self.update_object, update_probe=self.update_probe,
                            show_obj_probe=show_obj_probe,
                            calc_llk=self.params['verbose'], fig_num=fig_num)
                    ml = ML(update_object=self.update_object, update_probe=self.update_probe,
                            show_obj_probe=show_obj_probe, reg_fac_obj=self.params['regularization'],
                            calc_llk=self.params['verbose'], fig_num=fig_num)

                    showobjprobe = ShowObjProbe(fig_num=fig_num)

                    try:
                        ops = eval(algo.lower())
                        self.p = ops * self.p
                    except Exception as ex:
                        # print(self.help_text)
                        print('\n\n Caught exception for scan %d: %s    \n' % (self.scan, str(ex)))
                        print(traceback.format_exc())
                        print('Could not interpret operator-based algorithm (see above error): ', algo)
                        # TODO: print valid examples of algorithms

                if realoptim and self.params['autocenter']:
                    pr = self.p.get_probe()
                    obj = self.p.get_obj()
                    pr, obj = center_array_2d(pr, other_arrays=obj, iz=0)
                    self.p.set_obj(obj)
                    self.p.set_probe(pr)

                if self.params['save'] == 'all' and realoptim:
                    self.save(self._run, self._stepnum, self._algo_s)
                    self._stepnum += 1
            print("\nTotal elapsed time for algorithms: %8.2fs " % (timeit.default_timer() - t0))
        if self.params['save'] != 'all' and self.params['algorithm'].lower() not in ['analyze', 'manual']:
            self.save(self._run)

        if self.params['profiling'] and 'cl_event_profiling' in dir(self.processing_unit):
            # Profiling can only work with OpenCL
            print("\n", "#" * 100, "\n#", "\n#         Profiling info\n#\n", "#" * 100)
            dt = 0
            vv = []
            for s in self.processing_unit.cl_event_profiling:
                v = np.array([(e.event.profile.end - e.event.profile.start) for e in
                              self.processing_unit.cl_event_profiling[s]])
                gf = np.array([e.gflops() for e in self.processing_unit.cl_event_profiling[s]])
                gb = np.array([e.gbs() for e in self.processing_unit.cl_event_profiling[s]])
                vv.append((s, v.mean() * 1e-3, len(v), v.sum() * 1e-6, gf.mean(), gb.mean()))
                dt += v.sum() * 1e-6
            vv.sort(key=lambda x: x[3], reverse=True)
            for v in vv:
                print("dt(%25s)=%9.3f µs , nb=%6d, dt_sum=%10.3f ms [%4.1f%%], %8.3f Gflop/s, %8.3f Gb/s"
                      % (v[0], v[1], v[2], v[3], v[3] / dt * 100, v[4], v[5]))
            print("                                                    Total: dt=%11.3f ms" % (dt))
        self.print_probe_fwhm()

    def print_probe_fwhm(self):
        """
        Analyze probe shape and print estimated FWHM

        Returns:
            Nothing
        """
        print("\n", "#" * 100, "\n")
        print("Probe statistics at sample position:")
        analysis.probe_fwhm(self.p.get_probe(), self.data.pixel_size_object()[0])

    def save_data_cxi(self, nexus_compatible=False, crop=True):
        """
        Save the scan data using the CXI format (see http://cxidb.org)
        Args:
            nexus_compatible: if True, save using a NeXus-compatible CXI file
            crop: if True, only the already-cropped data will be save. Otherwise, the original raw data is saved.
        Returns:

        """
        # TODO: handle the case where maxframe and/or moduloframe is used: either save with a specific name, or just warn user, or...
        path = os.path.dirname(self.params['saveprefix'] % (self.scan, 0))
        os.makedirs(path, exist_ok=True)
        filename = os.path.join(path, "data.cxi")
        if os.path.isfile(filename):
            print("Data CXI file already exists, no overwriting: ", filename)
            return
        print("Saving raw data to CXI file: ", filename)
        if nexus_compatible:
            pass
        else:
            # TODO: make sure all data fields are recorded and stored in self.params by all running scripts
            f = h5py.File(filename, "w")
            f.create_dataset("cxi_version", data=140)
            entry_1 = f.create_group("entry_1")
            entry_1.create_dataset("program_name", data="PyNX %s" % (__version__))
            entry_1.create_dataset("start_time", data=time.strftime("%Y-%m-%dT%H:%M:%S%z", time.localtime(time.time())))
            sample_1 = entry_1.create_group("sample_1")
            geometry_1 = sample_1.create_group("geometry_1")
            xyz = np.zeros((3, self.x.size), dtype=np.float32)
            xyz[0] = self.x
            xyz[1] = self.y
            geometry_1.create_dataset("translation", data=xyz)
            data_1 = entry_1.create_group("data_1")
            # data_1.create_dataset("scanid", data=self.scan) # TODO: save scan number ?
            data_1["translation"] = h5py.SoftLink('/entry_1/sample_1/geometry_1/translation')
            if self.raw_data_monitor is not None:
                data_1.create_dataset("monitor", self.raw_data_monitor)

            instrument_1 = entry_1.create_group("instrument_1")
            instrument_1.create_dataset("name", data=self.params['instrument'])

            source_1 = instrument_1.create_group("source_1")
            source_1.create_dataset("energy", data=self.params['nrj'] * 1.60218e-16)  # in J

            detector_1 = instrument_1.create_group("detector_1")

            if self.iobs is None:
                iobs = self.raw_data
            else:
                iobs = self.iobs

            if crop:
                nz, ny, nx = iobs.shape
                detector_1.create_dataset("data", data=iobs, chunks=(1, ny, nx), shuffle=True,
                                          compression="gzip")
                pixel_size = self.params['pixelsize'] * self.rebinf
            else:
                nz, ny, nx = self.raw_data.shape
                detector_1.create_dataset("data", data=self.raw_data, chunks=(1, ny, nx), shuffle=True,
                                          compression="gzip")
                pixel_size = self.params['pixelsize']
            detector_1.create_dataset("distance", data=self.params['detectordistance'])
            detector_1.create_dataset("x_pixel_size", data=pixel_size)
            detector_1.create_dataset("y_pixel_size", data=pixel_size)
            if self.raw_mask is not None:
                if self.raw_mask.sum() != 0:
                    detector_1.create_dataset("mask", data=self.raw_mask, chunks=True, shuffle=True, compression="gzip")
                    print("Also saved mask to hdf5 file")
            # Basis vector - this is the default CXI convention, so could be skipped
            # This corresponds to a 'top, left' origin convention
            basis_vectors = np.zeros((2, 3), dtype=np.float32)
            basis_vectors[0, 1] = -pixel_size
            basis_vectors[1, 0] = -pixel_size
            detector_1.create_dataset("basis_vectors", data=basis_vectors)

            detector_1["translation"] = h5py.SoftLink('/entry_1/sample_1/geometry_1/translation')
            data_1["data"] = h5py.SoftLink('/entry_1/instrument_1/detector_1/data')

            # Remember how import was done
            command = ""
            for arg in sys.argv:
                command += arg + " "
            process_1 = data_1.create_group("process_1")
            process_1.create_dataset("command", data=command)
            f.flush()  # Is that necessary in addition to close ??
            f.close()

    def save(self, run, stepnum=None, algostring=None):
        """
        Save the result of the optimization, and (if  self.params['saveplot'] is True) the corresponding plot.
        This is an internal function.

        :param run:  the run number (integer)
        :param stepnum: the step number in the set of algorithm steps
        :param algostring: the string corresponding to all the algorithms ran so far, e.g. '100DM,100AP,100ML'
        :return:
        """
        if stepnum is None:
            steps = ""
        else:
            steps = "-%02d" % stepnum

        if 'npz' in self.params['output_format'].lower():
            filename = self.params['saveprefix'] % (self.scan, run) + steps + ".npz"
            print("\n", "#" * 100, "\n#",
                  "\n#         Saving object and probe to: %s\n#\n" % filename, "#" * 100)

            # Shift back the phase range to [0...], if object phase is not wrapped.
            # TODO: handle objects with multiple modes
            obj = phase.shift_phase_zero(self.p.get_obj()[0], percent=2, origin=0, mask=self.p.scan_area_obj)
            kwargs = {'obj': obj, 'probe': self.p.get_probe(), 'pixelsize': self.data.pixel_size_object(),
                      'scan_area_obj': self.p.scan_area_obj, 'scan_area_probe': self.p.scan_area_probe}

            if self.p.get_background() is not None:
                if self.p.get_background().sum() > 0:
                    kwargs['background'] = self.p.get_background()

            np.savez_compressed(filename, **kwargs)
            if os.name is 'posix':
                try:
                    sf = os.path.split(filename)
                    os.system('ln -sf "%s" %s' % (sf[1], os.path.join(sf[0], 'latest.npz')))
                except:
                    pass
        else:
            # Save as CXI file
            filename = self.params['saveprefix'] % (self.scan, run) + ".cxi"
            print("\n", "#" * 100, "\n#",
                  "\n#         Saving object and probe to: %s\n#\n" % filename, "#" * 100)

            command = ""
            for arg in sys.argv:
                command += arg + " "
            process = {"command": command}

            if algostring is not None:
                process["algorithm"] = algostring

            params_string = ""
            for p in self.params.items():
                k, v = p
                if v is not None and k not in ['output_format']:
                    params_string += "%s = %s\n" % (k, str(v))

            process["parameters_all"] = params_string
            process["program"] = "PyNX"
            process["version"] = __version__
            process["note"] = {'llk(Poisson)': self.p.llk_poisson / self.p.nb_obs,
                               'llk(Gaussian)': self.p.llk_gaussian / self.p.nb_obs,
                               'llk(Euclidian)': self.p.llk_euclidian / self.p.nb_obs,
                               'nb_photons_calc': self.p.nb_photons_calc}

            self.p.save_obj_probe_cxi(filename, sample_name=None, experiment_id=None,
                                      instrument=self.params['instrument'], note=None,
                                      process=process, append=True)
            if os.name is 'posix':
                try:
                    sf = os.path.split(filename)
                    os.system('ln -sf "%s" %s' % (sf[1], os.path.join(sf[0], 'latest.cxi')))
                except:
                    pass
        if self.params['saveplot']:
            self.save_plot(run, stepnum, algostring)

    def save_plot(self, run, stepnum=None, algostring=None, display_plot=False):
        """
        Save the plot to a png file.

        :param run:  the run number (integer)
        :param stepnum: the step number in the set of algorithm steps
        :param algostring: the string corresponding to all the algorithms ran so far, e.g. '100DM,100AP,100ML'
        :param display_plot: if True, the saved plot will also be displayed
        :return:
        """
        if stepnum is None:
            steps = ""
        else:
            steps = "-%02d" % stepnum
        if algostring is None:
            algostring = self.params['algorithm']

        # TODO: minimise grad phase also for secondary modes ?
        obj = phase.minimize_grad_phase(self.p.get_obj()[0], center_phase=0, global_min=False,
                                        mask=~self.p.scan_area_obj)[0]
        probe = phase.minimize_grad_phase(self.p.get_probe()[0], center_phase=0, global_min=False,
                                          mask=~self.p.scan_area_probe)[0]

        if display_plot:
            try:
                fig = plt.figure(101, figsize=(10, 6))
            except:
                # no GUI or $DISPLAY
                fig = Figure(figsize=(10, 6))
        else:
            fig = Figure(figsize=(10, 6))
        fig.clf()

        ax = fig.add_axes((0.1, 0.1, 0.4, 0.8))
        nyo, nxo = obj.shape[-2:]
        tmpx = nxo / 2 * self.data.pixel_size_object()[1] * 1e6
        tmpy = nyo / 2 * self.data.pixel_size_object()[0] * 1e6
        if self.params['saveplot'] == 'object_phase':
            # Show only object phase
            obja = np.angle(obj)
            smin, smax = np.percentile(np.ma.masked_array(obja, ~self.p.scan_area_obj).compressed(), (2, 98))
            if smax - smin < np.pi:
                cm_phase = plt.cm.get_cmap('gray')
            else:
                smin, smax = 0, 2 * np.pi
                cm_phase = plot_utils.cm_phase
            mp = ax.imshow(obja, vmax=smax, vmin=smin, extent=(tmpx, -tmpx, -tmpy, tmpy), cmap=cm_phase)
            # if smax - smin < np.pi:
            #    fig.colorbar(mp, ax=ax)
            ax.set_title("Object phase [%5.2f-%5.2f radians]" % (smin, smax))
        else:
            # Show object as RGBA/HSV
            if False:
                smin, smax = np.percentile(np.ma.masked_array(abs(obj), ~self.p.scan_area_obj).flatten().compressed(),
                                           (2, 98))
            else:
                smin, smax = 0, np.ma.masked_array(abs(obj), ~self.p.scan_area_obj).max()
            ax.imshow(plot_utils.complex2rgbalin(obj, smax=smax, smin=smin), extent=(tmpx, -tmpx, -tmpy, tmpy))
            ax.set_title("Object amplitude & phase")
            if smin is not None and smax is not None:
                ax.text(0.002, 0.97, "brightness scaling: 0-max=[%5.2f-%5.2f]" % (smin, smax),
                        horizontalalignment='left', verticalalignment='bottom',
                        fontsize=6, transform=ax.transAxes)

        ax.set_xlim(tmpx, -tmpx)
        ax.set_ylim(-tmpy, tmpy)
        ax.set_xlabel(u"x(µm)")
        ax.set_ylabel(u"y(µm)")
        ax.plot(self.p.scan_area_points[0] * self.data.pixel_size_object()[1] * 1e6,
                self.p.scan_area_points[1] * self.data.pixel_size_object()[0] * 1e6, 'k-', linewidth=0.3)

        rx, ry = probe.shape[1] / obj.shape[1], probe.shape[0] / obj.shape[0]
        # ax = fig.add_axes((0.6 + 0.4 * (1 - rx) / 2, 0.1 + 0.8 * (1 - ry) / 2, 0.4 * rx, 0.8 * ry))
        ax = fig.add_axes((0.6 + 0.4 * (1 - rx) / 2, 0.1, 0.4 * rx, 0.8 * ry))
        ny, nx = probe.shape[-2:]
        tmpx = nx / 2 * self.data.pixel_size_object()[1] * 1e6
        tmpy = ny / 2 * self.data.pixel_size_object()[0] * 1e6
        smax = abs(probe * self.p.scan_area_probe).max()

        ax.imshow(plot_utils.complex2rgbalin(probe, smax=smax), extent=(-tmpx, tmpx, -tmpy, tmpy))
        ax.set_xlim(tmpx, -tmpx)
        ax.set_ylim(-tmpy, tmpy)
        ax.set_xlabel(u"x(µm)")
        ax.set_ylabel(u"y(µm)")
        ax.set_title("Probe amplitude & phase")

        # ptycho.insertColorwheel(left=0.47, bottom=.03, width=.06, height=.06)
        ax = fig.add_axes((0.47, 0.01, 0.06, 0.06), facecolor='w')
        plot_utils.colorwheel(ax=ax)

        fig.suptitle("Scan #%d, %d frames, pixelsize=%5.1fnm, LLK=%8.3f\n algo=%s" %
                     (self.scan, len(self.x), self.data.pixel_size_object()[0] * 1e9,
                      self.p.llk_poisson / self.p.nb_obs, algostring), fontsize=9)
        dy = (6 + 1) / 72 / fig.get_size_inches()[1]
        y0 = 0.95 - 1.5 * dy
        n = 1
        vk = [k for k in self.params.keys()]
        vk.sort()
        for k in vk:
            v = self.params[k]
            if v is not None and k not in ['liveplot', 'livescan', 'saveplot', 'scan', 'algorithm', 'save',
                                           'saveprefix', 'nbrun', 'run0',
                                           'pixelsize', 'imgcounter', 'epoch', 'data2cxi', 'verbose']:
                fig.text(0.505, y0 - n * dy, "%s = %s" % (k, str(v)), fontsize=6, horizontalalignment='left',
                         stretch='condensed')
                n += 1
        fig.text(dy, dy, "PyNX v%s, finished at %s" % (__version__, time.strftime("%Y/%m/%d %H:%M:%S")),
                 fontsize=6, horizontalalignment='left', stretch='condensed')

        # Add probe full width information
        fwhmyx, fw20yx, fws = analysis.probe_fwhm(self.p.get_probe(), self.data.pixel_size_object()[0], verbose=False)
        fig.text(0.6, dy, "FWHM : %7.2f(H)x%7.2f(V) nm**2 [peak]" % (fwhmyx[1] * 1e9, fwhmyx[0] * 1e9), fontsize=6,
                 horizontalalignment='left', stretch='condensed')
        fig.text(0.6, 2 * dy, "FW20%%: %7.2f(H)x%7.2f(V)nm**2 [extended]" % (fw20yx[1] * 1e9, fw20yx[0] * 1e9),
                 fontsize=6,
                 horizontalalignment='left', stretch='condensed')
        fig.text(0.6, 3 * dy, "FW (stat):  %7.2f nm" % (fws * 1e9), fontsize=6,
                 horizontalalignment='left', stretch='condensed')

        # Add beam direction
        ax = fig.add_axes((0.55, 0.2, 0.05, 0.05), facecolor='w')
        ax.set_axis_off()
        ax.text(0, 0, 'x\n Beam\n(// z)', horizontalalignment='center', verticalalignment='center')  # fontweight='bold'
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

        if display_plot:
            try:
                # Not supported on all backends (e.g. nbagg)
                plt.draw()
                plt.pause(.001)
            except:
                pass

        canvas = FigureCanvasAgg(fig)
        filename = self.params['saveprefix'] % (self.scan, run) + steps + '.png'
        canvas.print_figure(filename, dpi=150)
        if os.name is 'posix':
            try:
                sf = os.path.split(filename)
                os.system('ln -sf "%s" %s' % (sf[1], os.path.join(sf[0], 'latest.png')))
            except:
                pass

    def plot_llk_history(self):
        # TODO
        plt.figure(110, figsize=(10, 6))
        plot_utils.plot_llk_history(self.p.history)


class PtychoRunner:
    """
    Class to process a series of scans with a series of algorithms, given from the command-line.
    """

    def __init__(self, argv, params, ptycho_runner_scan_class):
        """
        
        :param argv: the command-line parameters
        :param params: parameters for the optimization, with some default values.
        :param ptycho_runner_scan_class: the class to use to run the analysis.
        """
        self.params = copy.deepcopy(params)
        self.argv = argv
        self.PtychoRunnerScan = ptycho_runner_scan_class
        self.parse_arg()
        self.check_params()
        self.help_text = helptext_generic

    def parse_arg(self):
        """
        Parses the arguments given on a command line

        Returns: nothing

        """
        for arg in self.argv:
            if arg.lower() in ['liveplot', 'livescan', 'data2cxi', 'orientation_round_robin', 'profiling']:
                self.params[arg.lower()] = True
            else:
                s = arg.find('=')
                if s > 0 and s < (len(arg) - 1):
                    k = arg[:s].lower()
                    v = arg[s + 1:]
                    print(k, v)
                    # NB: 'scan' is kept as a string, to be able to interpret e.g. scan="range(12,22)"
                    if k == 'mask':
                        k = 'loadmask'
                    if k in ['algorithm', 'saveprefix', 'load', 'object', 'probe', 'loadprobe', 'save', 'monitor',
                             'scan', 'loadmask', 'detector_orientation', 'xy', 'flatfield', 'roi', 'data2cxi',
                             'saveplot', 'dark', 'output_format']:
                        self.params[k] = v
                    elif k in ['gpu']:
                        # If several GPU are listed
                        self.params[k] = v.split(',')
                    elif k in ['nbrun', 'run0', 'maxframe', 'maxsize', 'verbose', 'rebin']:
                        self.params[k] = int(v)
                    elif k in ['defocus', 'loadpixelsize', 'rotate']:
                        self.params[k] = float(v)
                    elif k in ['autocenter']:
                        if v.lower() in ['false', '0']:
                            self.params[k] = False
                    elif k == 'moduloframe':
                        vs = v.split(',')
                        n1 = int(vs[0])
                        if len(vs) == 1:
                            self.params[k] = (n1, 0)
                        else:
                            n2 = int(vs[1])
                            if n2 >= n1 or n2 < 0:
                                raise PtychoRunnerException('Parameter moduloframe: n1=%d, n2=%d must satisfy 0<=n2<n1')
                            self.params[k] = (n1, n2)
                    elif not (self.parse_arg_beamline(k, v)):
                        print("WARNING: argument not interpreted: %s=%s" % (k, v))
                else:
                    if arg.find('saveplot') >= 0:
                        self.params['saveplot'] = True
                    elif arg.find('.cxi') >= 0:
                        self.params['cxifile'] = arg
                    elif arg.find('.py') < 0 and self.parse_arg_beamline(arg.lower(), None):
                        print("WARNING: argument not interpreted: %s" % (arg))

    def parse_arg_beamline(self, k, v):
        """
        Parse argument in a beamline-specific way. This function only parses single arguments.
        If an argument is recognized and interpreted, the corresponding value is added to self.params

        This method should be superseded in a beamline/instrument-specific child class.
        
        Returns:
            True if the argument is interpreted, false otherwise
        """
        return False

    def check_params(self):
        """
        Check if self.params includes a minimal set of valid parameters
        
        Returns: Nothing. Will raise an exception if necessary
        """
        self.check_params_beamline()
        if self.params['probe'] is None and self.params['load'] is None and self.params['loadprobe'] is None and \
                self.params['data2cxi'] is False:
            raise PtychoRunnerException('Missing argument: either probe=, load= or loadprobe= is required')
        if self.params['scan'] is None and self.params['cxifile'] is None and self.params[
            'data'] is None and 'h5meta' not in self.params and self.params['livescan'] is None:
            raise PtychoRunnerException('Missing argument: no scan # given')
        # if self.params['gpu'] is None :
        #    raise PtychoRunnerException('Missing argument: no gpu name given (e.g. gpu=Titan)')
        if self.params['saveprefix'] is None:
            self.params['saveprefix'] = 'ResultsScan%04d/Run%04d'
            print("No saveprefix given, using default: ", self.params['saveprefix'])

    def check_params_beamline(self):
        """
        Check if self.params includes a minimal set of valid parameters, specific to a beamline.
        Derived implementations can also set default values when appropriate.
        
        Returns: Nothing. Will raise an exception if necessary
        """
        pass

    def process_scans(self):
        """
        Run all the analysis on the supplied scan list
        
        :return: Nothing
        """
        if (self.params['cxifile'] is not None or self.params['data'] is not None or 'h5meta' in self.params) and \
                self.params['scan'] is None:
            # Only when reading a CXI, ptypy or a single hdf5 metadata file (from id16) we accept a dummy scan value
            vscan = [0]
        else:
            vscan = eval(self.params['scan'])
            if type(vscan) is int:
                vscan = [vscan]

        cxifile0 = self.params['cxifile']
        for scan in vscan:
            try:
                if cxifile0 is not None:
                    if '%' in cxifile0:
                        self.params['cxifile'] = cxifile0 % scan
                        print('Loading CXIFile:', self.params['cxifile'])
                self.ws = self.PtychoRunnerScan(self.params, scan)
                self.ws.prepare_processing_unit()
                self.ws.load_data()
                if self.params['data2cxi']:
                    if self.params['data2cxi'] == 'crop':
                        self.ws.center_crop_data()
                        self.ws.save_data_cxi(crop=True)
                    else:
                        self.ws.save_data_cxi()
                else:
                    if self.params['orientation_round_robin']:
                        for xy in ['x,y', 'x,-y', '-x,y', '-x,-y', 'y,x', 'y,-x', '-y,x', '-y, -x']:
                            self.params['xy'] = xy
                            for transp in range(2):
                                for flipud in range(2):
                                    for fliplr in range(2):
                                        self.params['detector_orientation'] = '%d,%d,%d' % (transp, flipud, fliplr)
                                        self.ws.center_crop_data()
                                        self.ws.prepare()
                                        self.ws.run()

                    else:
                        self.ws.center_crop_data()
                        self.ws.prepare()
                        self.ws.run()
            except PtychoRunnerException as ex:
                print(self.help_text)
                print('\n\n Caught exception for scan %d: %s    \n' % (scan, str(ex)))
        self.params['cxifile'] = cxifile0
