/* -*- coding: utf-8 -*-
*
* PyNX - Python tools for Nano-structures Crystallography
*   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
*       authors:
*         Vincent Favre-Nicolin, favre@esrf.fr
*/

// Compute 2 * P * O - Psi , with the quadratic phase factor
void ObjectProbePsiDM1(const int i, __global float2* psi, __global float2 *obj, __global float2* probe,
                       __global int* cx, __global int* cy, __global int* cz,
                       __global float* dsx, __global float* dsy, __global float* dsz,
                       const float px, const float py, const float pz, const float f,
                       const int npsi, const int stack_size,
                       const int nx, const int ny, const int nz, const int nxo, const int nyo, const int nzo,
                       const int nbobj, const int nbprobe)
{
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny;
  const int nxyz = nx * ny * nz;
  const int nxyzo = nxo * nyo * nzo;

  // Coordinates in Psi array, fft-shifted (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * py;

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      for(int j=0;j<npsi;j++)
      {
        float2 ps1 = (float2)0;
        const int ixo = cx[j] + prx;
        const int iyo = cy[j] + pry;
        if((ixo>=0) && (ixo<nxo) && (iyo>=0) && (iyo<nyo))
        {
          // TODO: use a __local array for psi values to minimize memory transfers ? Or trust the cache.
          const float tmp_dsxy = f*(x*x+y*y) + dsx[j] * x + dsy[j] * y;
          const float dszj = dsz[j] * pz;

          int przmin = max((int)0, -cz[j]);
          int przmax = min(nz, nzo - cz[j]);
          for(int prz=przmin;prz<przmax;prz++)
          {
            const float2 p = probe[i + prz * nxy + iprobe * nxyz];
            const float2 o = obj[ixo + nxo * (iyo + nyo * (cz[j] + prz)) + iobjmode * nxyzo];
            float2 ps=(float2)(o.x*p.x - o.y*p.y , o.x*p.y + o.y*p.x);

            // Add the phase factor with the quadratic phase and multi-angle terms
            const float tmp = tmp_dsxy + dszj * prz;
            const float s=native_sin(tmp);
            const float c=native_cos(tmp);

            ps1 += (float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
          }
        }
        const int ii = ipsi + (j + stack_size * (iprobe + iobjmode * nbprobe) ) * nxy;
        psi[ii] = 2 * ps1 - psi[ii];
      }
    }
  }
}

/** Update Psi (with quadratic phase)
* Psi(n+1) = Psi(n) - P*O + Psi_calc ; where Psi_calc=Psi_fourier is (2*P*O - Psi(n)) after applying Fourier constraints
*/
void ObjectProbePsiDM2(const int i, __global float2* psi, __global float2* psi_fourier, __global float2 *obj, __global float2* probe,
                       __global int* cx, __global int* cy, __global int* cz,
                       __global float* dsx, __global float* dsy, __global float* dsz,
                       const float px, const float py, const float pz, const float f,
                       const int npsi, const int stack_size,
                       const int nx, const int ny, const int nz, const int nxo, const int nyo, const int nzo,
                       const int nbobj, const int nbprobe)
{
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny;
  const int nxyz = nx * ny * nz;
  const int nxyzo = nxo * nyo * nzo;

  // Coordinates in Psi array, fft-shifted (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - ny/2) * py;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      for(int j=0;j<npsi;j++)
      {
        float2 ps1 = (float2)0;
        const int ixo = cx[j] + prx;
        const int iyo = cy[j] + pry;
        if((ixo>=0) && (ixo<nxo) && (iyo>=0) && (iyo<nyo))
        {
          // TODO: use a __local array for psi values to minimize memory transfers ? Or trust the cache.
          const float tmp_dsxy = f*(x*x+y*y) + dsx[j] * x + dsy[j] * y;
          const float dszj = dsz[j] * pz;

          int przmin = max((int)0, -cz[j]);
          int przmax = min(nz, nzo - cz[j]);
          for(int prz=przmin;prz<przmax;prz++)
          {
            const float2 p = probe[i + prz * nxy + iprobe * nxyz];

            const float2 o = obj[ixo + nxo * (iyo + nyo * (cz[j] + prz)) + iobjmode * nxyzo];
            float2 ps=(float2)(o.x*p.x - o.y*p.y , o.x*p.y + o.y*p.x);

            // Add the phase factor with the quadratic phase and multi-angle terms
            const float tmp = tmp_dsxy + dszj * prz;
            const float s=native_sin(tmp);
            const float c=native_cos(tmp);

            ps1 += (float2)(ps.x*c - ps.y*s , ps.y*c + ps.x*s);
          }
        }
        const int ii = ipsi + (j + stack_size * (iprobe + iobjmode * nbprobe) ) * nxy;
        psi[ii] += psi_fourier[ii] - ps1;
      }
    }
  }
}
