/** Calculate Psi.conj() * (1 - Iobs / Icalc), for the gradient calculation with Poisson noise.
* Masked pixels are set to zero.
* \param i: the point in the 2D observed intensity array for which the calculation is made
* \param iobs: the observed in tensity array, shape=(stack_size, ny, nx)
* \param psi: the calculated complex amplitude, shape=(nb_obj, nb_probe, stack_size, ny, nx)
* \param background: the incoherent background, of shape (ny, nx)
* \param npsi: number of valid frames in the stack, over which the integration is performd (usually equal to
*              stack_size, except for the last stack which may be incomplete (0-padded)
* \param nbmode: number of modes = nb_probe * nb_obj
* \param nxy: number of pixels in a single frame
* \param nxystack: number of frames in stack multiplied by nxy
* \return: a float4 vector with (poisson llk, gaussian llk, euclidian llk, icalc)
*/
void GradPoissonFourier(const int i, __global float *iobs, __global float2 *psi,
                        __global float *background, const int nbmode,
                        const int nx, const int ny, const int nxystack)
{
  const float obs= iobs[i];

  if(obs < 0)
  {
    // Set masked values to zero
    for(int imode=0; imode < nbmode; imode++)
    {
      psi[i + imode * nxystack] = (float2) (0, 0);
    }
  }

  // Use a Hann window multiplication to dampen high-frequencies in the object
  const int ix = i % nx;
  const int iy = (i % (nx * ny)) / nx;
  const float qx = (float)(ix - nx * (ix >= (nx / 2))) * 3.14159265f / (float)(nx-1);
  const float qy = (float)(iy - ny * (iy >= (ny / 2))) * 3.14159265f / (float)(ny-1);
  const float g = pown(native_cos(qx) * native_cos(qy), 2);

  float calc = 0;
  for(int imode=0;imode<nbmode;imode++) calc += dot(psi[i + imode * nxystack], psi[i + imode* nxystack]);

  calc = fmax(1e-12f,calc);  // TODO: KLUDGE ? 1e-12f is arbitrary

  const float f = g * (1 - obs/ (calc + background[i % (nx * ny)]));

  for(int imode=0; imode < nbmode; imode++)
  {
    // TODO: store psi to avoid double-read. Or just assume it's cached.
    const float2 ps = psi[i + imode * nxystack];
    psi[i + imode * nxystack] = (float2) (f*ps.x , f*ps.y);
  }
}


/** Elementwise kernel to compute the object gradient from psi. Almost the same as the kernel to compute the
* updated object projection, except that no normalization array is retained.
* This kernel computes the gradient contribution:
* - for a single probe position (to avoid memory conflicts),
* - for all object modes
* - for a given (ix,iy) coordinate in the object, and all iz values.
* - points not inside the object support have a null gradient
*
* The returned value is the conjugate of the gradient.
*/
void GradObj(const int i, __global float2* psi, __global float2 *objgrad, __global float2* probe, __global char* support,
             const int cx,  const int cy, const int cz, const float dsx, const float dsy, const float dsz,
             const float px, const float py, const float pz,
             const float f, const int stack_size, const int nx, const int ny, const int nz,
             const int nxo, const int nyo, const int nzo, const int nbobj, const int nbprobe)
{
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny ;
  const int nxyz = nxy * nz;
  const int nxyo = nxo * nyo ;
  const int nxyzo = nxyo * nzo;

  const int ixo = cx + prx;
  const int iyo = cy + pry;
  if((ixo<0) || (ixo>=nxo) || (iyo<0) || (iyo>=nyo)) return; // Outside object array ?

  // Coordinates in Psi array, fft-shifted (origin at (0,0,0)). Assume nx ny nz are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * py;
  const float tmp = f*(x*x+y*y) + dsx * x + dsy * y;

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    const int iobjxy = ixo + nxo * iyo + iobjmode * nxyzo;
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      // TODO: avoid reading psi nb_mode times ?
      float2 ps=psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nxy];

      int przmin = max((int)0, -cz);
      int przmax = min(nz, nzo - cz);
      for(int prz=przmin;prz<przmax;prz++)
      {
        const int iobj = iobjxy + (cz + prz) * nxyo;
        if(support[iobj] > 0)
        {
          const float2 pr = probe[i + prz * nxy + iprobe * nxyz];

          // TODO: check the phase factor signs
          // Add the phase factor with the quadratic and multi-angle terms
          const float tmp2 = tmp + dsz * pz * prz;
          const float s=native_sin(tmp2);
          const float c=native_cos(tmp2);

          ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
          objgrad[iobj] -= (float2) (pr.x*ps.x + pr.y*ps.y , pr.x*ps.y - pr.y*ps.x );
        }
      }
    }
  }
}

// Gaussian convolution of a 3D array (the object gradient) along its 3rd dimension. This elementwise kernel
// should be called for each of the 2D pixels in the XY base plane, and will apply to all modes and pixels along z.
void GaussConvolveZ(const int i, __global float2 *grad, const float sigma,
                    const int nxyo, const int nzo, const int nbobj)
{
   float2 g[15];  // 15 = 2*7+1 or any 2*N+1
   const float norm = 1 / (sigma *native_sqrt(2*3.141592653589f));
   for(int imode=0;imode<nbobj;imode++)
   {
     for(int iz=-7;iz<=7;iz++)
     {
       if(iz>=0) g[iz+7] = grad[i + nxyo * (iz + nzo * imode)];
       else      g[iz+7] = grad[i + nxyo * (iz + nzo * (imode+1))];
     }
     for(int iz=0; iz<nzo;iz++)
     {
        float2 v=0;
        // % could be replaced by a AND (& (2^n - 1)) if the kernel was a power of two-sized
        for(int j=-7;j<=7;j++) v += g[(iz+j+7)%15] * native_exp(-j*j/(2*sigma*sigma)) * norm ;
        grad[i + nxyo * (iz + nzo *imode)] = v;
        g[iz%15] = grad[i + nxyo * ((iz + 7 + 1) % nzo + nzo *imode)];
     }
   }
}
