/* -*- coding: utf-8 -*-
*
* PyNX - Python tools for Nano-structures Crystallography
*   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
*       authors:
*         Vincent Favre-Nicolin, favre@esrf.fr
*/

/** Elementwise kernel to compute an update of the object and probe from Psi. The update of the object is
* done separately for each frame to avoid memory conflicts because of the unknown shift between the frames.
* This should be called with a first argument array with a size of nx*ny, i.e. one frame size. Each parallel
* kernel execution treats one pixel, for one frames and all modes.
* TODO: this must be optimised (coalesced memory access), as performance is terrible (for fast GPUs)
*/
void UpdateObjQuadPhase(const int i, __global float2* psi, __global float2 *objnew, __global float2* probe,
                        __global float* objnorm, const int cx,  const int cy, const float px, const float f,
                        const int stack_size, const int nx, const int ny, const int nxo, const int nyo,
                        const int nbobj, const int nbprobe)
{
  // Coordinate
  const int prx = i % nx;
  const int pry = i / nx;

  //if(i<140) printf("CL prx=%2d pry=%2d cx=%3d cy=%3d stack_size=%d nx=%d ny=%d nxo=%d nyo=%d nbobj=%d nbprobe=%d, lid=(%d,%d,%d) gid=(%d,%d,%d)\\n", prx, pry, cx, cy, stack_size, nx, ny, nxo, nyo, nbobj, nbprobe,get_local_id(0),get_local_id(1),get_local_id(2),get_global_id(0),get_global_id(1),get_global_id(2));

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  float2 pr;
  float prn=0;

  for(int iprobe=0;iprobe<nbprobe;iprobe++)
  {// TODO: avoid multiple access of probe value (maybe cached ?)
    pr = probe[i + iprobe*nx*ny];
    prn += dot(pr,pr);
  }

  const int iobj0 = cx+prx + nxo*(cy+pry);
  objnorm[iobj0] += prn ; // All the object modes have the same probe normalization.

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    float2 o=0;
    const int iobj  = iobj0 + iobjmode * nxo * nyo;
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      pr = probe[i + iprobe*nx*ny]; // TODO: avoid multiple access of probe value (maybe cached ?)
      float2 ps=psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nx * ny];
      ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
      o += (float2) (pr.x*ps.x + pr.y*ps.y , pr.x*ps.y - pr.y*ps.x);
    }
    objnew[iobj] += o ;
  }
}

/** Elementwise kernel to compute an update of the object and probe from Psi. The update of the object is
* done in N=stack_size arrays to avoid memory conflicts because of the unknown shift between the N frames.
* The resulting update is stored in a stack of N object to later summed.
* This should be called with a first argument array with a size of nx*ny, i.e. one frame size. Each parallel
* kernel execution treats one pixel, for all frames and all modes.
*/
void UpdateObjQuadPhaseN(const int i, __global float2* psi, __global float2 *objnewN, __global float2* probe,
                         __global float* objnormN, __global int *cx, __global int *cy, const float px, const float f,
                         const int stack_size, const int nx, const int ny, const int nxo, const int nyo,
                         const int nbobj, const int nbprobe, const int npsi)
{
  // Coordinate in probe array
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxyo = nxo * nyo;
  const int nxy = nx * ny;

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  float2 pr;
  float prn=0;

  for(int iprobe=0;iprobe<nbprobe;iprobe++)
  {// TODO: avoid multiple access of probe value (maybe cached ?)
    pr = probe[i + iprobe*nx*ny];
    prn += dot(pr,pr);
  }

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    for(int j=0;j<npsi;j++)
    {
      const int iobj0 = cx[j] + prx + nxo * (cy[j] + pry) + j * nxyo;

      float2 o=0;
      const int iobj  = iobj0 + iobjmode * nxo * nyo;
      for(int iprobe=0;iprobe<nbprobe;iprobe++)
      {
        pr = probe[i + iprobe*nx*ny]; // TODO: avoid multiple access of probe value (maybe cached ?)
        float2 ps=psi[ipsi + nxy * (j + stack_size * (iprobe + iobjmode * nbprobe)) ];
        ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
        o += (float2) (pr.x*ps.x + pr.y*ps.y , pr.x*ps.y - pr.y*ps.x);
      }
      objnormN[iobj0] += prn ; // All the object modes have the same probe normalization.
      objnewN[iobj] += o ;
    }
  }
}


// Same for probe update
// TODO: this must be optimised (coalesced memory access), as performance is terrible !
void UpdateProbeQuadPhase(const int i, __global float2 *obj, __global float2* probe, __global float2* psi,
                          __global float* probenorm, __global int* cx,  __global int* cy,
                          const float px, const float f, const char firstpass, const int npsi, const int stack_size,
                          const int nx, const int ny, const int nxo, const int nyo, const int nbobj, const int nbprobe)
{
  const int prx = i % nx;
  const int pry = i / nx;

  // obj and probe are centered arrays, Psi is fft-shifted

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  float2 o;
  float prn=0;

  for(int j=0;j<npsi;j++)
  {
    const int iobj0 = cx[j] + prx + nxo*(cy[j] + pry);
    for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
    {
      const int iobj  = iobj0 + iobjmode * nxo * nyo;
      o = obj[iobj]; // TODO 1: store object values to avoid repeated memory read
      prn += dot(o,o);
    }
  }

  // all modes have the same normalization
  if(firstpass) probenorm[i] = prn ;
  else probenorm[i] += prn ;

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iprobemode=0;iprobemode<nbprobe;iprobemode++)
  {
    float2 p=0;
    for(int j=0;j<npsi;j++)
    {
      const int iobj0 = cx[j] + prx + nxo*(cy[j] + pry); // TODO 1
      for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
      {
        float2 ps = psi[ipsi + (j + stack_size * (iprobemode + iobjmode * nbprobe) ) * nx * ny];
        ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);

        const int iobj  = iobj0 + iobjmode * nxo * nyo; // TODO 1
        o = obj[iobj]; // TODO 1

        p += (float2) (o.x*ps.x + o.y*ps.y , o.x*ps.y - o.y*ps.x);
      }
    }
    if(firstpass) probe[i + iprobemode * nx * ny] = p ;
    else probe[i + iprobemode * nx * ny] += p ;
  }
}


// Sum the stack of N updated objects (this must be done in this step to avoid memory access conflicts),
// as well as the sum of object normalisation arrays
void SumN(const int iobj, __global float2 *obj, __global float2 *objN, __global float* objnorm, __global float* objnormN, const int stack_size, const int nxyo, const int nbobj)
{
  float n=0;
  for(int i=0;i<stack_size;i++)
  {
    const int ii = iobj + i*nxyo;
    n += objnormN[ii];
    //objnormN[ii] = 0;
  }
  objnorm[iobj]=n; // The same norm applies to all object modes

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
     float2 o=0;
     for(int i=0;i<stack_size;i++)
     {
       const int ii = iobj + (i + iobjmode*stack_size) * nxyo;
       o += objN[ii];
       //objN[ii] = 0;
     }
     obj[iobj + iobjmode*nxyo] = o;
  }
}

// Sum the stack of N object normalisation arrays to the first array
void SumNnorm(const int iobj, __global float* objnormN, const int stack_size, const int nxyo)
{
  float n=0;
  for(int i=1;i<stack_size;i++)
  {
    n += objnormN[iobj + i*nxyo];
  }
  objnormN[iobj] += n;
}


// Normalize object.
// The regularization term is used as in: Marchesini et al, Inverse problems 29 (2013), 115009, eq (14)
void ObjNorm(const int i, __global float2 *obj_unnorm, __global float* objnorm, __global float2 *obj, const float reg, const int nxyo, const int nbobj)
{
  const float norm = fmax(objnorm[i] + reg, 1e-12f);
  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
    obj[i + iobjmode*nxyo] = (obj_unnorm[i + iobjmode*nxyo] + reg * obj[i + iobjmode*nxyo]) / norm ;
}

// Normalize object directly from the stack of N layers of object and norm, to avoid one extra memory r/w
// The regularization term is used as in: Marchesini et al, Inverse problems 29 (2013), 115009, eq (14)
void ObjNormN(const int i, __global float* obj_norm, __global float2 *obj_newN, __global float2 *obj, const float reg, const int nxyo, const int nbobj, const int stack_size)
{
  const float norm = fmax(obj_norm[i] + reg, 1e-12f); // The same norm applies to all object modes
  for(int iobjmode=0; iobjmode<nbobj; iobjmode++)
  {
     float2 o=0;
     for(int j=0;j<stack_size;j++)
     {
       const int ii = i + (j + iobjmode*stack_size) * nxyo;
       o += obj_newN[ii];
     }
    obj[i + iobjmode*nxyo] = (o + reg * obj[i + iobjmode*nxyo]) / norm ;
  }
}

