/* -*- coding: utf-8 -*-
*
* PyNX - Python tools for Nano-structures Crystallography
*   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
*       authors:
*         Vincent Favre-Nicolin, favre@esrf.fr
*/


/** Elementwise kernel to compute an update of the object and probe from Psi. The update of the object is
* cumulated using atomic operations to avoid memory access conflicts.
* This should be called with a first argument array with a size of nx*ny, i.e. one frame size. Each parallel
* kernel execution treats one pixel, for all frames and all modes.
*/

__device__ void UpdateObjQuadPhaseAtomic(const int i, complexf* psi, complexf *objnew, complexf* probe,
                        float* objnorm, float* cx,  float* cy, const float px, const float f,
                        const int stack_size, const int nx, const int ny, const int nxo, const int nyo,
                        const int nbmode, const int npsi, const int padding,
                        const int padding_window, const bool interp, int* obj_idx, int* probe_idx,
                        float* beamx, float* beamy)
{
  // Coordinates in the probe array
  const int prx = i % nx;
  const int pry = i / nx;

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Use a window (Tukey or erfc) for padded data in near field
  float cpad = 1.0;

  if(padding > 0)
  { // Tukey dampening
    const int p0 = padding-padding_window;
    const int p1 = padding+padding_window;
//    const int p0 = padding;
//    const int p1 = padding+2*padding_window;
    // Padding factor goes from 0 on the border to 1 at 2*padding pixels from the border
    if(prx<p0)            cpad = 0;
    else if(prx<p1)       cpad *= 0.5f * (1.0f - __cosf(   (prx-p0) * 1.57079632679f / padding_window));
    if(prx>=(nx-p0))      cpad =0;
    else if(prx>=(nx-p1)) cpad *= 0.5f * (1.0f - __cosf((nx-p0-prx) * 1.57079632679f / padding_window));
    if(pry<p0)            cpad =0;
    else if(pry<p1)       cpad *= 0.5f * (1.0f - __cosf(   (pry-p0) * 1.57079632679f / padding_window));
    if(pry>=(ny-p0))      cpad = 0;
    else if(pry>=(ny-p1)) cpad *= 0.5f * (1.0f - __cosf((ny-p0-pry) * 1.57079632679f / padding_window));
  }

/*
  if(padding > 0)
  { // erfc dampening
    const float sigma = padding_window*0.5f;  // ~full-width of the erfc function
    if(prx<(2*padding))     cpad *= 0.5f * erfcf((padding-prx)/sigma);
    else if(prx>=(nx-2*padding)) cpad *= 0.5f * erfcf((prx-(nx-padding))/sigma);
    if(pry<(2*padding))     cpad *= 0.5f * erfcf((padding-pry)/sigma);
    else if(pry>=(ny-2*padding)) cpad *= 0.5f * erfcf((pry-(ny-padding))/sigma);
  }
*/

  // Apply Quadratic phase factor after far field propagation (ignores subpixel interpolation)
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  float s, c;
  __sincosf(tmp , &s, &c);

  // For a given object mode, accumulate values for different probe modes before writing
  // More efficient if probe modes are contiguous (they should be)
  int iobj_last = -1;
  float bx, by, bx_last=-1, by_last=-1;

  for(int j=0;j<npsi;j++)
  {
    if(cx[j]>1e8) continue;  // direct beam frames
    complexf o=0;

    float prn=0;

    for(int imode=0 ; imode < nbmode ; imode++)
    {
      const int iobj = obj_idx[imode];
      bx = beamx[imode];
      by = beamy[imode];

      if((iobj_last>=0) && ((iobj != iobj_last) || (bx != bx_last) || (by != by_last)))
      {
        // new object mode, so save accumulated values
        // Distribute the computed o on the 4 corners of the interpolated object
        bilinear_atomic_add_c(objnew, o * cpad,    cx[j] + bx_last + prx, cy[j] + by_last + pry, iobj_last, nxo, nyo, interp);
        bilinear_atomic_add_f(objnorm, prn * cpad, cx[j] + bx_last + prx, cy[j] + by_last + pry, iobj_last, nxo, nyo, interp);
        o = 0;
        prn = 0;
      }
      iobj_last = iobj;
      bx_last = bx;
      by_last = by;

      const int iprobe = probe_idx[imode];
      // TODO: check if we need to cache probe values explicitly
      const complexf pr = probe[i + iprobe * nx * ny];
      prn += dot(pr,pr);

      complexf ps = psi[ipsi + (j + stack_size * imode) * (nx * ny)];

      ps = complexf(ps.real()*c - ps.imag()*s , ps.imag()*c + ps.real()*s);
      o += complexf(pr.real()*ps.real() + pr.imag()*ps.imag() , pr.real()*ps.imag() - pr.imag()*ps.real());
    }
    // Last object mode update
    bilinear_atomic_add_c(objnew, o * cpad,    cx[j] + bx + prx, cy[j] + by + pry, iobj_last, nxo, nyo, interp);
    bilinear_atomic_add_f(objnorm, prn * cpad, cx[j] + bx + prx, cy[j] + by + pry, iobj_last, nxo, nyo, interp);
    o = 0;
    prn = 0;
    iobj_last = -1;
  }
}

// Same for probe update, without need for atomic operations
__device__ void UpdateProbeQuadPhase(const int i, complexf *obj, complexf* probe, complexf* psi, float* probenorm,
                                     float* cx,  float* cy, const float px, const float f, const char firstpass,
                                     const int npsi, const int stack_size, const int nx, const int ny, const int nxo,
                                     const int nyo, const int nbprobe, const int nbmode,
                                     const int padding, const int padding_window, const bool interp,
                                     int* obj_idx, int* probe_idx, float* beamx, float* beamy)
{
  const int prx = i % nx;
  const int pry = i / nx;

  // obj and probe are centered arrays, Psi is fft-shifted

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  float s, c;
  __sincosf(tmp , &s, &c);

  // Use a window (Tukey or erfc) for padded data in near field,
  // preventing update of the probe far from the observation window
  float cpad = 1.0;
  if(padding > 0)
  { // Tukey dampening
    const int p0 = padding-padding_window;
    const int p1 = padding+padding_window;
    // Padding factor goes from 0 on the border to 1 at 2*padding pixels from the border
    if(prx<p0)            cpad = 0;
    else if(prx<p1)       cpad *= 0.5f * (1.0f - __cosf(   (prx-p0) * 1.57079632679f / padding_window));
    if(prx>=(nx-p0))      cpad = 0;
    else if(prx>=(nx-p1)) cpad *= 0.5f * (1.0f - __cosf((nx-p0-prx) * 1.57079632679f / padding_window));
    if(pry<p0)            cpad = 0;
    else if(pry<p1)       cpad *= 0.5f * (1.0f - __cosf(   (pry-p0) * 1.57079632679f / padding_window));
    if(pry>=(ny-p0))      cpad = 0;
    else if(pry>=(ny-p1)) cpad *= 0.5f * (1.0f - __cosf((ny-p0-pry) * 1.57079632679f / padding_window));
  }

  // TODO: find a way to use a dynamically allocated cache..
  // Write only probe and probe norm values at the end
  #define NB_PROBE_CACHE 8
  complexf prnew[NB_PROBE_CACHE]; // new probe value, before normalisation
  float prn[NB_PROBE_CACHE];  // probe normalisation

  for(int iprobe0=0; iprobe0<nbprobe ; iprobe0+=NB_PROBE_CACHE)
  {
    // Init local cache
    for(int iprobe=0; iprobe<NB_PROBE_CACHE; iprobe++)
    {
      prnew[iprobe] = complexf(0.0f,0.0f);
      prn[iprobe] = 0;
    }

    for(int j=0;j<npsi;j++)
    {
      for(int imode=0 ; imode < nbmode ; imode++)
      {
        const int iprobe = probe_idx[imode] - iprobe0;
        if((iprobe<0) || (iprobe>=NB_PROBE_CACHE)) continue;

        const int iobj = obj_idx[imode];
        const float bx = beamx[imode];
        const float by = beamy[imode];

        complexf ps = psi[ipsi + (j + stack_size * imode) * nx * ny];
        ps = complexf(ps.real()*c - ps.imag()*s , ps.imag()*c + ps.real()*s);
        if(cx[j]>1e8)
        {
          prnew[iprobe] += ps;
          prn[iprobe] += 1;
        }
        else
        {
          const complexf o = bilinear(obj, cx[j]+bx+prx, cy[j]+by+pry, iobj, nxo, nyo, interp, false);
          prnew[iprobe] += complexf(o.real()*ps.real() + o.imag()*ps.imag() , o.real()*ps.imag() - o.imag()*ps.real());
          prn[iprobe] += dot(o,o);
        }
      }
    }
    // Store values
    const int m = nbprobe - iprobe0 > NB_PROBE_CACHE ? NB_PROBE_CACHE : nbprobe - iprobe0;
    if(firstpass)
      for(int iprobe=0; iprobe<m; iprobe++)
      {
        probe[i + (iprobe+iprobe0) * nx * ny] = prnew[iprobe] * cpad;
        probenorm[i + (iprobe+iprobe0) * nx * ny] = prn[iprobe] * cpad;
      }
    else
      for(int iprobe=0; iprobe<m; iprobe++)
      {
        probe[i + (iprobe+iprobe0) * nx * ny] += prnew[iprobe] * cpad;
        probenorm[i + (iprobe+iprobe0) * nx * ny] += prn[iprobe] * cpad;
      }
  }
}

// Sum the stack of N object normalisation arrays to the first array
__device__ void SumNnorm(const int iobj, float* objnormN, const int stack_size, const int nxyo)
{
  float n=0;
  for(int i=1;i<stack_size;i++)
  {
    n += objnormN[iobj + i*nxyo];
  }
  objnormN[iobj] += n;
}

// Normalize object.
// The regularization term is used as in: Marchesini et al, Inverse problems 29 (2013), 115009, eq (14)
__device__ void ObjNorm(const int i, float* objnorm, complexf* obj_unnorm, complexf *obj,
                        float *regmax, const float inertia)
{
  const float reg = regmax[0] * inertia;
  const float norm = fmaxf(objnorm[i] + reg, 1e-12f);
  obj[i] = (obj_unnorm[i] + reg * obj[i]) / norm ;
}

// Normalise object
// The regularization term is used as in: Marchesini et al, Inverse problems 29 (2013), 115009, eq (14)
// Additional restraint on imaginary part to steer towards a zero phase in a given area
__device__ void ObjNormZeroPhaseMask(const int i, float* obj_norm, complexf *obj_new, complexf *obj,
                                     float *zero_phase_mask, float *regmax, const float inertia)
{
  const float reg = regmax[0] * inertia;
  const float norm_real = obj_norm[i] + reg; // The same norm applies to all object modes
  const float norm_imag = obj_norm[i] + regmax[0] * (inertia + zero_phase_mask[i]);
   const complexf o = reg * obj[i] + obj_new[i];
   obj[i] = complexf(o.real() / norm_real, o.imag() / norm_imag) ;
}


/*
// TODO: ObjNormZeroPhaseMask needs to be re-implemented
// Normalise object directly from the stack of N layers of object and norm, to avoid one extra memory r/w
// The regularization term is used as in: Marchesini et al, Inverse problems 29 (2013), 115009, eq (14)
// Additional restraint on imaginary part to steer towards a zero phase in a given area
__device__ void ObjNormZeroPhaseMaskN(const int i, float* obj_norm, complexf *obj_newN, complexf *obj,
                                      float *zero_phase_mask, float *regmax, const float inertia,
                                      const int nxyo)
{
  const float reg = regmax[0] * inertia;
  const float norm_real = obj_norm[i] + reg; // The same norm applies to all object modes
  const float norm_imag = obj_norm[i] + regmax[0] * (inertia + zero_phase_mask[i % nxyo]);
  obj[i] = complexf(o.real() / norm_real, o.imag() / norm_imag) ;
}
*/
