/* -*- coding: utf-8 -*-
*
* PyNX - Python tools for Nano-structures Crystallography
*   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
*       authors:
*         Vincent Favre-Nicolin, favre@esrf.fr
*/

/** Multiply object and probe in 3D, and project along z to obtain a 2D wavefront.
* This kernel must be called for a psi array with shape (stack_size, ny, nx), and the calculation will be performed
* for all frames in the stack.
*/
void ObjectProbeMultQuadPhase(const int i, __global float2* psi, __global float2 *obj, __global float2* probe,
                              __global int* cx, __global int* cy, __global int* cz,
                              const float pixel_size_x, const float pixel_size_y, const float f,
                              const int npsi, const int stack_size,
                              const int nx, const int ny, const int nz, const int nxo, const int nyo, const int nzo,
                              const int nbobj, const int nbprobe)
{
  // Coordinates in the 3D probe array - x, y only, we will loop to integrate over z
  const int prx = i % nx;
  const int pry = i / nx;
  const int nxy = nx * ny;
  const int nxyz = nx * ny * nz;
  const int nxyzo = nxo * nyo * nzo;

  // Coordinates in Psi array, fft-shifted (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  // Apply Quadratic phase factor before far field propagation
  const float y = (pry - ny/2) * pixel_size_y;
  const float x = (prx - nx/2) * pixel_size_x;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  const float s=native_sin(tmp);
  const float c=native_cos(tmp);

  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
  {
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      for(int j=0;j<npsi;j++)
      {
        // TODO: use a __local array for psi values to minimize memory transfers ? Or trust the cache.
        float2 ps1 = (float2)0;
        for(int prz=0;prz<nz;prz++)
        {
            const float2 p = probe[i + prz * nxy + iprobe * nxyz];
            const float2 o = obj[cx[j] + prx + nxo * (cy[j] + pry + nyo * (cz[j] + prz)) + iobjmode * nxyzo];
            float2 ps=(float2)(o.x * p.x - o.y * p.y , o.x * p.y + o.y * p.x);
            ps1 += (float2)(ps.x * c - ps.y * s , ps.y * c + ps.x * s);
        }
        psi[ipsi + (j + stack_size * (iprobe + iobjmode * nbprobe) ) * nxy] = ps1;
      }
    }
  }
}
