#! /opt/local/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2018-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

"""
This file includes tests for the CDI python API.
"""

import unittest
import tempfile
import shutil
import numpy as np
from pynx.utils.pattern import siemens_star, fibonacci_urchin
from pynx.cdi import save_cdi_data_cxi


def make_cdi_data(shape=(128, 128, 128), obj_shape='rectangle', file_type='cxi', nb_photons=1e9, dir=None):
    """
    Create a CDI data file.
    :param shape: the shape of the data file, either 2D or 3D.
    :param obj_shape: the object shape, either 'rectangle' (by default the lateral size is 1/4 of the array shape),
                      or 'circle' or 'sphere' or 'star' (a Siemens star)
    :param file_type: either npz or cxi
    :param nb_photons: the total number of photons in the data array
    :param dir: the directory where the file will be created
    :return: the file name
    """
    ndim = len(shape)
    assert (ndim in [2, 3])
    if ndim == 2:
        ny, nx = shape
        y, x = np.meshgrid(np.arange(ny) - ny // 2, np.arange(nx) - nx // 2, indexing='ij')
        z = 0
        nz = 1
    else:
        nz, ny, nx = shape
        z, y, x = np.meshgrid(np.arange(nz) - nz // 2, np.arange(ny) - ny // 2, np.arange(nx) - nx // 2, indexing='ij')

    if obj_shape == 'star':
        if ndim == 2:
            nxy = min(nx, ny)
            a = siemens_star(dsize=nxy, nb_rays=7, r_max=nxy / 4, nb_rings=3)
            d = np.zeros((ny, nx))
            d[ny // 2 - nxy // 2:ny // 2 + nxy // 2, nx // 2 - nxy // 2:nx // 2 + nxy // 2] = a
        else:
            nxy = min(nx, ny, nz)
            a = fibonacci_urchin(dsize=nxy, nb_rays=20, r_max=nxy / 4, nb_rings=8)
            d = np.zeros((nz, ny, nx))
            d[nz // 2 - nxy // 2:nz // 2 + nxy // 2, ny // 2 - nxy // 2:ny // 2 + nxy // 2,
            nx // 2 - nxy // 2:nx // 2 + nxy // 2] = a
    elif obj_shape in ['circle', 'sphere']:
        r = min(x, y) / 8
        d = np.sqrt(x ** 2 + y ** 2 + z ** 2) <= r
    else:
        # 'rectangle'
        d = (abs(x) <= (nx // 8)) * (abs(y) <= (ny // 8)) * (abs(z) <= (nz // 8))

    d = np.fft.fftshift(np.abs(np.fft.fftn((d.astype(np.complex64))))) ** 2
    d *= nb_photons / d.sum()

    if file_type == 'cxi':
        f, path = tempfile.mkstemp(suffix='.cxi', dir=dir)
        save_cdi_data_cxi(path, d, wavelength=1.5e-10, detector_distance=1, pixel_size_detector=55e-6, mask=None,
                          sample_name=None, experiment_id=None, instrument=None, note=None, iobs_is_fft_shifted=False)
    else:
        # npz
        f, path = tempfile.mkstemp(suffix='.npz', dir=dir)
        np.savez_compressed(path, d=d)

    return path


class TestCDI(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.tmp_dir = tempfile.mkdtemp()

    @classmethod
    def tearDownClass(cls):
        # print("Removing temporary directory: %s" % (cls.tmp_dir))
        shutil.rmtree(cls.tmp_dir)

    def test_make_cdi_cxi(self):
        path = make_cdi_data(file_type='cxi', dir=self.tmp_dir)

    def test_make_cdi_npz(self):
        path = make_cdi_data(file_type='npz', dir=self.tmp_dir)


def suite():
    test_suite = unittest.TestSuite()
    loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
    test_suite.addTest(loadTests(TestCDI))
    return test_suite


if __name__ == '__main__':
    unittest.main(defaultTest='suite')
