// Calculate corrected Psi values (Psi * (1 - Iobs / Icalc)), before inverse FT for the calculation of gradient vs object and probe
// These kernels are called for NX*NY*NZ points
__kernel
void PsiCorr_Real(__global float2* psi, const float iobs, const float background)
{
  const unsigned long i2=get_global_id(0);

  float2 ps[NBMODE];
  float ps2=0;
  for(unsigned int mode=0 ; mode<NBMODE ; mode++)
  {
    ps[mode]=psi[i2 + mode*NXYZ];
    ps2 += dot(ps[mode],ps[mode]);
  }

  ps2 = fmax(ps2,1e-12f); // TODO: KLUDGE ? 1e-12f is arbitrary
  const float f = 1 - iobs/ (ps2 + background);

  for(unsigned int mode=0 ; mode<NBMODE ; mode++)
  {
    psi[i2 + mode*NXYZ] = (float2) (f*ps[mode].x , f*ps[mode].y);
  }
}

__kernel
void PsiCorr(__global float2* psi, __global float* iobs)
{
  const unsigned long i2=get_global_id(0);
  PsiCorr_Real(psi, iobs[i2], 0.0f);
}

__kernel
void PsiCorrMask(__global float2* psi, __global float* iobs, __global char* mask)
{
  const unsigned long i2=get_global_id(0);
  if(mask[i2 % NXY] == 0) PsiCorr_Real(psi, iobs[i2], 0.0f);
}

__kernel
void PsiCorrMaskBackground(__global float2* psi, __global float* iobs, __global char* mask, __global float* background)
{
  const unsigned long i2=get_global_id(0);
  if(mask[i2 % NXY] == 0) PsiCorr_Real(psi, iobs[i2], background[i2 % NXY]);
}

__kernel
void PsiCorrBackground(__global float2* psi, __global float* iobs, __global float* background)
{
  const unsigned long i2=get_global_id(0);
  PsiCorr_Real(psi, iobs[i2], background[i2 % NXY]);
}

// Same as PsiCorr kernels, but also calculate the LLK gradient vs the background at the same time
// Note that these kernels are called for NX*NY points, so must loop over the stack of frames
__kernel
void PsiCorrBackgroundGradient(__global float2* psi, __global float* iobs, __global float* background, __global float* dbackground,
                               const int npsi, const char firstpass)
{
  const unsigned long i2=get_global_id(0);
  float dbackg = 0.0f;
  const float backg = background[i2];
  for(int i=0;i<npsi;i++)
  {
    float2 ps[NBMODE];
    float ps2=0;
    for(unsigned int mode=0 ; mode<NBMODE ; mode++)
    {
      ps[mode]=psi[i2 + mode*NXYZ + i*NXY];
      ps2 += dot(ps[mode],ps[mode]);
    }

    ps2 = fmax(ps2 + backg,1e-12f); // TODO: KLUDGE ? 1e-12f is arbitrary
    const float f = 1 - iobs[i2 + i*NXY]/ ps2;
    dbackg += f;

    for(unsigned int mode=0 ; mode<NBMODE ; mode++)
    {
      psi[i2 + mode*NXYZ + i*NXY] = (float2) (f*ps[mode].x , f*ps[mode].y);
    }
  }
  if(firstpass) dbackground[i2] = dbackg;
  else dbackground[i2] += dbackg;
}

__kernel
void PsiCorrBackgroundGradientMask(__global float2* psi, __global float* iobs, __global float* background, __global float* dbackground,
                               __global char* mask, const int npsi, const char firstpass)
{
  if(mask[get_global_id(0) % NXY] == 0) PsiCorrBackgroundGradient(psi, iobs, background, dbackground, npsi, firstpass);
}


// This is actually the conjugate of the gradient (steepest descent)
__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void ObjGrad(__global float2 *objgradN, __global float2 *probe, __global float2* psi,
             __global int* cxy, const int npsi, const char firstpass)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  float2 pr[NBPROBE];

  for(int iprobe=0;iprobe<NBPROBE;iprobe++)
  {
    pr[iprobe] = probe[get_global_id(0) + iprobe*NXY];
  }

  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
    for(int i=0;i<npsi;i++)
    {
      float2 objgrad=0;
      const int iobj0 = cx[i]+prx + NXO*(cy[i]+pry) + i * NXYO;
      const int iobj  = iobj0 + iobjmode*BLOCKSIZE * NXYO;
      for(int iprobe=0;iprobe<NBPROBE;iprobe++)
      {
        const float2 ps=psi[ipsi + (i + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) ) * NXY];
        objgrad += (float2) (pr[iprobe].x*ps.x + pr[iprobe].y*ps.y ,  pr[iprobe].x*ps.y - pr[iprobe].y*ps.x );
      }
      if(firstpass)
        objgradN[iobj] = -objgrad ;
      else
        objgradN[iobj] -= objgrad ;
    }
  }
}

// This is actually the conjugate of the gradient (steepest descent)
__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void ObjGradQuadPhase(__global float2 *objgradN, __global float2 *probe, __global float2* psi,
                      __global int* cxy, const float px, const float py, const float f, const int npsi, const char firstpass)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  float2 pr[NBPROBE];

  for(int iprobe=0;iprobe<NBPROBE;iprobe++)
  {
    pr[iprobe] = probe[get_global_id(0) + iprobe*NXY];
  }

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - NY/2) * py;
  const float x = (prx - NX/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
    for(int i=0;i<npsi;i++)
    {
      float2 objgrad=0;
      const int iobj0 = cx[i]+prx + NXO*(cy[i]+pry) + i * NXYO;
      const int iobj  = iobj0 + iobjmode*BLOCKSIZE * NXYO;
      for(int iprobe=0;iprobe<NBPROBE;iprobe++)
      {
        float2 ps=psi[ipsi + (i + BLOCKSIZE * (iprobe + iobjmode * NBPROBE) ) * NXY];
        ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
        objgrad += (float2) (pr[iprobe].x*ps.x + pr[iprobe].y*ps.y , pr[iprobe].x*ps.y - pr[iprobe].y*ps.x );
      }
      if(firstpass)
        objgradN[iobj] = -objgrad ;
      else
        objgradN[iobj] -= objgrad ;
    }
  }
}

// This is actually the conjugate of the gradient (steepest descent)
__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void ProbeGrad(__global float2 *probegrad, __global float2 *obj, __global float2* psi,
               __global int* cxy, const int npsi, const char firstpass)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);
  const int iprobe=   get_global_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  float2 o[NBOBJ*BLOCKSIZE];

  for(int i=0;i<npsi;i++)
  {
    const int iobj0 = cx[i] + prx + NXO*(cy[i] + pry);
    for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
    {
      const int ii = i+iobjmode*BLOCKSIZE;
      const int iobj  = iobj0 + iobjmode * NXYO;
      o[ii] = obj[iobj];
    }
  }

  for(int iprobemode=0;iprobemode<NBPROBE;iprobemode++)
  {
    float2 p=0;
    for(int i=0;i<npsi;i++)
    {
      for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
      {
        const int ii = i+iobjmode*BLOCKSIZE;
        const float2 ps = psi[ipsi + (i + BLOCKSIZE * (iprobemode + iobjmode * NBPROBE) ) * NXY];

        p += (float2) (o[ii].x*ps.x + o[ii].y*ps.y , o[ii].x*ps.y - o[ii].y*ps.x);
      }
    }
    if(firstpass)
      probegrad[iprobe + iprobemode * NXY] = -p ;
    else
      probegrad[iprobe + iprobemode * NXY] -= p ;
  }
}

// This is actually the conjugate of the gradient (steepest descent)
__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void ProbeGradQuadPhase(__global float2 *probegrad, __global float2 *obj, __global float2* psi,
                        __global int* cxy, const float px, const float py, const float f, const int npsi, const char firstpass)
{
  const int prx = get_global_id(0) % NX;
  const int pry = get_global_id(0) / NX;
  const int lid = get_local_id(0);
  const int iprobe=   get_global_id(0);

  // obj and probe are centered arrays, Psi is fft-shifted
  __local int cx[BLOCKSIZE];
  __local int cy[BLOCKSIZE];
  cx[lid]=cxy[lid];
  cy[lid]=cxy[lid+BLOCKSIZE];

  barrier(CLK_LOCAL_MEM_FENCE);

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - NY/2 + NY * (pry<(NY/2));
  const int ix = prx - NX/2 + NX * (prx<(NX/2));
  const int ipsi  = ix + iy * NX ;

  float2 o[NBOBJ*BLOCKSIZE];

  for(int i=0;i<npsi;i++)
  {
    const int iobj0 = cx[i] + prx + NXO*(cy[i] + pry);
    for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
    {
      const int ii = i+iobjmode*BLOCKSIZE;
      const int iobj  = iobj0 + iobjmode * NXYO;
      o[ii] = obj[iobj];
    }
  }

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - NY/2) * py;
  const float x = (prx - NX/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iprobemode=0;iprobemode<NBPROBE;iprobemode++)
  {
    float2 p=0;
    for(int i=0;i<npsi;i++)
    {
      for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
      {
        const int ii = i+iobjmode*BLOCKSIZE;
        float2 ps = psi[ipsi + (i + BLOCKSIZE * (iprobemode + iobjmode * NBPROBE) ) * NXY];
        ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);

        p += (float2) (o[ii].x*ps.x + o[ii].y*ps.y , o[ii].x*ps.y - o[ii].y*ps.x);
      }
    }
    if(firstpass)
      probegrad[iprobe + iprobemode * NXY] = -p ;
    else
      probegrad[iprobe + iprobemode * NXY] -= p ;
  }
}


__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void SumObjGradN(__global float2 *objgrad, __global float2 *objgradN)
{
  const int iobj = get_global_id(0);

  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
     float2 o=0;
     for(int i=0;i<16;i++)
     {
       o += objgradN[iobj + (i + iobjmode*BLOCKSIZE)  * NXYO];
     }
     objgrad[iobj + iobjmode*NXYO]=o;
  }
}


__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void SumObjGradNZero(__global float2 *objgrad, __global float2 *objgradN)
{
  const int iobj = get_global_id(0);

  for(int iobjmode=0;iobjmode<NBOBJ;iobjmode++)
  {
     float2 o=0;
     for(int i=0;i<16;i++)
     {
       o += objgradN[iobj + (i + iobjmode*BLOCKSIZE)  * NXYO] ;
       objgradN[iobj + (i + iobjmode*BLOCKSIZE)  * NXYO] = (float2)0 ;
     }
     objgrad[iobj + iobjmode*NXYO]=o;
  }
}


__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void CG_linear_complex(const float a, __global float2 *A, const float b, __global float2 *B)
{
  const int i = get_global_id(0);
  A[i] = (float2)(a*A[i].x + b*B[i].x, a*A[i].y + b*B[i].y);
}

__kernel __attribute__((reqd_work_group_size(16, 1, 1)))
void CG_linear_float(const float a, __global float *A, const float b, __global float *B)
{
  const int i = get_global_id(0);
  A[i] = a*A[i] + b*B[i];
}


__kernel
void RegGrad(__global float2 *dv, __global float2 *v, const float alpha)
{
  const int i = get_global_id(0);
  const int x = i%NX;
  const int y = i/NX;
  const int y0 = y%NY; // For multiple modes, to see if we are near a border

  const float2 v0=v[i];
  float2 d = (float2)(8*alpha*v0.x, 8*alpha*v0.y);

  // The 4 cases could be put in a loop for simplicity (but not performance)
  if(x>0)
  {
    const float2 v1=v[i-1];
    d += (float2)(-2*alpha*v1.x, -2*alpha*v1.y);
  }
  if(x<(NX-1))
  {
    const float2 v1=v[i+1];
    d += (float2)(-2*alpha*v1.x, -2*alpha*v1.y);
  }
  if(y0>0)
  {
    const float2 v1=v[i-NX];
    d += (float2)(-2*alpha*v1.x, -2*alpha*v1.y);
  }
  if(y0<(NY-1))
  {
    const float2 v1=v[i+NX];
    d += (float2)(-2*alpha*v1.x, -2*alpha*v1.y);
  }

  dv[i] += d;
}
