/* -*- coding: utf-8 -*-
*
* PyNX - Python tools for Nano-structures Crystallography
*   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
*       authors:
*         Vincent Favre-Nicolin, favre@esrf.fr
*/

/// Get complex value in object space and apply TV regularisation
__device__ complexf get_di_tv(const int i, complexf *d, const int nx, const int ny, const int nz, const float tv)
{
  complexf di = d[i];

  if(tv>0)
  {
    const int ix = i % nx;
    const int iy = (i % (nx * ny)) / nx;
    if(nz>1)
    {
      const int iz = (i % (nx * ny * nz)) / (nx * ny);
      di += tv * (( d[i-1 + nx *(ix==0)]
                   +d[i+1 - nx *(ix==nx-1)]
                   +d[i-nx + nx*ny*(iy==0)]
                   +d[i+nx - nx*ny*(iy==ny-1)]
                   +d[i-nx*ny + nx*ny*nz*(iz==0)]
                   +d[i+nx*ny - nx*ny*nz*(iz==nz-1)])
                   /6.0f - di);
    }
    else
    {
      di += tv * (( d[i-1 + nx *(ix==0)]
                   +d[i+1 - nx *(ix==nx-1)]
                   +d[i-nx + nx*ny*(iy==0)]
                   +d[i+nx - nx*ny*(iy==ny-1)])
                   /4.0f - di);
    }
  }
  return di;
}

/// HIO
__device__ void HIO(const int i, complexf *d, complexf* dold, signed char *support,
                    const float beta, const int nx, const int ny, const int nz, const float tv)
{
  complexf di = get_di_tv(i,d,nx,ny,nz,tv);
  if(support[i]==0) di = dold[i] - beta * di ;
  dold[i] = di;
  // We cannot replace d[] values which are going to be used by other threads
  if(tv==0) d[i] = di;
}

/// HIO, biasing real part to be positive
__device__ void HIO_real_pos(const int i, complexf *d, complexf* dold, signed char *support,
                             const float beta, const int nx, const int ny, const int nz, const float tv)
{
  complexf di = get_di_tv(i,d,nx,ny,nz,tv);
  if((support[i]==0)||(di.real()<0)) di = dold[i] - beta * di ;
  dold[i] = di;
  // We cannot replace d[] values which are going to be used by other threads
  if(tv==0) d[i] = di;
}

/// Error reduction
__device__ void ER(const int i, complexf *d, complexf *dnew, signed char *support,
                   const int nx, const int ny, const int nz, const float tv)
{
  complexf di = get_di_tv(i,d,nx,ny,nz,tv);
  if(support[i]==0) di = complexf(0,0) ;
  dnew[i] = di;
}

/// Error reduction, forcing real part to be positive
__device__ void ER_real_pos(const int i, complexf *d, complexf *dnew, signed char *support,
                            const int nx, const int ny, const int nz, const float tv)
{
  complexf di = get_di_tv(i,d,nx,ny,nz,tv);
  if((support[i]==0)||(di.real()<0)) di = complexf(0,0) ;
  dnew[i] = di;
}

/// Charge flipping
__device__ void CF(const int i, complexf *d, complexf *dnew, signed char *support,
                   const int nx, const int ny, const int nz, const float tv)
{
  complexf di = get_di_tv(i,d,nx,ny,nz,tv);
  if(support[i]==0) di.imag(-di.imag()) ;
  dnew[i] = di;
}

/// Charge flipping, biasing real part to be positive
__device__ void CF_real_pos(const int i, complexf *d, complexf *dnew, signed char *support,
                            const int nx, const int ny, const int nz, const float tv)
{
  complexf di = get_di_tv(i,d,nx,ny,nz,tv);
  if((support[i]==0)||(d[i].real()<0)) di.imag(-di.imag()) ;
  dnew[i] = di;
}

/// RAAR
__device__ void RAAR(const int i, complexf *d, complexf* dold, signed char *support,
                     const float beta, const int nx, const int ny, const int nz, const float tv)
{
  complexf di = get_di_tv(i,d,nx,ny,nz,tv);
  if(support[i]==0) di = (1 - 2 * beta) * di + beta * dold[i];
  dold[i] = di;
  // We cannot replace d[] values which are going to be used by other threads
  if(tv==0) d[i] = di;
}

/// RAAR, biasing real part to be positive
__device__ void RAAR_real_pos(const int i, complexf *d, complexf* dold, signed char *support,
                              const float beta, const int nx, const int ny, const int nz, const float tv)
{
  complexf di = get_di_tv(i,d,nx,ny,nz,tv);
  if((support[i]==0)||(di.real()<0)) di = (1 - 2 * beta) * di + beta * dold[i];
  dold[i] = di;
  // We cannot replace d[] values which are going to be used by other threads
  if(tv==0) d[i] = di;
}
