/** Elementwise kernel to compute the 3D updated object  and its normalisation from psi.
* This back-propagation uses a replication of Psi along all z-layers, normalised by the norm of the probe
* at each z layer, and divided by the total number of layers.
*
* This kernel computes the object:
* - for a single probe position (to avoid memory access conflicts).
* - for all object and probe modes
* - for a given (ix,iy) coordinate in the object, and all iz values.
*
* This should be called with a Psi[0,0,i], i.e. with a nx*ny sized array as first argument. The probe and object modes
* will be looped over.
*/
void Psi2ObjRepNorm1(const int i, __global float2* psi, __global float2 *obj, __global float2* probe, __global char* support,
             __global float2 *obj_new, __global float* obj_norm,
             const int cx, const int cy, const int cz, const float dsx, const float dsy, const float dsz,
             const float px, const float py, const float pz, const float f, const int stack_size,
             const int nx, const int ny, const int nz, const int nxo, const int nyo, const int nzo,
             const int nbobj, const int nbprobe)
{
  // Coordinates in the 3D probe array - x, y only, we will loop to integrate over z
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny;
  const int nxyz = nx * ny * nz;
  const int nxyo = nxo * nyo ;
  const int nxyzo = nxyo * nzo;

  const int ixo = cx + prx;
  const int iyo = cy + pry;
  if((ixo<0) || (ixo>=nxo) || (iyo<0) || (iyo>=nyo)) return; // Outside object array ?

  // Coordinates in Psi array, fft-shifted (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - ny/2) * py;
  const float x = (prx - nx/2) * px;
  const float tmp_dsxy = dsx * x + dsy * y;
  const float dszj = dsz * pz;

  int przmin = max((int)0, -cz);
  int przmax = min(nz, nzo - cz);

  // Amplitude is evenly distributed among all points in the support along z
  int sup_norm = 0;
  for(int prz=przmin;prz<przmax;prz++)
  {
    if(support[ixo + nxo * (iyo + nyo * (cz + prz))] > 0) sup_norm +=1;
  }

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    const int iobjxy = ixo + nxo * iyo + iobjmode * nxyzo;
    for(int prz=przmin;prz<przmax;prz++)
    {
      const int iobj = iobjxy + (cz + prz) * nxyo;

      // Phase factor with the quadratic and multi-angle terms
      const float tmp = f*(x*x+y*y) - tmp_dsxy - dszj * prz;  // Psi->Obj, so - sign (and f is already <0)
      const float s=native_sin(tmp);
      const float c=native_cos(tmp);

      if(support[ixo + nxo*(iyo + nyo * (cz + prz))] > 0)
      {
        float2 o=0;
        float prn=0;
        for(int iprobe=0;iprobe<nbprobe;iprobe++)
        {
          // Correct Psi for phase factor
          float2 ps = psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nxy];
          ps = (float2)(ps.x * c - ps.y * s , ps.y * c + ps.x * s);
          const float2 pr = probe[i + prz * nxy + iprobe * nxyz];
          if(iobjmode==0) prn += dot(pr,pr);
          o += (float2) (pr.x*ps.x + pr.y*ps.y , pr.x*ps.y - pr.y*ps.x);
        }
        obj_new[iobj] += o;
        if(iobjmode==0)
        { // All object modes have the same normalisation
          obj_norm[iobj] += prn * sup_norm;
        }
      }
    }
  }
}

/** Elementwise kernel to compute the 3D updated object  and its normalisation from psi.
* This back-propagation uses a replication of Psi along all z-layers, normalised by the sum of the norm of the probe
* along the entire z stack.
*
* This kernel computes the object:
* - for a single probe position (to avoid memory access conflicts).
* - for all object and probe modes
* - for a given (ix,iy) coordinate in the object, and all iz values.
*
* This should be called with a Psi[0,0,i], i.e. with a nx*ny sized array as first argument. The probe and object modes
* will be looped over.
*/
void Psi2ObjRepNormN(const int i, __global float2* psi, __global float2 *obj, __global float2* probe, __global char* support,
             __global float2 *obj_new, __global float* obj_norm,
             const int cx, const int cy, const int cz, const float dsx, const float dsy, const float dsz,
             const float px, const float py, const float pz, const float f, const int stack_size,
             const int nx, const int ny, const int nz, const int nxo, const int nyo, const int nzo,
             const int nbobj, const int nbprobe)
{
  // Coordinates in the 3D probe array - x, y only, we will loop to integrate over z
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny;
  const int nxyz = nx * ny * nz;
  const int nxyo = nxo * nyo ;
  const int nxyzo = nxyo * nzo;

  const int ixo = cx + prx;
  const int iyo = cy + pry;
  if((ixo<0) || (ixo>=nxo) || (iyo<0) || (iyo>=nyo)) return; // Outside object array ?

  // Coordinates in Psi array, fft-shifted (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - ny/2) * py;
  const float x = (prx - nx/2) * px;
  const float tmp_dsxy = dsx * x + dsy * y;
  const float dszj = dsz * pz;

  // The probe normalisation, which is the same for all object modes and all z layers
  float prn=0;

  int przmin = max((int)0, -cz);
  int przmax = min(nz, nzo - cz);

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    const int iobjxy = ixo + nxo * iyo + iobjmode * nxyzo;
    for(int prz=przmin;prz<przmax;prz++)
    {
      const int iobj = iobjxy + (cz + prz) * nxyo;

      // Phase factor with the quadratic and multi-angle terms
      const float tmp = f*(x*x+y*y) - tmp_dsxy - dszj * prz;  // Psi->Obj, so - sign (and f is already <0)
      const float s=native_sin(tmp);
      const float c=native_cos(tmp);

      if(support[ixo + nxo*(iyo + nyo * (cz + prz))] > 0)
      {
        float2 o=0;
        for(int iprobe=0;iprobe<nbprobe;iprobe++)
        {
          // Correct Psi for phase factor
          float2 ps = psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nxy];
          ps = (float2)(ps.x * c - ps.y * s , ps.y * c + ps.x * s);
          const float2 pr = probe[i + prz * nxy + iprobe * nxyz];
          if(iobjmode==0) prn += dot(pr,pr);
          o += (float2) (pr.x*ps.x + pr.y*ps.y , pr.x*ps.y - pr.y*ps.x);
        }
        obj_new[iobj] += o;
      }
    }
  }
  for(int prz=przmin;prz<przmax;prz++)
  {
    if(support[ixo + nxo*(iyo + nyo * (cz + prz))] > 0)
    {
      const int iobj = cx+prx + nxo*(cy+pry) + (cz + prz) * nxyo;
      obj_norm[iobj] += prn ;
    }
  }
}

/** Elementwise kernel to compute the 3D updated object  and its normalisation from psi.
* This back-propagation uses a replication of Psi along all z-layers, normalised by the sum of the norm of the probe
* along the entire z stack, and weighted by the probe norm.
*
* This kernel computes the object:
* - for a single probe position (to avoid memory access conflicts).
* - for all object and probe modes
* - for a given (ix,iy) coordinate in the object, and all iz values.
*
* This should be called with a Psi[0,0,i], i.e. with a nx*ny sized array as first argument. The probe and object modes
* will be looped over.
*/
void Psi2ObjRepNormNw(const int i, __global float2* psi, __global float2 *obj, __global float2* probe, __global char* support,
             __global float2 *obj_new, __global float* obj_norm,
             const int cx, const int cy, const int cz, const float dsx, const float dsy, const float dsz,
             const float px, const float py, const float pz, const float f, const int stack_size,
             const int nx, const int ny, const int nz, const int nxo, const int nyo, const int nzo,
             const int nbobj, const int nbprobe)
{
  // Coordinates in the 3D probe array - x, y only, we will loop to integrate over z
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny;
  const int nxyz = nx * ny * nz;
  const int nxyo = nxo * nyo ;
  const int nxyzo = nxyo * nzo;

  const int ixo = cx + prx;
  const int iyo = cy + pry;
  if((ixo<0) || (ixo>=nxo) || (iyo<0) || (iyo>=nyo)) return; // Outside object array ?

  // Coordinates in Psi array, fft-shifted (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - ny/2) * py;
  const float x = (prx - nx/2) * px;
  const float tmp_dsxy = dsx * x + dsy * y;
  const float dszj = dsz * pz;

  // The probe normalisation, which is the same for all object modes and all z layers
  float prn=0;

  int przmin = max((int)0, -cz);
  int przmax = min(nz, nzo - cz);

  for(int prz=przmin;prz<przmax;prz++)
  {
    if(support[ixo + nxo*(iyo + nyo * (cz + prz))] > 0)
    {
      for(int iprobe=0;iprobe<nbprobe;iprobe++)
      {
        const int iobj = ixo + nxo * iyo + (cz + prz) * nxyo;
        const float2 pr = probe[i + prz * nxy + iprobe * nxyz];
        prn += dot(pr,pr);
      }
    }
  }


  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    const int iobjxy = ixo + nxo * iyo + iobjmode * nxyzo;
    for(int prz=przmin;prz<przmax;prz++)
    {
      const int iobj = iobjxy + (cz + prz) * nxyo;

      // Phase factor with the quadratic and multi-angle terms
      const float tmp = f*(x*x+y*y) - tmp_dsxy - dszj * prz;  // Psi->Obj, so - sign (and f is already <0)
      const float s=native_sin(tmp);
      const float c=native_cos(tmp);

      if(support[ixo + nxo*(iyo + nyo * (cz + prz))] > 0)
      {
        float2 o=0;
        float pr2 = 0;
        for(int iprobe=0;iprobe<nbprobe;iprobe++)
        {
          // Correct Psi for phase factor
          float2 ps = psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nxy];
          ps = (float2)(ps.x * c - ps.y * s , ps.y * c + ps.x * s);
          const float2 pr = probe[i + prz * nxy + iprobe * nxyz];
          o += (float2) (pr.x*ps.x + pr.y*ps.y , pr.x*ps.y - pr.y*ps.x);
          pr2 += dot(pr,pr);
          if(iobjmode==0) obj_norm[iobj] += pr2;
        }
        obj_new[iobj] += o / (prn + 1e-6f) * pr2;  // TODO: 1e-6 is arbitrary
      }
    }
  }
}

// Normalize object.
// The regularization term is used as in: Marchesini et al, Inverse problems 29 (2013), 115009, eq (14)
void ObjNorm(const int i, __global float2 *obj_unnorm, __global float* objnorm, __global float2 *obj, __global float *normmax, const float inertia, const int nxyo, const int nbobj)
{
  const float reg = normmax[0] * inertia;
  const float norm = fmax(objnorm[i] + reg, 1e-12f);
  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
    obj[i + iobjmode*nxyo] = (obj_unnorm[i + iobjmode*nxyo] + reg * obj[i + iobjmode*nxyo]) / norm ;
}


/** Elementwise kernel to compute the 3D object gradient from psi. This is used to update the object value so that
* it fits the computed Psi value (during AP or DM algorithms).
* This kernel computes the object gradient:
* - for a single probe position (to avoid memory conflicts),
* - for all object modes
* - for a given (ix,iy) coordinate in the object, and all iz values.
* - points not inside the object support have a null gradient
*
* The returned value is the conjugate of the gradient.
*/
void Psi2ObjGrad(const int i, __global float2* psi, __global float2 *obj, __global float2* probe,
                 __global float2 *grad, __global char* support,
                 const int cx, const int cy, const int cz, const float dsx,  const float dsy, const float dsz,
                 const float px, const float py, const float pz, const float f, const int stack_size,
                 const int nx, const int ny, const int nz, const int nxo, const int nyo, const int nzo,
                 const int nbobj, const int nbprobe)
{
  // Coordinates in the 3D probe array - x, y only, we will loop to integrate over z
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny;
  const int nxyz = nx * ny * nz;
  const int nxyo = nxo * nyo ;
  const int nxyzo = nxyo * nzo;

  const int ixo = cx + prx;
  const int iyo = cy + pry;
  if((ixo<0) || (ixo>=nxo) || (iyo<0) || (iyo>=nyo)) return; // Outside object array ?

  // Coordinates in Psi array, fft-shifted (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Phase factor
  const float y = (pry - ny/2) * py;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  // Phase factor for multi-angle
  const float tmp_dsxy = dsx * x + dsy * y;
  const float dszj = dsz * pz;

  int przmin = max((int)0, -cz);
  int przmax = min(nz, nzo - cz);

  // printf("CL (%3d,%3d) cx=%3d cy=%3d cz=%3d ixo=%3d iyo=%3d przmin=%3d przmax=%3d (%3d %3d %3d) (%3d %3d %3d)\n", prx, pry, cx, cy, cz, ixo, iyo, przmin, przmax, nx, ny, nz, nxo, nyo, nzo);

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    const int iobjxy = ixo + nxo * iyo + iobjmode * nxyzo;
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      // TODO: avoid multiple read of arrays

      // First calculate dpsi = Psi - SUM_z (P*O)
      float2 dpsi = psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nxy];
      // Correct Psi for quadratic phase factor
      dpsi = (float2)(dpsi.x * c - dpsi.y * s , dpsi.y * c + dpsi.x * s);
      for(int prz=przmin;prz<przmax;prz++)
      {
        float2 p = probe[i + prz * nxy + iprobe * nxyz];
        // printf("CL (%3d,%3d) cx=%3d cy=%3d cz=%3d prz=%3d iobjxy=%4d\n", prx, pry, cx, cy, cz, prz, iobjxy);
        const float2 o = obj[iobjxy + (cz + prz) * nxyo];

        // Correct PO for multi-angle phase factor
        const float tmp2 = tmp_dsxy + dszj * prz;
        const float s2=native_sin(tmp2);
        const float c2=native_cos(tmp2);
        p = (float2)(p.x * c2 - p.y * s2 , p.y * c2 + p.x * s2);

        dpsi -=(float2)(o.x * p.x - o.y * p.y , o.x * p.y + o.y * p.x);
      }
      // Now the object gradient conjugate for each z layer
      for(int prz=przmin;prz<przmax; prz++)
      {
        const int iobj = iobjxy + (cz + prz) * nxyo;
        if(support[iobj] > 0)
        {
          float2 p = probe[i + prz * nxy + iprobe * nxyz];
          // Correct probe for multi-angle phase factor
          const float tmp2 = tmp_dsxy + dszj * prz;
          const float s2=native_sin(tmp2);
          const float c2=native_cos(tmp2);
          p = (float2)(p.x * c2 - p.y * s2 , p.y * c2 + p.x * s2);

          // probe.conj() * dpsi, to get
          grad[iobj] -= (float2) (p.x*dpsi.x + p.y*dpsi.y , p.x*dpsi.y - p.y*dpsi.x);
        }
      }
    }
  }
}

/** Compute the optimal gamma value to fit an object to the current Psi values (during AP or DM).
* This kernel computes the gamma contribution:
* - for a stack of probe position,
* - for all object modes
* - for a given (ix,iy) coordinate in the object, and all iz values.
*
* Returns numerator and denominator of the coefficient in a float2 value.
*/
float4 Psi2Obj_Gamma(const int i, __global float2* psi, __global float2 *obj, __global float2* probe,
                     __global float2 *dobj, __global int* cx, __global int* cy, __global int* cz,
                     __global float* dsx, __global float* dsy, __global float* dsz,
                     const float px, const float py, const float pz, const float f, const int npsi,
                     const int stack_size, const int nx, const int ny, const int nz,
                     const int nxo, const int nyo, const int nzo, const int nbobj, const int nbprobe)
{
  // Coordinates in the 3D probe array - x, y only, we will loop to integrate over z
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny;
  const int nxyz = nx * ny * nz;
  const int nxyzo = nxo * nyo * nzo;

  // Coordinates in Psi array, fft-shifted (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Avoid overflow
  //const float scale = 1e-6;

  // Quadratic phase factor for field propagation
  const float y = (pry - ny/2) * py;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  float gamma_d = 0;
  float gamma_n = 0;
  float dpsi2 = 0;

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      for(int j=0;j<npsi;j++)
      {
        const int ixo = cx[j] + prx;
        const int iyo = cy[j] + pry;
        if((ixo<0) || (ixo>=nxo) || (iyo<0) || (iyo>=nyo)) continue; // Outside object array ?

        // TODO: use a __local array for psi values to minimize memory transfers ? Or trust the cache.
        float2 pdo = (float2)0;
        float2 dpsi = psi[ipsi + (j + stack_size * (iprobe + iobjmode * nbprobe) ) * nxy];

        // Correct quadratic phase factor in Psi
        dpsi = (float2)(dpsi.x * c - dpsi.y * s , dpsi.y * c + dpsi.x * s);
        const int iobj0 = ixo + nxo * iyo + iobjmode * nxyzo;

        // Phase factor for multi-angle
        const float tmp_dsxy = dsx[j] * x + dsy[j] * y;
        const float dszj = dsz[j] * pz;

        int przmin = max((int)0, -cz[j]);
        int przmax = min(nz, nzo - cz[j]);
        for(int prz=przmin;prz<przmax;prz++)
        {
            float2 p = probe[i + prz * nxy + iprobe * nxyz];
            // Correct probe for multi-angle phase factor
            const float tmp2 = tmp_dsxy + dszj * prz;
            const float s2=native_sin(tmp2);
            const float c2=native_cos(tmp2);
            p = (float2)(p.x * c2 - p.y * s2 , p.y * c2 + p.x * s2);

            const int iobj = iobj0 + nxo * nyo * (cz[j] + prz);
            const float2 o = obj[iobj];
            const float2 d = dobj[iobj];
            dpsi -= (float2)(o.x * p.x - o.y * p.y , o.x * p.y + o.y * p.x);
            pdo += (float2)(d.x * p.x - d.y * p.y , d.x * p.y + d.y * p.x);
        }
        gamma_n += dpsi.x * pdo.x + dpsi.y * pdo.y;
        gamma_d += dot(pdo, pdo);
        dpsi2 += dot(dpsi, dpsi); // For fitting statistics
      }
    }
  }
  //printf("CL: gamma %15e / %15e\\n",gamma_d, gamma_n);
  return (float4)(gamma_n, gamma_d, dpsi2, 0);
}
