/** Calculate Psi.conj() * (1 - Iobs / Icalc), for the gradient calculation with Poisson noise.
* Masked pixels are set to zero.
* \param i: the point in the 3D observed intensity array for which the llk is calculated
* \param iobs: the observed in tensity array, shape=(stack_size, ny, nx)
* \param psi: the calculated complex amplitude, shape=(nb_obj, nb_probe, stack_size, ny, nx)
* \param background: the incoherent background, of shape (ny, nx)
* \param nbmode: number of modes = nb_probe * nb_obj
* \param nxy: number of pixels in a single frame
* \param nxystack: number of frames in stack multiplied by nxy
* \return: a float4 vector with (poisson llk, gaussian llk, euclidian llk, icalc)
*/
void GradPoissonFourier(const int i, __global float *iobs, __global float2 *psi, __global float *background, const int nbmode, const int nx, const int ny, const int nxy, const int nxystack)
{
  const float obs= iobs[i];

  if(obs < 0) return;

  // Use a Hann window multiplication to dampen high-frequencies in the object
  const int prx = i % nx;
  const int pry = i / nx;
  const float qx = (float)(prx - nx/2 + nx * (prx<(nx/2))) / (float)nx;
  const float qy = (float)(pry - ny/2 + ny * (pry<(ny/2))) / (float)ny;
  const float g = pown(native_cos(qx) * native_cos(qy),2);

  float calc = 0;
  for(int imode=0;imode<nbmode;imode++) calc += dot(psi[i + imode * nxystack],psi[i + imode* nxystack]);

  calc = fmax(1e-12f,calc);  // TODO: KLUDGE ? 1e-12f is arbitrary

  const float f = g * (1 - obs/ (calc + background[i%nxy]));

  for(int imode=0; imode < nbmode; imode++)
  {
    // TODO: store psi to avoid double-read. Or just assume it's cached.
    const float2 ps = psi[i + imode * nxystack];
    psi[i + imode * nxystack] = (float2) (f*ps.x , f*ps.y);
  }
}


/** Elementwise kernel to compute the object gradient from psi. Almost the same as the kernel to compute the
* updated object projection, except that no normalization array is retained.
*/
void GradObj(const int i, __global float2* psi, __global float2 *objgrad, __global float2* probe,
             const int cx,  const int cy, const float px, const float f,
             const int stack_size, const int nx, const int ny, const int nxo, const int nyo,
             const int nbobj, const int nbprobe)
{
  // Coordinate
  const int prx = i % nx;
  const int pry = i / nx;

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  const int iobj0 = cx+prx + nxo*(cy+pry);

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    float2 grad=0;
    const int iobj  = iobj0 + iobjmode*stack_size * nxo * nyo;
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      const float2 pr = probe[i + iprobe*nx*ny];
      float2 ps=psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nx * ny];
      ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
      grad += (float2) (pr.x*ps.x + pr.y*ps.y , pr.x*ps.y - pr.y*ps.x );
    }
    objgrad[iobj] -= grad;
  }
}

/** Elementwise kernel to compute the probe gradient from psi. Almost the same as the kernel to compute the
* updated probe projection, except that no normalization array is retained.
*/
void GradProbe(const int i, __global float2* psi, __global float2* probegrad, __global float2 *obj,
                          __global int* cx,  __global int* cy,
                          const float px, const float f, const char firstpass, const int npsi, const int stack_size,
                          const int nx, const int ny, const int nxo, const int nyo, const int nbobj, const int nbprobe)
{
  const int prx = i % nx;
  const int pry = i / nx;
  const int iprobe=   i;

  // obj and probe are centered arrays, Psi is fft-shifted

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iprobemode=0;iprobemode<nbprobe;iprobemode++)
  {
    float2 p=0;
    for(int j=0;j<npsi;j++)
    {
      const int iobj0 = cx[j] + prx + nxo*(cy[j] + pry);
      for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
      {
        float2 ps = psi[ipsi + (j + stack_size * (iprobemode + iobjmode * nbprobe) ) * nx * ny];
        ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);

        const int iobj  = iobj0 + iobjmode * nxo * nyo;
        const float2 o = obj[iobj];

        p += (float2) (o.x*ps.x + o.y*ps.y , o.x*ps.y - o.y*ps.x);
      }
    }
    if(firstpass) probegrad[iprobe + iprobemode * nx * ny] = -p ;
    else probegrad[iprobe + iprobemode * nx * ny] -= p ;
  }
}


// Sum the stack of N object gradient arrays (this must be done in this step to avoid memory access conflicts)
void SumGradN(const int i, __global float2 *objN, __global float2 *obj, const int stack_size, const int nxyo, const int nbobj)
{
  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
     float2 o=0;
     for(int j=0;j<stack_size;j++)
     {
       o += objN[i + (j + iobjmode*stack_size) * nxyo];
     }
     obj[i + iobjmode*nxyo]=o;
  }
}
