inline __device__ float pow2(const float v)
{
  return v*v;
}

/** Calculate Psi.conj() * (1 - Iobs / Icalc), for the gradient calculation with Poisson noise.
* Masked pixels are set to zero.
* \param i: the point in the 3D observed intensity array for which the llk is calculated
* \param iobs: the observed in tensity array, shape=(stack_size, ny, nx)
* \param psi: the calculated complex amplitude, shape=(nb_obj, nb_probe, stack_size, ny, nx)
* \param background: the incoherent background, of shape (ny, nx)
* \param nbmode: number of modes = nb_probe * nb_obj
* \param nxy: number of pixels in a single frame
* \param nxystack: number of frames in stack multiplied by nxy
* \return: a float4 vector with (poisson llk, gaussian llk, euclidian llk, icalc)
*/
__device__ void GradPoissonFourier(const int i, float *iobs, complexf *psi, float *background, const int nbmode, const int nx, const int ny, const int nxy, const int nxystack)
{
  const float obs= iobs[i];

  if(obs < 0) return;

  // Use a Hann window multiplication to dampen high-frequencies in the object
  const int prx = i % nx;
  const int pry = i / nx;
  const float qx = (float)(prx - nx/2 + nx * (prx<(nx/2))) / (float)nx;
  const float qy = (float)(pry - ny/2 + ny * (pry<(ny/2))) / (float)ny;
  const float g = pow2(cosf(qx) * cosf(qy));

  float calc = 0;
  for(int imode=0;imode<nbmode;imode++) calc += dot(psi[i + imode * nxystack],psi[i + imode* nxystack]);

  calc = fmaxf(1e-12f,calc);  // TODO: KLUDGE ? 1e-12f is arbitrary

  const float f = g * (1 - obs/ (calc + background[i%nxy]));

  for(int imode=0; imode < nbmode; imode++)
  {
    // TODO: store psi to avoid double-read. Or just assume it's cached.
    const complexf ps = psi[i + imode * nxystack];
    psi[i + imode * nxystack] = complexf(f*ps.real() , f*ps.imag());
  }
}


/** Elementwise kernel to compute the object gradient from psi. Almost the same as the kernel to compute the
* updated object projection, except that no normalization array is retained.
*/
__device__ void GradObj(const int i, complexf* psi, complexf *objgrad, complexf* probe,
             const int cx,  const int cy, const float px, const float f,
             const int stack_size, const int nx, const int ny, const int nxo, const int nyo,
             const int nbobj, const int nbprobe)
{
  // Coordinate
  const int prx = i % nx;
  const int pry = i / nx;

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  float s, c;
  __sincosf(tmp , &s, &c);

  const int iobj0 = cx+prx + nxo*(cy+pry);

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    complexf grad=0;
    const int iobj  = iobj0 + iobjmode*stack_size * nxo * nyo;
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      const complexf pr = probe[i + iprobe*nx*ny];
      complexf ps=psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nx * ny];
      ps = complexf(ps.real()*c - ps.imag()*s , ps.imag()*c + ps.real()*s);
      grad += complexf(pr.real()*ps.real() + pr.imag()*ps.imag() , pr.real()*ps.imag() - pr.imag()*ps.real() );
    }
    objgrad[iobj] -= grad;
  }
}

/** Elementwise kernel to compute the probe gradient from psi. Almost the same as the kernel to compute the
* updated probe projection, except that no normalization array is retained.
*/
__device__ void GradProbe(const int i, complexf* psi, complexf* probegrad, complexf *obj,
                          int* cx,  int* cy,
                          const float px, const float f, const char firstpass, const int npsi, const int stack_size,
                          const int nx, const int ny, const int nxo, const int nyo, const int nbobj, const int nbprobe)
{
  const int prx = i % nx;
  const int pry = i / nx;
  const int iprobe=   i;

  // obj and probe are centered arrays, Psi is fft-shifted

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  float s, c;
  __sincosf(tmp , &s, &c);

  for(int iprobemode=0;iprobemode<nbprobe;iprobemode++)
  {
    complexf p=0;
    for(int j=0;j<npsi;j++)
    {
      const int iobj0 = cx[j] + prx + nxo*(cy[j] + pry);
      for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
      {
        complexf ps = psi[ipsi + (j + stack_size * (iprobemode + iobjmode * nbprobe) ) * nx * ny];
        ps = complexf(ps.real()*c - ps.imag()*s , ps.imag()*c + ps.real()*s);

        const int iobj  = iobj0 + iobjmode * nxo * nyo;
        const complexf o = obj[iobj];

        p += complexf(o.real()*ps.real() + o.imag()*ps.imag() , o.real()*ps.imag() - o.imag()*ps.real());
      }
    }
    if(firstpass) probegrad[iprobe + iprobemode * nx * ny] = -p ;
    else probegrad[iprobe + iprobemode * nx * ny] -= p ;
  }
}


// Sum the stack of N object gradient arrays (this must be done in this step to avoid memory access conflicts)
__device__ void SumGradN(const int i, complexf *objN, complexf *obj, const int stack_size, const int nxyo, const int nbobj)
{
  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
     complexf o=0;
     for(int j=0;j<stack_size;j++)
     {
       o += objN[i + (j + iobjmode*stack_size) * nxyo];
     }
     obj[i + iobjmode*nxyo]=o;
  }
}
