/** Elementwise kernel to compute the 3D object gradient from psi. This is used to update the object value so that
* it fits the computed Psi value (during AP or DM algorithms).
* This kernel computes the object gradient:
* - for a single probe position (to avoid memory conflicts),
* - for all object modes
* - for a given (ix,iy) coordinate in the object, and all iz values.
* - points not inside the object support have a null gradient
*
* The returned value is the conjugate of the gradient.
*/
void Psi2ObjGrad(const int i, __global float2* psi, __global float2 *obj, __global float2* probe,
                 __global float2 *grad, __global char* support,
                 const int cx, const int cy, const int cz,
                 const float pixel_size_x, const float pixel_size_y, const float f, const int stack_size,
                 const int nx, const int ny, const int nz, const int nxo, const int nyo, const int nzo,
                 const int nbobj, const int nbprobe)
{
  // Coordinates in the 3D probe array - x, y only, we will loop to integrate over z
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny;
  const int nxyz = nx * ny * nz;
  const int nxyo = nxo * nyo ;
  const int nxyzo = nxyo * nzo;

  // Coordinates in Psi array, fft-shifted (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - ny/2) * pixel_size_y;
  const float x = (prx - nx/2) * pixel_size_x;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    const int iobjxy = cx+prx + nxo*(cy+pry) + iobjmode * nxyzo;
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      // TODO: avoid multiple read of arrays

      // First calculate dpsi = Psi - SUM_z (P*O)
      float2 dpsi = psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nxy];
      // Correct Psi for quadratic phase factor
      dpsi = (float2)(dpsi.x * c - dpsi.y * s , dpsi.y * c + dpsi.x * s);
      for(int prz=0;prz<nz;prz++)
      {
        const float2 p = probe[i + prz * nxy + iprobe * nxyz];
        const float2 o = obj[iobjxy + (cz + prz) * nxyo];
        dpsi -=(float2)(o.x * p.x - o.y * p.y , o.x * p.y + o.y * p.x);
      }
      // Now the object gradient conjugate for each z layer
      for(int prz=0;prz<nz;prz++)
      {
        const int iobj = iobjxy + (cz + prz) * nxyo;
        if(support[iobj] > 0)
        {
          const float2 pr = probe[i + prz * nxy + iprobe * nxyz];
          // probe.conj() * dpsi
          grad[iobj] -= (float2) (pr.x*dpsi.x + pr.y*dpsi.y , pr.x*dpsi.y - pr.y*dpsi.x );
        }
      }
    }
  }
}

/** Compute the optimal gamma value to fit an object to the current Psi values (during AP or DM).
* This kernel computes the gamma contribution:
* - for a stack of probe position,
* - for all object modes
* - for a given (ix,iy) coordinate in the object, and all iz values.
*
* Returns numerator and denominator of the coefficient in a float2 value.
*/
float4 Psi2Obj_Gamma(const int i, __global float2* psi, __global float2 *obj, __global float2* probe,
                     __global float2 *dobj, __global int* cx, __global int* cy, __global int* cz,
                     const float pixel_size_x, const float pixel_size_y, const float f, const int npsi,
                     const int stack_size, const int nx, const int ny, const int nz,
                     const int nxo, const int nyo, const int nzo, const int nbobj, const int nbprobe)
{
  // Coordinates in the 3D probe array - x, y only, we will loop to integrate over z
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny;
  const int nxyz = nx * ny * nz;
  const int nxyzo = nxo * nyo * nzo;

  // Coordinates in Psi array, fft-shifted (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Avoid overflow
  //const float scale = 1e-6;

  // Quadratic phase factor for field propagation
  const float y = (pry - ny/2) * pixel_size_y;
  const float x = (prx - nx/2) * pixel_size_x;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  float gamma_d = 0;
  float gamma_n = 0;
  float dpsi2 = 0;

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      for(int j=0;j<npsi;j++)
      {
        // TODO: use a __local array for psi values to minimize memory transfers ? Or trust the cache.
        float2 pdo = (float2)0;
        float2 dpsi = psi[ipsi + (j + stack_size * (iprobe + iobjmode * nbprobe) ) * nxy];
        // Correct quadratic phase factor in Psi
        dpsi = (float2)(dpsi.x * c - dpsi.y * s , dpsi.y * c + dpsi.x * s);
        const int iobj0 = cx[j] + prx + nxo * (cy[j] + pry) + iobjmode * nxyzo;
        for(int prz=0;prz<nz;prz++)
        {
            const float2 p = probe[i + prz * nxy + iprobe * nxyz];
            const int iobj = iobj0 + nxo * nyo * (cz[j] + prz);
            const float2 o = obj[iobj];
            const float2 d = dobj[iobj];
            dpsi -= (float2)(o.x * p.x - o.y * p.y , o.x * p.y + o.y * p.x);
            pdo += (float2)(d.x * p.x - d.y * p.y , d.x * p.y + d.y * p.x);
        }
        gamma_n += dpsi.x * pdo.x + dpsi.y * pdo.y;
        gamma_d += dot(pdo, pdo);
        dpsi2 += dot(dpsi, dpsi); // For fitting statistics
      }
    }
  }
  //printf("CL: gamma %15e / %15e\\n",gamma_d, gamma_n);
  return (float4)(gamma_n, gamma_d, dpsi2, 0);
}
