/** Kernels for the Alternating Projections Ptychography algorithm.
* Uses a stack of BLOCKSIZE(=16) images handled at the same time for more efficiency,
* and takes into account object and probe modes.
*/


/** Apply the amplitude for a given number of modes.
*/
__kernel
void ApplyAmplitude_Real(float iobs, __global float2* dcalc)
{
  const unsigned long i2=get_global_id(0);

  float2 dc[NBMODE];
  float dc2=0;
  for(unsigned int mode=0 ; mode<NBMODE ; mode++)
  {
    dc[mode]=dcalc[i2 + mode*NXYZ];
    dc2 += dot(dc[mode],dc[mode]);
  }

  // Normalization to observed amplitude, taking into account all modes
  dc2 = fmax(dc2,1e-12f); // TODO: KLUDGE ? 1e-12f is arbitrary
  const float d = native_sqrt(fmax(iobs, 0)) * native_rsqrt(dc2);
  for(unsigned int mode=0 ; mode<NBMODE ; mode++)
  {
    dcalc[i2 + mode*NXYZ] = (float2) (d*dc[mode].x , d*dc[mode].y);
  }
}

/** Apply the amplitude for a given number of modes
*/
__kernel
void ApplyAmplitude(__global float *iobs, __global float2* dcalc)
{
  const unsigned long i2=get_global_id(0);
  ApplyAmplitude_Real(iobs[i2], dcalc);
}

/** Apply the amplitude for a given number of modes (version with mask)
*/
__kernel
void ApplyAmplitudeMask(__global float *iobs, __global float2* dcalc, __global char* mask)
{
  const unsigned long i2=get_global_id(0);
  if(mask[i2 % NXY] == 0) ApplyAmplitude_Real(iobs[i2], dcalc);
}

/** Apply the amplitude for a given number of modes (version with mask and background)
*/
__kernel
void ApplyAmplitudeMaskBackground(__global float *iobs, __global float2* dcalc, __global char* mask, __global float *background)
{
  const unsigned long i2=get_global_id(0);
  if(mask[i2 % NXY] == 0) ApplyAmplitude_Real(iobs[i2] - background[i2 % NXY], dcalc);
}

/** Apply the amplitude for a given number of modes (version with background)
*/
__kernel
void ApplyAmplitudeBackground(__global float *iobs, __global float2* dcalc, __global float *background)
{
  const unsigned long i2=get_global_id(0);
  ApplyAmplitude_Real(iobs[i2] - background[i2 % NXY], dcalc);
}


__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void ObjectProbeMult(__global float2 *obj, __global float2* probe, __global float2* psi, __global int* cxy, const int npsi)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  float2 o[BLOCKSIZE];
  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
    for(int i=0;i<npsi;i++)
    {
      const int iobj  = cx[i]+prx + NXO*(cy[i]+pry);
      o[i] = obj[iobj + iobjmode*NXYO];
    }
    for(int iprobe=0;iprobe<NBPROBE;iprobe++)
    {
      const float2 p = probe[get_global_id(0) + iprobe*NXY];

      for(int i=0;i<npsi;i++)
      {
        psi[ipsi + (i + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) ) * NXY ] = (float2)(o[i].x*p.x - o[i].y*p.y , o[i].x*p.y + o[i].y*p.x);
      }
      for(int i=npsi;i<BLOCKSIZE;i++)
      {
        psi[ipsi + (i + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) ) * NXY ] = (float2)0;
      }
    }
  }
}

__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void ObjectProbeMultQuadPhase(__global float2 *obj, __global float2* probe, __global float2* psi,
                              __global int* cxy, const float px, const float py, const float f, const int npsi)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - NY/2) * py;
  const float x = (prx - NX/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  float2 o[BLOCKSIZE];
  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
    for(int i=0;i<npsi;i++)
    {
      const int iobj  = cx[i]+prx + NXO*(cy[i]+pry);
      o[i] = obj[iobj + iobjmode*NXYO];
    }
    for(int iprobe=0;iprobe<NBPROBE;iprobe++)
    {
      const float2 p = probe[get_global_id(0) + iprobe*NXY];

      for(int i=0;i<npsi;i++)
      {
        float2 ps=(float2)(o[i].x*p.x - o[i].y*p.y , o[i].x*p.y + o[i].y*p.x);
        psi[ipsi + (i + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) ) * NXY ] = (float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
      }
      for(int i=npsi;i<BLOCKSIZE;i++)
      {
        // Need this for dummy frames at the end of the stack (to have a multiple of 16), or Chi2 would be incorrect
        psi[ipsi + (i + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) ) * NXY ] = (float2)0;
      }
    }
  }
}

// Note that firstpass here is not sufficient to erase all previous content as the entire object is probably not covered by the
// first set of 16 frames. firstpass however avoids a few memory transfers
__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void UpdateObj(__global float2 *objN, __global float2* probe, __global float2* psi, __global float* objnormN,
               __global int* cxy, const int npsi, const char firstpass)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NY;
  const int lid = get_local_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  float2 pr[NBPROBE];
  float prn=0;

  for(int iprobe=0;iprobe<NBPROBE;iprobe++)
  {
    pr[iprobe] = probe[get_global_id(0) + iprobe*NXY];
    prn += dot(pr[iprobe],pr[iprobe]);
  }

  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
    for(int i=0;i<npsi;i++)
    {
      float2 o=0;
      const int iobj0 = cx[i]+prx + NXO*(cy[i]+pry) + i * NXYO;
      const int iobj  = iobj0 + iobjmode*BLOCKSIZE * NXYO;
      for(int iprobe=0;iprobe<NBPROBE;iprobe++)
      {
        const float2 ps=psi[ipsi + (i + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) ) * NXY];
        o += (float2) (pr[iprobe].x*ps.x + pr[iprobe].y*ps.y , pr[iprobe].x*ps.y - pr[iprobe].y*ps.x);
      }
      if(firstpass)
      {
        objN[iobj] = o ;
        if(iobjmode==0) objnormN[iobj0] = prn ; // All the object modes have the same probe normalization. TODO: move to outside loop ?
      }
      else
      {
        objN[iobj] += o ;
        if(iobjmode==0) objnormN[iobj0] += prn ; // All the object modes have the same probe normalization. TODO: move to outside loop ?
      }
    }
  }
}

__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void UpdateObjQuadPhase(__global float2 *objN, __global float2* probe, __global float2* psi, __global float* objnormN,
                        __global int* cxy, const float px, const float py, const float f, const int npsi, const char firstpass)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  float2 pr[NBPROBE];
  float prn=0;

  for(int iprobe=0;iprobe<NBPROBE;iprobe++)
  {
    pr[iprobe] = probe[get_global_id(0) + iprobe*NXY];
    prn += dot(pr[iprobe],pr[iprobe]);
  }

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - NY/2) * py;
  const float x = (prx - NX/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
    for(int i=0;i<npsi;i++)
    {
      float2 o=0;
      const int iobj0 = cx[i]+prx + NXO*(cy[i]+pry) + i * NXYO;
      const int iobj  = iobj0 + iobjmode*BLOCKSIZE * NXYO;
      for(int iprobe=0;iprobe<NBPROBE;iprobe++)
      {
        float2 ps=psi[ipsi + (i + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) ) * NXY];
        ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
        o += (float2) (pr[iprobe].x*ps.x + pr[iprobe].y*ps.y , pr[iprobe].x*ps.y - pr[iprobe].y*ps.x);
      }
      if(firstpass)
      {
        objN[iobj] = o ;
        if(iobjmode==0) objnormN[iobj0] = prn ; // All the object modes have the same probe normalization. TODO: move to outside loop ?
      }
      else
      {
        objN[iobj] += o ;
        if(iobjmode==0) objnormN[iobj0] += prn ; // All the object modes have the same probe normalization. TODO: move to outside loop ?
      }
    }
  }
}


__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void UpdateProbe(__global float2 *obj, __global float2* probe, __global float2* psi, __global float* probenorm,
                 __global int* cxy, const int npsi, const char firstpass)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);
  const int iprobe=   get_global_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  float2 o[NBOBJ*BLOCKSIZE];
  float prn=0;

  for(int i=0;i<npsi;i++)
  {
    const int iobj0 = cx[i] + prx + NXO*(cy[i] + pry);
    for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
    {
      const int ii = i+iobjmode*BLOCKSIZE;
      const int iobj  = iobj0 + iobjmode * NXYO;
      o[ii] = obj[iobj];
      prn += dot(o[ii],o[ii]);
    }
  }
  // all modes have the same normalization
  if(firstpass) probenorm[iprobe] = prn ;
  else probenorm[iprobe] += prn ;

  for(int iprobemode=0;iprobemode<NBPROBE;iprobemode++)
  {
    float2 p=0;
    for(int i=0;i<npsi;i++)
    {
      for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
      {
        const int ii = i+iobjmode*BLOCKSIZE;
        const float2 ps = psi[ipsi + (i + BLOCKSIZE * (iprobemode + iobjmode * NBPROBE) ) * NXY];

        p += (float2) (o[ii].x*ps.x + o[ii].y*ps.y , o[ii].x*ps.y - o[ii].y*ps.x);
      }
    }
    if(firstpass) probe[iprobe + iprobemode * NXY] = p ;
    else probe[iprobe + iprobemode * NXY] += p ;
  }
}

__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void UpdateProbeQuadPhase(__global float2 *obj, __global float2* probe, __global float2* psi, __global float* probenorm,
                          __global int* cxy, const float px, const float py, const float f, const int npsi, const char firstpass)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);
  const int iprobe=   get_global_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  float2 o[NBOBJ*BLOCKSIZE];
  float prn=0;

  for(int i=0;i<npsi;i++)
  {
    const int iobj0 = cx[i] + prx + NXO*(cy[i] + pry);
    for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
    {
      const int ii = i+iobjmode*BLOCKSIZE;
      const int iobj  = iobj0 + iobjmode * NXYO;
      o[ii] = obj[iobj];
      prn += dot(o[ii],o[ii]);
    }
  }

  #if 0
  probenorm[iprobe] += prn ;  // all modes have the same normalization
  #else
  // all modes have the same normalization
  if(firstpass) probenorm[iprobe] = prn ;
  else probenorm[iprobe] += prn ;
  #endif

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - NY/2) * py;
  const float x = (prx - NX/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iprobemode=0;iprobemode<NBPROBE;iprobemode++)
  {
    float2 p=0;
    for(int i=0;i<npsi;i++)
    {
      for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
      {
        const int ii = i+iobjmode*BLOCKSIZE;
        float2 ps = psi[ipsi + (i + BLOCKSIZE * (iprobemode + iobjmode * NBPROBE) ) * NXY];
        ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);

        p += (float2) (o[ii].x*ps.x + o[ii].y*ps.y , o[ii].x*ps.y - o[ii].y*ps.x);
      }
    }
    #if 0
    probe[iprobe + iprobemode * NXY] += p ;
    #else
    if(firstpass) probe[iprobe + iprobemode * NXY] = p ;
    else probe[iprobe + iprobemode * NXY] += p ;
    #endif
  }
}


// Sum & zero the stack of 16 updated objects (this must be done in this step to avoid memory access conflicts)
__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void SumNZero(__global float2 *objN, __global float* objnormN, __global float2 *obj, __global float* objnorm)
{
  const int iobj = get_global_id(0);
  float n=0;
  for(int i=0;i<16;i++)
  {
    n += objnormN[iobj + i*NXYO];
    objnormN[iobj + i*NXYO] = 0;
  }
  objnorm[iobj]=n; // The same norm applies to all object modes

  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
     float2 o=0;
     for(int i=0;i<16;i++)
     {
       o += objN[iobj + (i + iobjmode*BLOCKSIZE) * NXYO];
       objN[iobj + (i + iobjmode*BLOCKSIZE) * NXYO] = (float2)0;
     }
     obj[iobj + iobjmode*NXYO]=o;
  }
}

// Sum the stack of 16 updated objects (this must be done in this step to avoid memory access conflicts)
__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void SumN(__global float2 *objN, __global float* objnormN, __global float2 *obj, __global float* objnorm)
{
  const int iobj = get_global_id(0);
  float n=0;
  for(int i=0;i<16;i++)
  {
    n += objnormN[iobj + i*NXYO];
  }
  objnorm[iobj]=n; // The same norm applies to all object modes

  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
     float2 o=0;
     for(int i=0;i<16;i++)
     {
       o += objN[iobj + (i + iobjmode*BLOCKSIZE) * NXYO];
     }
     obj[iobj + iobjmode*NXYO]=o;
  }
}

// Normalize object.
// The regularization term is used as in: Marchesini et al, Inverse problems 29 (2013), 115009, eq (14)
__kernel
void ObjNorm(__global float2 *obj_unnorm, __global float* objnorm, __global float2 *obj, const float reg)
{
  const int i = get_global_id(0);

  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
    obj[i + iobjmode*NXYO] = (obj_unnorm[i + iobjmode*NXYO] + reg * obj[i + iobjmode*NXYO]) / (objnorm[i] + reg) ;
}

// Normalize probe
__kernel
void ProbeNorm(__global float2 *probe_unnorm, __global float* probenorm, __global float2 *probe, const float reg)
{
  const int i = get_global_id(0);

  for(int iprobemode=0;iprobemode<NBPROBE;iprobemode++)
    probe[i + iprobemode*NXY] = (probe_unnorm[i + iprobemode*NXY] + reg * probe[i + iprobemode*NXY]) / (probenorm[i] + reg);
}

// Normalize probe, taking into account probe mask
__kernel
void ProbeNormMask(__global float2 *probe_unnorm, __global float* probenorm, __global float2 *probe, const float reg, __global float* mask)
{
  const int i = get_global_id(0);

  for(int iprobemode=0;iprobemode<NBPROBE;iprobemode++)
  {
    probe[i + iprobemode*NXY] = mask[i] * (probe_unnorm[i + iprobemode*NXY] + reg * probe[i + iprobemode*NXY]) / (probenorm[i] + reg);
  }
}
