/* -*- coding: utf-8 -*-
*
* PyNX - Python tools for Nano-structures Crystallography
*   (c) 2023-present : ESRF-European Synchrotron Radiation Facility
*       authors:
*         Vincent Favre-Nicolin, favre@esrf.fr
*/

/** RAAR Psi update (with quadratic phase)
* Psi(n+1) = beta/2 [2PO-Psi(n)+ Psi_Fourier] +(1-beta)Psi_Fourier
*/

void ObjectProbePsiRAAR(const int i, __global float2* psi, __global float2* psi_copy,
                       __global float2 *obj, __global float2* probe,
                       __global float* cx, __global float* cy,
                       const float pixel_size, const float f, const int npsi, const int stack_size,
                       const int nx, const int ny, const int nxo, const int nyo,
                       const int nbmode, __global float* scale, const char interp,
                       const float beta, __global int* obj_idx, __global int* probe_idx,
                       __global float* beamx, __global float* beamy)
{
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny;

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - ny/2) * pixel_size;
  const float x = (prx - nx/2) * pixel_size;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int j=0;j<npsi;j++)
  {
    for(int imode=0;imode<nbmode;imode++)
    {
      const int iobj = obj_idx[imode];
      const int iprobe = probe_idx[imode];
      const float bx = beamx[imode];
      const float by = beamy[imode];

      // TODO: check if we need to cache probe values explicitly
      const float2 p = probe[i + iprobe*nxy];

      float2 ps;
      // cx too large indicate a frame with the direct beam - no object
      if(cx[j]>1e8) ps = p;
      else
      {
        // Bilinear interpolation for subpixel shift
        const float2 o = bilinear(obj, cx[j]+bx+prx, cy[j]+by+pry, iobj, nxo, nyo, interp, false) * native_sqrt(scale[j]);
        ps=(float2)(o.x*p.x - o.y*p.y , o.x*p.y + o.y*p.x);
      }
      const int ii = ipsi + (j + stack_size * imode) * nxy;
      const float2 po= (float2) (ps.x*c - ps.y*s , ps.y*c + ps.x*s);
      psi[ii] = (1.0f - 0.5f * beta) * psi[ii] + (beta * 0.5f)*(2.0f * po - psi_copy[ii]);
    }
  }
}
