/* -*- coding: utf-8 -*-
*
* PyNX - Python tools for Nano-structures Crystallography
*   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
*       authors:
*         Vincent Favre-Nicolin, favre@esrf.fr
*/

/// Get complex value in object space and apply TV regularisation
float2 get_di_tv(const int i, __global float2 *d, const int nx, const int ny, const int nz, const float tv)
{
  float2 di = d[i];

  if(tv>0)
  {
    const int ix = i % nx;
    const int iy = (i % (nx * ny)) / nx;
    if(nz>1)
    {
      const int iz = (i % (nx * ny * nz)) / (nx * ny);
      di += tv * (( d[i-1 + nx *(ix==0)]
                   +d[i+1 - nx *(ix==nx-1)]
                   +d[i-nx + nx*ny*(iy==0)]
                   +d[i+nx - nx*ny*(iy==ny-1)]
                   +d[i-nx*ny + nx*ny*nz*(iz==0)]
                   +d[i+nx*ny - nx*ny*nz*(iz==nz-1)])
                   /6.0f - di);
    }
    else
    {
      di += tv * (( d[i-1 + nx *(ix==0)]
                   +d[i+1 - nx *(ix==nx-1)]
                   +d[i-nx + nx*ny*(iy==0)]
                   +d[i+nx - nx*ny*(iy==ny-1)])
                   /4.0f - di);
    }
  }
  return di;
}

/// HIO
void HIO(const int i, __global float2 *d, __global float2* dold, __global char *support,
         const float beta, const int nx, const int ny, const int nz, const float tv)
{
  float2 di = get_di_tv(i,d,nx,ny,nz,tv);
  if(support[i]==0) di = dold[i] - beta * di ;
  dold[i] = di;
  // We cannot replace d[] values which are going to be used by other threads
  if(tv==0) d[i] = di;
}


/// HIO, biasing real part to be positive
void HIO_real_pos(const int i, __global float2 *d, __global float2* dold, __global char *support,
                  const float beta, const int nx, const int ny, const int nz, const float tv)
{
  float2 di = get_di_tv(i,d,nx,ny,nz,tv);
  if((support[i]==0)||(di.x<0)) di = dold[i] - beta * di ;
  dold[i] = di;
  if(tv==0) d[i] = di;
}


/// Error reduction
void ER(const int i, __global float2 *d, __global float2 *dnew, __global char *support,
        const int nx, const int ny, const int nz, const float tv)
{
  float2 di = get_di_tv(i,d,nx,ny,nz,tv);
  if(support[i]==0) di = (float2)(0,0) ;
  dnew[i] = di;
}

/// Error reduction, forcing real part to be positive
void ER_real_pos(const int i, __global float2 *d, __global float2 *dnew, __global char *support,
                 const int nx, const int ny, const int nz, const float tv)
{
  float2 di = get_di_tv(i,d,nx,ny,nz,tv);
  if((support[i]==0)||(di.x<0)) di = (float2)(0,0) ;
  dnew[i] = di;
}

/// Charge flipping
void CF(const int i, __global float2 *d, __global float2 *dnew, __global char *support,
        const int nx, const int ny, const int nz, const float tv)
{
  float2 di = get_di_tv(i,d,nx,ny,nz,tv);
  if(support[i]==0) di.y = -di.y ;
  // Need to write to a new array to avoid conflicts between threads
  dnew[i] = di;
}

/// Charge flipping, biasing real part to be positive
void CF_real_pos(const int i, __global float2 *d, __global float2 *dnew, __global char *support,
                 const int nx, const int ny, const int nz, const float tv)
{
  float2 di = get_di_tv(i,d,nx,ny,nz,tv);
  if((support[i]==0)||(di.x<0)) di.y = -di.y ;
  // Need to write to a new array to avoid conflicts between threads
  dnew[i] = di;
}

/// RAAR
void RAAR(const int i, __global float2 *d, __global float2* dold, __global char *support,
          const float beta, const int nx, const int ny, const int nz, const float tv)
{
  float2 di = get_di_tv(i,d,nx,ny,nz,tv);
  if(support[i]==0) di = (1 - 2 * beta) * di + beta * dold[i];
  dold[i] = di;
  if(tv==0) d[i] = di;
}

/// RAAR, biasing real part to be positive
void RAAR_real_pos(const int i, __global float2 *d, __global float2* dold, __global char *support,
                   const float beta, const int nx, const int ny, const int nz, const float tv)
{
  float2 di = get_di_tv(i,d,nx,ny,nz,tv);
  if((support[i]==0)||(di.x<0)) di = (1 - 2 * beta) * di + beta * dold[i];
  dold[i] = di;
  if(tv==0) d[i] = di;
}

/*
/// DM1
void DM1(const int i, __global float2 *d, __global float2* dold, __global char *support)
{
  const float2 v = d[i];
  dold[i] = v;
  if(support[i]==0) d[i] = -v;
  else d[i] = v;
}

/// DM1, biasing real part to be positive (need to be checked)
void DM1_real_pos(const int i, __global float2 *d, __global float2* dold, __global char *support)
{
  const float2 v = d[i];
  dold[i] = v;
  if((support[i]==0)||(v.x<0)) d[i] = -v;
  else d[i] = v;
}

/// DM2
void DM2(const int i, __global float2 *d, __global float2* dold, __global char *support)
{
  const float2 vold = dold[i];
  if(support[i]==0) d[i] = vold - d[i];
  else d[i] = 2 * vold - d[i];
}

/// DM2, biasing real part to be positive (need to be checked)
void DM2_real_pos(const int i, __global float2 *d, __global float2* dold, __global char *support)
{
  const float2 vold = dold[i];
  if((support[i]==0)||(vold.x<0)) d[i] = vold - d[i];
  else d[i] = 2 * vold - d[i];
}
*/