/** Calculate Psi.conj() * (1 - Iobs / Icalc), for the gradient calculation with Poisson noise.
* Masked pixels are set to zero.
* This is called for the first frame of a stack of observed intensities, and will loop over all frames.
* \param i: the point in the 3D observed intensity array for which the llk is calculated
* \param iobs: the observed in tensity array, shape=(stack_size, ny, nx)
* \param psi: the calculated complex amplitude, shape=(nb_obj, nb_probe, stack_size, ny, nx)
* \param background: the incoherent background, of shape (ny, nx)
* \param background_grad: the incoherent background, of shape (ny, nx), always updated
* \param nbmode: number of modes = nb_probe * nb_obj
* \param nxystack: number of frames in stack multiplied by nx * ny
* \param npsi: number of frames in stack
* \return: a float4 vector with (poisson llk, gaussian llk, euclidian llk, icalc)
*/
void GradPoissonFourier(const int i, __global float *iobs, __global float2 *psi, __global float *background,
                        __global float *background_grad, const int nbmode, const int nx, const int ny,
                        const int nxystack, const int npsi, const char hann_filter,
                        const float scale_in, const float scale_out, __global float* scale)
{
  const float b = background[i % (nx * ny)];
  float db = 0.0f;
  const float s2 = scale_in * scale_in;

  float g=1;
  if(hann_filter>0)
  {
    // Use a Hann window multiplication to dampen high-frequencies in the object
    const int ix = i % nx;
    const int iy = (i % (nx * ny)) / nx;
    const float qx = (float)(ix - nx * (ix >= (nx / 2))) * 3.14159265f / (float)(nx-1);
    const float qy = (float)(iy - ny * (iy >= (ny / 2))) * 3.14159265f / (float)(ny-1);
    g = pown(native_cos(qx) * native_cos(qy), 2);
  }

  for(int j=0; j<npsi; j++)
  {
    const int ij = i + nx * ny * j;
    const float obs= iobs[ij];

    if(obs < 0)
    {
      for(int imode=0; imode < nbmode; imode++)
        psi[ij + imode * nxystack] = (float2) (0.0f,0.0f);
    }
    else
    {
      float calc = b;
      for(int imode=0;imode<nbmode;imode++) calc += s2 * dot(psi[ij + imode * nxystack],psi[ij + imode* nxystack]);

      calc = 1 - obs / fmax(1e-12f, calc);  // TODO: KLUDGE ? 1e-12f is arbitrary

      db += calc * scale[j]; // is the multiplication by scale here adequate ? Effectively it's a weight

      // For the gradient multiply by scale and not sqrt(scale)
      const float f = scale_out * g * calc * scale[j];

      for(int imode=0; imode < nbmode; imode++)
      {
        // TODO: store psi to avoid double-read. Or just assume it's cached.
        const float2 ps = psi[ij + imode * nxystack];
        psi[ij + imode * nxystack] = (float2) (f*ps.x , f*ps.y);
      }
    }
  }
  background_grad[i % (nx * ny)] = -db;  // check sign
}

/** Elementwise kernel to compute the object gradient from psi. Almost the same as the kernel to compute the
* updated object projection, except that no normalization array is retained.
*/
void GradObj(const int i, __global float2* psi, __global float2 *objgrad, __global float2* probe,
             const float cx,  const float cy, const float px, const float f,
             const int stack_size, const int nx, const int ny, const int nxo, const int nyo,
             const int nbmode, const char interp,
             __global int* obj_idx, __global int* probe_idx,
             __global float* beamx, __global float* beamy)
{
  // cx too large indicate a frame with the direct beam - no object
  if(cx>1e8) return;
  // Coordinate
  const int prx = i % nx;
  const int pry = i / nx;

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  // For a given object mode, accumulate values for different probe modes before writing
  // More efficient if probe modes are contiguous (they should be)
  int iobj_last = -1;
  float bx_last=-1, by_last=-1;

  float2 grad=0;

  for(int imode=0 ; imode < nbmode ; imode++)
  {
    const int iobj = obj_idx[imode];
    const float bx = beamx[imode];
    const float by = beamy[imode];

      if((iobj_last>=0) && ((iobj != iobj_last) || (bx != bx_last) || (by != by_last)))
    {
      // New object mode, store last
      bilinear_atomic_add_c(objgrad, -grad, cx + bx_last + prx, cy + by_last + pry, iobj_last, nxo, nyo, interp);
      grad = 0;
    }
    iobj_last = iobj;
    bx_last = bx;
    by_last = by;

    const int iprobe = probe_idx[imode];

    const float2 pr = bilinear(probe, prx, pry, iprobe, nx, ny, interp, false);
    float2 ps = psi[ipsi + (stack_size * imode) * nx * ny];

    ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
    grad += (float2) (pr.x*ps.x + pr.y*ps.y , pr.x*ps.y - pr.y*ps.x );
  }
  // Store final object mode
  bilinear_atomic_add_c(objgrad, -grad, cx + bx_last + prx, cy + by_last + pry, iobj_last, nxo, nyo, interp);
}

/** Elementwise kernel to compute the probe gradient from psi. Almost the same as the kernel to compute the
* updated probe projection, except that no normalization array is retained.
*/
void GradProbe(const int i, __global float2* psi, __global float2* probegrad, __global float2 *obj,
                          __global float* cx,  __global float* cy, const float px, const float f, const char firstpass,
                          const int npsi, const int stack_size, const int nx, const int ny, const int nxo,
                          const int nyo, const int nbprobe, const int nbmode, const char interp,
                          __global int* obj_idx, __global int* probe_idx,
                          __global float* beamx, __global float* beamy)
{
  const int prx = i % nx;
  const int pry = i / nx;

  // obj and probe are centered arrays, Psi is fft-shifted

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  // TODO: find a way to use __local memory - pyopencl does not support LocalMemory arguments for ElementwiseKernel..
  //   ... this caching mode is ugly...
  // Write only probe and probe norm values at the end
  #define NB_PROBE_CACHE 8
  float2 p[NB_PROBE_CACHE];

  // Only update by batches of NB_PROBE_CACHE modes, each time looping on psi values
  for(int iprobe0=0; iprobe0<nbprobe ; iprobe0+=NB_PROBE_CACHE)
  {
    // Init local cache
    for(int iprobe=0; iprobe<NB_PROBE_CACHE; iprobe++)
      p[iprobe] = 0;

    for(int j=0;j<npsi;j++)
    {
      for(int imode=0 ; imode < nbmode ; imode++)
      {
        const int iprobe = probe_idx[imode] - iprobe0;
        if((iprobe<0) || (iprobe>=NB_PROBE_CACHE)) continue;

        const int iobj = obj_idx[imode];

        float2 ps = psi[ipsi + (j + stack_size * imode) * nx * ny];
        ps=(float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);

        if(cx[j]>1e8) p[iprobe] += ps;
        else
        {
          const float bx = beamx[imode];
          const float by = beamy[imode];
          const float2 o = bilinear(obj, cx[j]+bx+prx, cy[j]+by+pry, iobj, nxo, nyo, interp, false);
          p[iprobe] += (float2) (o.x*ps.x + o.y*ps.y , o.x*ps.y - o.y*ps.x);
        }
      }
    }
    // Store values
    const int m = nbprobe - iprobe0 > NB_PROBE_CACHE ? NB_PROBE_CACHE : nbprobe - iprobe0;
    if(firstpass)
      for(int iprobe=0; iprobe<m; iprobe++)
        probegrad[i + (iprobe+iprobe0) * nx * ny] = -p[iprobe];
    else
      for(int iprobe=0; iprobe<m; iprobe++)
        probegrad[i + (iprobe+iprobe0) * nx * ny] -= p[iprobe];
  }
}


// Sum the stack of N object gradient arrays (this must be done in this step to avoid memory access conflicts)
void SumGradN(const int i, __global float2 *objN, __global float2 *obj, const int stack_size, const int nxyo, const int nbobj)
{
  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
     float2 o=0;
     for(int j=0;j<stack_size;j++)
     {
       o += objN[i + (j + iobjmode*stack_size) * nxyo];
     }
     obj[i + iobjmode*nxyo]=o;
  }
}

/** Regularisation gradient, to penalise local variations in the object or probe array
*/
void GradReg(const int i, __global float2 *dv, __global float2 *v, const float alpha, const int nx, const int ny)
{
  const int x = i % nx;
  const int y = (i % (nx * ny)) / nx;

  const float2 v0=v[i];
  float2 d = (float2)(0, 0);

  // The 4 cases could be put in a loop for simplicity (but not performance)
  if(x>0)
  {
    const float2 v1=v[i-1];
    d += (float2)(v0.x-v1.x, v0.y-v1.y);
  }
  if(x<(nx-1))
  {
    const float2 v1=v[i+1];
    d += (float2)(v0.x-v1.x, v0.y-v1.y);
  }
  if(y>0)
  {
    const float2 v1=v[i-nx];
    d += (float2)(v0.x-v1.x, v0.y-v1.y);
  }
  if(y<(ny-1))
  {
    const float2 v1=v[i+nx];
    d += (float2)(v0.x-v1.x, v0.y-v1.y);
  }

  dv[i] += 2 * alpha * d;
}
