# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

import numpy as np
import matplotlib.pyplot as plt
from scipy.misc import face, ascent
from pynx.pci import *
from pynx.pci.operator import *
from pynx.wavefront import Wavefront
from pynx.wavefront import operator as wop

# Create object
ny, nx = 512, 512
nb_view = 4
d = np.empty((nb_view, 512, 512), dtype=np.complex64)
d[0] = ascent()
im = face()
for i in range(1, nb_view):
    d[i] = im[im.shape[0] // 2 - 256:im.shape[0] // 2 + 256, im.shape[1] // 2 - 256:im.shape[1] // 2 + 256, i % 3]

delta = 1e-7
beta = 1e-9
wavelength = 1e-10
pixel_size = 1e-6

thickness = 1e-9 * d
# Try less thick objects
thickness[2] *= 0.1
thickness[3] *= .01

mu = 4 * np.pi * beta / wavelength
k = 2 * np.pi / wavelength
d = np.exp(1j * k * (-delta + 1j * beta) * thickness * d)

# Create Probe
probe = np.ones((1, 1, ny, nx), dtype=np.complex64)

# Create data from propagated wavefront
w = Wavefront(d=np.fft.fftshift(d, axes=(-2, -1)), pixel_size=pixel_size, wavelength=1.5e-10)

nb_dist = 5
iobs = np.empty((nb_dist, nb_view, ny, nx), dtype=np.float32)
dz = 0.1
vz = np.arange(nb_dist) + 1 * dz
for i in range(nb_dist):
    iobs[i] = np.abs((wop.PropagateNearField(0.1) * w).get())

# Display last propagated wavefront ?
# w = wop.ImshowAbs() * w

# Create PCIData
data = PCIData(iobs, pixel_size_detector=pixel_size, wavelength=wavelength, detector_distance=vz, mask=None)

# Create PCI object
nb_mode = 1
d = np.reshape(d, (1, nb_view, ny, nx))
pci = PCI(data=data, obj=d, probe=probe)

pci = ShowObj() * pci

plt.figure(figsize=(10, 8))
pci = ShowPsi(i=None, iz=None, fig_num=-1) * PropagateNearField() * ObjProbe2Psi() * pci

pci = Calc2Obs() * pci

pci = BackPropagatePaganin(iz=0, delta=delta, beta=beta, normalize_empty_beam=False) * pci
pci = ShowObj(i=None, type='phase') * pci
