/** Calculate Psi.conj() * (1 - Iobs / Icalc), for the gradient calculation with Poisson noise.
* Masked pixels are set to zero.
* \param i: the point in the 2D observed intensity array for which the calculation is made
* \param iobs: the observed in tensity array, shape=(stack_size, ny, nx)
* \param psi: the calculated complex amplitude, shape=(nb_obj, nb_probe, stack_size, ny, nx)
* \param background: the incoherent background, of shape (ny, nx)
* \param npsi: number of valid frames in the stack, over which the integration is performd (usually equal to
*              stack_size, except for the last stack which may be incomplete (0-padded)
* \param nbmode: number of modes = nb_probe * nb_obj
* \param nxy: number of pixels in a single frame
* \param nxystack: number of frames in stack multiplied by nxy
* \return: a float4 vector with (poisson llk, gaussian llk, euclidian llk, icalc)
*/
void GradPoissonFourier(const int i, __global float *iobs, __global float2 *psi,
                        __global float *background, const int nbmode,
                        const int nx, const int ny, const int nxystack)
{
  const float obs= iobs[i];

  if(obs < 0)
  {
    // Set masked values to zero
    for(int imode=0; imode < nbmode; imode++)
    {
      psi[i + imode * nxystack] = (float2) (0, 0);
    }
  }

  // Use a Hann window multiplication to dampen high-frequencies in the object
  const int prx = i % nx;
  const int pry = i / nx;
  const float qx = (float)(prx - nx/2 + nx * (prx<(nx/2))) / (float)nx;
  const float qy = (float)(pry - ny/2 + ny * (pry<(ny/2))) / (float)ny;
  const float g = pown(native_cos(qx) * native_cos(qy),2);

  float calc = 0;
  for(int imode=0;imode<nbmode;imode++) calc += dot(psi[i + imode * nxystack], psi[i + imode* nxystack]);

  calc = fmax(1e-12f,calc);  // TODO: KLUDGE ? 1e-12f is arbitrary

  const float f = g * (1 - obs/ (calc + background[i%nxy]));

  for(int imode=0; imode < nbmode; imode++)
  {
    // TODO: store psi to avoid double-read. Or just assume it's cached.
    const float2 ps = psi[i + imode * nxystack];
    psi[i + imode * nxystack] = (float2) (f*ps.x , f*ps.y);
  }
}


/** Elementwise kernel to compute the object gradient from psi. Almost the same as the kernel to compute the
* updated object projection, except that no normalization array is retained.
* This kernel computes the gradient contribution:
* - for a single probe position (to avoid memory conflicts),
* - for all object modes
* - for a given (ix,iy) coordinate in the object, and all iz values.
* - points not inside the object support have a null gradient
*
* The returned value is the conjugate of the gradient.
*/
void GradObj(const int i, __global float2* psi, __global float2 *objgrad, __global float2* probe, __global char* support,
             const int cx,  const int cy, const int cz, const float pixel_size_x, const float pixel_size_y,
             const float f, const int stack_size, const int nx, const int ny, const int nz,
             const int nxo, const int nyo, const int nzo, const int nbobj, const int nbprobe)
{
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny ;
  const int nxyz = nxy * nz;
  const int nxyo = nxo * nyo ;
  const int nxyzo = nxyo * nzo;

  // Coordinates in Psi array, fft-shifted (origin at (0,0,0)). Assume nx ny nz are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - ny/2) * pixel_size_y;
  const float x = (prx - nx/2) * pixel_size_x;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    const int iobjxy = cx+prx + nxo*(cy+pry) + iobjmode * nxyzo;
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      // TODO: avoid reading psi nb_mode times ?
      float2 ps=psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nxy];
      for(int prz=0;prz<nz;prz++)
      {
        const int iobj = iobjxy + (cz + prz) * nxyo;
        if(support[iobj] > 0)
        {
          const float2 pr = probe[i + prz * nxy + iprobe * nxyz];
          ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
          objgrad[iobj] -= (float2) (pr.x*ps.x + pr.y*ps.y , pr.x*ps.y - pr.y*ps.x );
        }
      }
    }
  }
}
