#! /opt/local/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2018-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

"""
This file includes tests for the Ptycho python API.
"""

import unittest
import tempfile
import shutil
import sys
import io
import warnings


def make_ptycho_data(dsize=256, nb_frame=100, nb_photons=1e9, dir_name=None):
    from pynx.ptycho import simulation, shape
    from pynx.ptycho import save_ptycho_data_cxi
    import numpy as np
    pixel_size_detector = 55e-6
    wavelength = 1.5e-10
    detector_distance = 1
    obj_info = {'type': 'phase_ampl', 'phase_stretch': np.pi / 2, 'alpha_win': .2}
    probe_info = {'type': 'gauss', 'sigma_pix': (40, 40), 'defocus': 100e-6, 'shape': (dsize, dsize)}

    # 50 scan positions correspond to 4 turns, 78 to 5 turns, 113 to 6 turns
    scan_info = {'type': 'spiral', 'scan_step_pix': 30, 'n_scans': nb_frame}
    data_info = {'num_phot_max': nb_photons, 'bg': 0, 'wavelength': wavelength, 'detector_distance': detector_distance,
                 'detector_pixel_size': pixel_size_detector, 'noise': 'poisson'}

    # Initialisation of the simulation with specified parameters
    s = simulation.Simulation(obj_info=obj_info, probe_info=probe_info, scan_info=scan_info, data_info=data_info,
                              verbose=False)
    s.make_data()

    # Positions from simulation are given in pixels
    x, y = s.scan.values
    px = wavelength * detector_distance / pixel_size_detector / dsize

    iobs = s.amplitude.values ** 2

    f, path = tempfile.mkstemp(suffix='.cxi', dir=dir_name)
    save_ptycho_data_cxi(path, iobs, pixel_size_detector, wavelength, detector_distance, x * px, y * px, z=None,
                         monitor=None, mask=None, instrument="Simulation", overwrite=True)
    return path


class TestPtycho(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.tmp_dir = tempfile.mkdtemp()

    @classmethod
    def tearDownClass(cls):
        # print("Removing temporary directory: %s" % (cls.tmp_dir))
        shutil.rmtree(cls.tmp_dir)

    def test_make_ptycho_cxi(self):
        path = make_ptycho_data(dir_name=self.tmp_dir)


def suite():
    test_suite = unittest.TestSuite()
    loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
    test_suite.addTest(loadTests(TestPtycho))
    return test_suite


if __name__ == '__main__':
    sys.stdout = io.StringIO()
    warnings.simplefilter('ignore')
    res = unittest.TextTestRunner(verbosity=2, descriptions=False).run(suite())
