/* -*- coding: utf-8 -*-
*
* PyNX - Python tools for Nano-structures Crystallography
*   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
*       authors:
*         Vincent Favre-Nicolin, favre@esrf.fr
*/

/** Elementwise kernel to compute an update of the object and probe from Psi. The update of the object is
* done in N=stack_size arrays to avoid memory conflicts because of the unknown shift between the N frames.
* This should be called with a first argument array with a size of nx*ny, i.e. one frame size. Each parallel
* kernel execution treats one pixel, for all frames and all modes.
*/

__device__ void UpdateObjQuadPhase(const int i, complexf* psi, complexf *objnew, complexf* probe,
                        float* objnorm, const int cx,  const int cy, const float px, const float f,
                        const int stack_size, const int nx, const int ny, const int nxo, const int nyo,
                        const int nbobj, const int nbprobe)
{
  // Coordinate
  const int prx = i % nx;
  const int pry = i / nx;

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  complexf pr;
  float prn=0;

  for(int iprobe=0; iprobe<nbprobe; iprobe++)
  {// TODO: avoid multiple access of probe value (maybe cached ?)
    pr = probe[i + iprobe*nx*ny];
    prn += dot(pr,pr);
  }

  const int iobj0 = cx+prx + nxo*(cy+pry);
  objnorm[iobj0] += prn ; // All the object modes have the same probe normalization.

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  float s, c;
  __sincosf(tmp , &s, &c);

  for(int iobjmode=0; iobjmode<nbobj; iobjmode++)
  {
    complexf o=0;
    const int iobj  = iobj0 + iobjmode * nxo * nyo;
    for(int iprobe=0;iprobe<nbprobe;iprobe++)
    {
      pr = probe[i + iprobe*nx*ny]; // TODO: avoid multiple access of probe value (maybe cached ?)
      complexf ps=psi[ipsi + stack_size * (iprobe + iobjmode * nbprobe) * nx * ny];
      ps = complexf(ps.real()*c - ps.imag()*s , ps.imag()*c + ps.real()*s);
      o += complexf(pr.real()*ps.real() + pr.imag()*ps.imag() , pr.real()*ps.imag() - pr.imag()*ps.real());
    }
    objnew[iobj] += o ;
  }
}

// Same for probe update
__device__ void UpdateProbeQuadPhase(const int i, complexf *obj, complexf* probe, complexf* psi,
                          float* probenorm, int* cx,  int* cy,
                          const float px, const float f, const char firstpass, const int npsi, const int stack_size,
                          const int nx, const int ny, const int nxo, const int nyo, const int nbobj, const int nbprobe)
{
  const int prx = i % nx;
  const int pry = i / nx;

  // obj and probe are centered arrays, Psi is fft-shifted

  // Coordinates in Psi array (origin at (0,0)). Assume nx ny are multiple of 2
  const int iy = pry - ny/2 + ny * (pry<(ny/2));
  const int ix = prx - nx/2 + nx * (prx<(nx/2));
  const int ipsi  = ix + iy * nx ;

  complexf o;
  float prn=0;

  for(int j=0;j<npsi;j++)
  {
    const int iobj0 = cx[j] + prx + nxo*(cy[j] + pry);
    for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
    {
      const int iobj  = iobj0 + iobjmode * nxo * nyo;
      o = obj[iobj]; // TODO 1: store object values to avoid repeated memory read
      prn += dot(o,o);
    }
  }

  // all modes have the same normalization
  if(firstpass) probenorm[i] = prn ;
  else probenorm[i] += prn ;

  // Apply Quadratic phase factor after far field propagation
  const float y = (pry - ny/2) * px;
  const float x = (prx - nx/2) * px;
  const float tmp = f*(x*x+y*y);
  // NOTE WARNING: if the argument becomes large (e.g. > 2^15, depending on implementation), native sin and cos may be wrong.
  float s, c;
  __sincosf(tmp , &s, &c);

  for(int iprobemode=0; iprobemode<nbprobe; iprobemode++)
  {
    complexf p=0;
    for(int j=0;j<npsi;j++)
    {
      const int iobj0 = cx[j] + prx + nxo*(cy[j] + pry); // TODO 1
      for(int iobjmode=0; iobjmode<nbobj; iobjmode++)
      {
        complexf ps = psi[ipsi + (j + stack_size * (iprobemode + iobjmode * nbprobe) ) * nx * ny];
        ps = complexf(ps.real()*c - ps.imag()*s , ps.imag()*c + ps.real()*s);

        const int iobj  = iobj0 + iobjmode * nxo * nyo; // TODO 1
        o = obj[iobj]; // TODO 1

        p += complexf(o.real()*ps.real() + o.imag()*ps.imag() , o.real()*ps.imag() - o.imag()*ps.real());
      }
    }
    if(firstpass) probe[i + iprobemode * nx * ny] = p ;
    else probe[i + iprobemode * nx * ny] += p ;
  }
}

// Normalize object.
// The regularization term is used as in: Marchesini et al, Inverse problems 29 (2013), 115009, eq (14)
__device__ void ObjNorm(const int i, complexf *obj_unnorm, float* objnorm, complexf *obj, const float reg, const int nxyo, const int nbobj)
{
  const float norm = fmaxf(objnorm[i] + reg, 1e-12f);
  for(int iobjmode=0;iobjmode<nbobj;iobjmode++)
    obj[i + iobjmode*nxyo] = (obj_unnorm[i + iobjmode*nxyo] + reg * obj[i + iobjmode*nxyo]) / norm ;
}

