/* -*- coding: utf-8 -*-
*
* PyNX - Python tools for Nano-structures Crystallography
*   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
*       authors:
*         Vincent Favre-Nicolin, favre@esrf.fr
*/

/** Elementwise kernel to compute an update of the object and probe from Psi. The update of the object is
* done separately for each frame to avoid memory conflicts because of the unknown shift between the frames.
* This should be called with a first argument array with a size of nx*ny, i.e. one frame size. Each parallel
* kernel execution treats one pixel, for one frames and all modes.
* NOTE: this must be optimised (coalesced memory access), as performance is terrible (for fast GPUs)
* NOTE: it is actually faster to use UpdateObjQuadPhaseN (copy to N arrays in //, then sum)
*/
/*
void UpdateObjQuadPhase(const int i, __global float2* psi, __global float2 *objnew, __global float2* probe,
                        __global float* objnorm, const float cx,  const float cy, const float px, const float f,
                        const int stack_size, const int nx, const int ny, const int nxo, const int nyo,
                        const int nbobj, const int nbprobe, const char interp)
{
  if(cx>1e8) return;  // direct beam frames
  // Coordinates in the probe array
  const int prx = i % nx;
  const int pry = i / nx;

  // Pixel coordinates in the Psi array - same as prx, pry, but, fft-shifted (origin at (0,0)). nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  //if(i<140) printf("CL prx=%2d pry=%2d cx=%3d cy=%3d stack_size=%d nx=%d ny=%d nxo=%d nyo=%d nbobj=%d nbprobe=%d, lid=(%d,%d,%d) gid=(%d,%d,%d)\\n", prx, pry, cx, cy, stack_size, nx, ny, nxo, nyo, nbobj, nbprobe,get_local_id(0),get_local_id(1),get_local_id(2),get_global_id(0),get_global_id(1),get_global_id(2));

  float prn=0;

  for(int iprobe=0;iprobe<nbprobe;iprobe++)
  {
    // We could probably ignore interpolation for the normalisation which is averaged over several scan positions.
    const float2 pr = bilinear(probe, prx, pry, iprobe, nx, ny, interp, false);
    prn += dot(pr,pr);
  }

  // Apply Quadratic phase factor after far field propagation (not interpolated ?)
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    float2 o=0;
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      const float2 pr = probe[i + iprobe * nx * ny];
      float2 ps = psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nx * ny];
      ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
      o += (float2) (pr.x*ps.x + pr.y*ps.y , pr.x*ps.y - pr.y*ps.x);
    }
    bilinear_atomic_add_c(objnew, o, cx + prx, cy + pry, iobjmode, nxo, nyo, interp);
    if(iobjmode==0)
      bilinear_atomic_add_f(objnorm, prn, cx + prx, cy + pry, iobjmode, nxo, nyo, interp);
  }
}
*/

/** Elementwise kernel to compute an update of the object and probe from Psi. The update of the object is
* cumulated using atomic operations to avoid memory access conflicts.
* This should be called with a first argument array with a size of nx*ny, i.e. one frame size. Each parallel
* kernel execution treats one pixel, for all frames and all modes.
*/

void UpdateObjAtomic(const int i, __global float2* psi, __global float2 *objnew, __global float2* probe,
                     __global float* objnorm, __global float* cx,  __global float* cy, const float px, const float f,
                     const int stack_size, const int nx, const int ny, const int nxo, const int nyo,
                     const int nbmode, const int npsi, __global float *scale, const char interp,
                     __global int* obj_idx, __global int* probe_idx,
                     __global float* beamx, __global float* beamy)
{
  // Coordinates in the probe array
  const int prx = i % nx;
  const int pry = i / nx;

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;


  // Apply Quadratic phase factor after far field propagation (ignores subpixel interpolation)
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  // For a given object mode, accumulate values for different probe modes before writing
  // More efficient if probe modes are contiguous (they should be)
  int iobj_last = -1;
  float bx, by, bx_last=-1, by_last=-1;

  for(int j=0;j<npsi;j++)
  {
    if(cx[j]>1e8) continue;  // direct beam frames
    float2 o=0;
    // Object normalisation
    float prn=0;
    for(int imode=0 ; imode < nbmode ; imode++)
    {
      const int iobj = obj_idx[imode];
      bx = beamx[imode];
      by = beamy[imode];

      if((iobj_last>=0) && ((iobj != iobj_last) || (bx != bx_last) || (by != by_last)))
      {
        // new object mode, or different beam position, so save accumulated values
        // Distribute the computed o on the 4 corners of the interpolated object
        bilinear_atomic_add_c(objnew, o/ native_sqrt(scale[j]),    cx[j] + bx_last + prx, cy[j] + by_last + pry, iobj_last, nxo, nyo, interp);
        bilinear_atomic_add_f(objnorm, prn, cx[j] + bx_last + prx, cy[j] + by_last + pry, iobj_last, nxo, nyo, interp);
        o = 0;
        prn = 0;
      }
      iobj_last = iobj;
      bx_last = bx;
      by_last = by;

      const int iprobe = probe_idx[imode];

      const float2 pr = probe[i + iprobe * nx * ny];
      float2 ps = psi[ipsi + (j + stack_size * imode ) * (nx * ny)];

      ps = (float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
      o += (float2)(pr.x*ps.x + pr.y*ps.y , pr.x*ps.y - pr.y*ps.x);
      prn += dot(pr,pr);
    }
    // Final object mode update
    bilinear_atomic_add_c(objnew, o/ native_sqrt(scale[j]),    cx[j] + bx + prx, cy[j] + by + pry, iobj_last, nxo, nyo, interp);
    bilinear_atomic_add_f(objnorm, prn, cx[j] + bx + prx, cy[j] + by + pry, iobj_last, nxo, nyo, interp);
    iobj_last = -1;
    o = 0;
    prn = 0;
  }
}

// Same for probe update
void UpdateProbeQuadPhase(const int i, __global float2 *obj, __global float2* probe, __global float2* psi,
                          __global float* probenorm, __global float* cx,  __global float* cy,
                          const float px, const float f, const char firstpass, const int npsi, const int stack_size,
                          const int nx, const int ny, const int nxo, const int nyo, const int nbprobe, const int nbmode,
                          __global float* scale, const char interp, __global int* obj_idx, __global int* probe_idx,
                          __global float* beamx, __global float* beamy)
{
  const int prx = i % nx;
  const int pry = i / nx;

  // obj and probe are centered arrays, Psi is fft-shifted

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  // TODO: find a way to use __local memory - pyopencl does not support LocalMemory arguments for ElementwiseKernel..
  //   ... this caching mode is ugly...
  // Write only probe and probe norm values at the end
  #define NB_PROBE_CACHE 8
  float2 prnew[NB_PROBE_CACHE];
  float prn[NB_PROBE_CACHE];

  // Only update by batches of NB_PROBE_CACHE modes, each time looping on psi values
  for(int iprobe0=0; iprobe0<nbprobe ; iprobe0+=NB_PROBE_CACHE)
  {
    // Init local cache
    for(int iprobe=0; iprobe<NB_PROBE_CACHE; iprobe++)
    {
      prnew[iprobe] = 0;
      prn[iprobe] = 0;
    }

    for(int j=0;j<npsi;j++)
    {
      for(int imode=0 ; imode < nbmode ; imode++)
      {
        const int iprobe = probe_idx[imode] - iprobe0;
        if((iprobe<0) || (iprobe>=NB_PROBE_CACHE)) continue;

        const int iobj = obj_idx[imode];
        const float bx = beamx[imode];
        const float by = beamy[imode];

        float2 ps = psi[ipsi + (j + stack_size * imode) * nx * ny] / native_sqrt(scale[j]);
        ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);

        if(cx[j]>1e8)
        {
          prnew[iprobe] += ps;
          prn[iprobe] += 1.0f;
        }
        else
        {
          const float2 o = bilinear(obj, cx[j]+bx+prx, cy[j]+by+pry, iobj, nxo, nyo, interp, false);
          prnew[iprobe] += (float2) (o.x*ps.x + o.y*ps.y , o.x*ps.y - o.y*ps.x);
          prn[iprobe] += dot(o,o);
        }
      }
    }
    // Store values
    const int m = nbprobe - iprobe0 > NB_PROBE_CACHE ? NB_PROBE_CACHE : nbprobe - iprobe0;
    if(firstpass)
      for(int iprobe=0; iprobe<m; iprobe++)
      {
        probe[i + (iprobe+iprobe0) * nx * ny] = prnew[iprobe] ;
        probenorm[i + (iprobe+iprobe0) * nx * ny] = prn[iprobe] ;
      }
    else
      for(int iprobe=0; iprobe<m; iprobe++)
      {
        probe[i + (iprobe+iprobe0) * nx * ny] += prnew[iprobe] ;
        probenorm[i + (iprobe+iprobe0) * nx * ny] += prn[iprobe] ;
      }
  }
}

// Normalize object.
// The regularization term is used as in: Marchesini et al, Inverse problems 29 (2013), 115009, eq (14)
void ObjNorm(const int i, __global float2 *obj_unnorm, __global float* objnorm, __global float2 *obj,
             __global float *normmax, const float inertia)
{
  const float reg = normmax[0] * inertia;
  const float norm = fmax(objnorm[i] + reg, 1e-12f);
  obj[i] = (obj_unnorm[i] + reg * obj[i]) / norm ;
}

// Normalize object.
// The regularization term is used as in: Marchesini et al, Inverse problems 29 (2013), 115009, eq (14)
// This version includes a scaling factor to compensate for average frame scaling
void ObjNormScale(const int i, __global float2 *obj_unnorm, __global float* objnorm, __global float2 *obj,
                  __global float *normmax, const float inertia, __global float* scale_sum, const int nb_frame)
{
  const float reg = normmax[0] * inertia;
  const float norm = fmax((objnorm[i] + reg)/ native_sqrt(scale_sum[0] / nb_frame), 1e-12f);
  obj[i] = (obj_unnorm[i] + reg * obj[i]) / norm ;
}
