/** Kernels for the Differnce Map Ptychography algorithm.
* Uses a stack of BLOCKSIZE(=16) images handled at the same time for more efficiency,
* and takes into account object and probe modes.
*/

/** Calculates 2*Object*Probe-Psi
*
*/
__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void ObjectProbeMult_Psi(__global float2 *obj, __global float2* probe, __global float2* psi, __global float2* oldpsi,
                            __global int* cxy, const int npsi)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  float2 o[BLOCKSIZE];
  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
    for(int i=0;i<npsi;i++)
    {
      const int iobj  = cx[i]+prx + NXO*(cy[i]+pry);
      o[i] = obj[iobj + iobjmode*NXYO];
    }
    for(int iprobe=0;iprobe<NBPROBE;iprobe++)
    {
      const float2 p = 2 * probe[get_global_id(0) + iprobe*NXY];
      const int ipsi0 = ipsi + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) * NXY;
      for(int i=0;i<npsi;i++)
      {
        psi[ipsi0 + i * NXY] = (float2)(o[i].x*p.x - o[i].y*p.y , o[i].x*p.y + o[i].y*p.x) - oldpsi[ipsi0 + i * NXY];
      }
      for(int i=npsi;i<BLOCKSIZE;i++)
      {
        psi[ipsi0 + i * NXY] = (float2)0;
      }
    }
  }
}

/** Calculates 2*Object*Probe-Psi, taking into account a quadratic phase for PO (Psi is already multiplid by it)
*
*/
__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void ObjectProbeMult_PsiQuadPhase(__global float2 *obj, __global float2* probe, __global float2* psi, __global float2* oldpsi,
                                  __global int* cxy, const float px, const float py, const float f, const int npsi)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - NY/2) * py;
  const float x = (prx - NX/2) * px;
  const float tmp = f*(x*x+y*y); // f is +pi/lambdaz here
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  float2 o[BLOCKSIZE];
  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
    for(int i=0;i<npsi;i++)
    {
      const int iobj  = cx[i]+prx + NXO*(cy[i]+pry);
      o[i] = obj[iobj + iobjmode*NXYO];
    }
    for(int iprobe=0;iprobe<NBPROBE;iprobe++)
    {
      const float2 p = 2 * probe[get_global_id(0) + iprobe*NXY];
      const int ipsi0 = ipsi + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) * NXY;
      for(int i=0;i<npsi;i++)
      {
        float2 ps=(float2)(o[i].x*p.x - o[i].y*p.y , o[i].x*p.y + o[i].y*p.x);
        psi[ipsi0 + i * NXY] = (float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s) - oldpsi[ipsi0 + i * NXY];
      }
      for(int i=npsi;i<BLOCKSIZE;i++)
      {
        // Need this for dummy frames at the end of the stack (to have a multiple of 16), or Chi2 would be incorrect
        psi[ipsi0 + i * NXY] = (float2)0;
      }
    }
  }
}


/** Update Psi
* Psi(n+1) = Psi(n) - P*O + Psi_calc ; where Psi_calc is (2*P*O - Psi(n)) after applying Fourier constraints
*/
__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void DM_UpdatePsi(__global float2 *obj, __global float2* probe, __global float2* psicalc, __global float2* psi,
                  __global int* cxy, const int npsi)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  float2 o[BLOCKSIZE];
  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
    for(int i=0;i<npsi;i++)
    {
      const int iobj  = cx[i]+prx + NXO*(cy[i]+pry);
      o[i] = obj[iobj + iobjmode*NXYO];
    }
    for(int iprobe=0;iprobe<NBPROBE;iprobe++)
    {
      const float2 p = probe[get_global_id(0) + iprobe*NXY];

      const int ipsi0 = ipsi + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) * NXY;
      for(int i=0;i<npsi;i++)
      {
        psi[ipsi0 + i * NXY] += psicalc[ipsi0 + i * NXY] - (float2)(o[i].x*p.x - o[i].y*p.y , o[i].x*p.y + o[i].y*p.x) ;
      }
      for(int i=npsi;i<BLOCKSIZE;i++)
      {
        psi[ipsi0 + i * NXY] = (float2)0;
      }
    }
  }
}

/** Update Psi (version correcting quadratic phase)
* Psi(n+1) = Psi(n) - P*O + Psi_calc ; where Psi_calc is (2*P*O - Psi(n)) after applying Fourier constraints
*/
__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void DM_UpdatePsiQuadPhase(__global float2 *obj, __global float2* probe, __global float2* psicalc, __global float2* psi,
                           __global int* cxy, const float px, const float py, const float f, const int npsi)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - NY/2) * py;
  const float x = (prx - NX/2) * px;
  const float tmp = f*(x*x+y*y); // f is -pi/lambdaz here
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  float2 o[BLOCKSIZE];
  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
    for(int i=0;i<npsi;i++)
    {
      const int iobj  = cx[i]+prx + NXO*(cy[i]+pry);
      o[i] = obj[iobj + iobjmode*NXYO];
    }
    for(int iprobe=0;iprobe<NBPROBE;iprobe++)
    {
      const float2 p = probe[get_global_id(0) + iprobe*NXY];
      const int ipsi0 = ipsi + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) * NXY;
      for(int i=0;i<npsi;i++)
      {
        float2 op=(float2)(o[i].x*p.x - o[i].y*p.y , o[i].x*p.y + o[i].y*p.x);
        psi[ipsi0 + i * NXY] += psicalc[ipsi0 + i * NXY] - (float2)(op.x*c - op.y*s , op.y*c + op.x*s);
      }
      for(int i=npsi;i<BLOCKSIZE;i++)
      {
        // Need this for dummy frames at the end of the stack (to have a multiple of 16)
        psi[ipsi0 + i * NXY] = (float2)0;
      }
    }
  }
}
