#! /opt/local/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2018-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

"""
This package includes tests for the CDI command-line scripts.
"""

import os
import sys
import subprocess
import unittest
import tempfile
import shutil
# import warnings
# from functools import wraps
from pynx.cdi.test.test_cdi import make_cdi_data

exclude_cuda = False
exclude_opencl = False
if 'PYNX_PU' in os.environ:
    if 'opencl' in os.environ['PYNX_PU'].lower():
        exclude_cuda = True
    elif 'cuda' in os.environ['PYNX_PU'].lower():
        exclude_opencl = True


# def ignore_warnings(func):
#     @wraps(func)
#     def inner(self, *args, **kwargs):
#         with warnings.catch_warnings():
#             warnings.simplefilter("ignore")
#             res = func(self, *args, **kwargs)
#         return res
#     return inner


class TestCDIRunner(unittest.TestCase):
    """
    Class for tests of the CDI runner scripts
    """

    @classmethod
    def setUpClass(cls):
        cls.tmp_dir = tempfile.mkdtemp()

    @classmethod
    def tearDownClass(cls):
        # print("Removing temporary directory: %s" % (cls.tmp_dir))
        if True:
            shutil.rmtree(cls.tmp_dir)
        else:
            print('Leaving test data in:', cls.tmp_dir)

    # @ignore_warnings
    @unittest.skipIf('cuda' in sys.argv or exclude_opencl, "OpenCL tests excluded")
    def test_cdi_runner_id01_3d_cxi_opencl(self):
        my_env = os.environ.copy()
        my_env["PYNX_PU"] = "opencl"
        path = make_cdi_data(shape=(128, 128, 128), file_type='cxi', dir=self.tmp_dir)
        with subprocess.Popen(['pynx-id01cdi.py', 'data=%s' % path], stderr=subprocess.PIPE, stdout=subprocess.PIPE,
                              cwd=self.tmp_dir, env=my_env) as p:
            stdout, stderr = p.communicate(timeout=200)
            res = p.returncode
            self.assertFalse(res, msg=stderr.decode())

    @unittest.skipIf('cuda' in sys.argv or exclude_opencl, "OpenCL tests excluded")
    @unittest.skipUnless('live_plot' in sys.argv or 'liveplot' in sys.argv, "live plot tests skipped")
    def test_cdi_runner_id01_3d_cxi_liveplot_opencl(self):
        my_env = os.environ.copy()
        my_env["PYNX_PU"] = "opencl"
        path = make_cdi_data(shape=(128, 128, 128), file_type='cxi', dir=self.tmp_dir)
        with subprocess.Popen(['pynx-id01cdi.py', 'data=%s' % path, 'live_plot'], stderr=subprocess.PIPE,
                              stdout=subprocess.PIPE, cwd=self.tmp_dir, env=my_env) as p:
            stdout, stderr = p.communicate(timeout=200)
            res = p.returncode
            self.assertFalse(res, msg=stderr.decode())

    @unittest.skipIf('opencl' in sys.argv or exclude_cuda, "CUDA tests excluded")
    def test_cdi_runner_id01_3d_cxi_cuda(self):
        my_env = os.environ.copy()
        my_env["PYNX_PU"] = "cuda"
        path = make_cdi_data(shape=(128, 128, 128), file_type='cxi', dir=self.tmp_dir)
        with subprocess.Popen(['pynx-id01cdi.py', 'data=%s' % path], stderr=subprocess.PIPE, stdout=subprocess.PIPE,
                              cwd=self.tmp_dir, env=my_env) as p:
            stdout, stderr = p.communicate(timeout=200)
            res = p.returncode
            self.assertFalse(res, msg=stderr.decode())

    @unittest.skipIf('opencl' in sys.argv or exclude_cuda, "CUDA tests excluded")
    @unittest.skipUnless('live_plot' in sys.argv or 'liveplot' in sys.argv, "live plot tests skipped")
    def test_cdi_runner_id01_3d_cxi_liveplot_cuda(self):
        my_env = os.environ.copy()
        my_env["PYNX_PU"] = "cuda"
        path = make_cdi_data(shape=(128, 128, 128), file_type='cxi', dir=self.tmp_dir)
        with subprocess.Popen(['pynx-id01cdi.py', 'data=%s' % path, 'live_plot'], stderr=subprocess.PIPE,
                              stdout=subprocess.PIPE, cwd=self.tmp_dir, env=my_env) as p:
            stdout, stderr = p.communicate(timeout=200)
            res = p.returncode
            self.assertFalse(res, msg=stderr.decode())

    @unittest.skipIf('cuda' in sys.argv or exclude_opencl, "OpenCL tests excluded")
    def test_cdi_runner_id01_2d_cxi_opencl(self):
        my_env = os.environ.copy()
        my_env["PYNX_PU"] = "opencl"
        path = make_cdi_data(shape=(128, 128), file_type='cxi', dir=self.tmp_dir)
        with subprocess.Popen(['pynx-id01cdi.py', 'data=%s' % path], stderr=subprocess.PIPE, stdout=subprocess.PIPE,
                              cwd=self.tmp_dir, env=my_env) as p:
            stdout, stderr = p.communicate(timeout=200)
            res = p.returncode
            self.assertFalse(res, msg=stderr.decode())

    @unittest.skipIf('cuda' in sys.argv or exclude_opencl, "OpenCL tests excluded")
    @unittest.skipUnless('live_plot' in sys.argv or 'liveplot' in sys.argv, "live plot tests skipped")
    def test_cdi_runner_id01_2d_cxi_liveplot_opencl(self):
        my_env = os.environ.copy()
        my_env["PYNX_PU"] = "opencl"
        path = make_cdi_data(shape=(128, 128), file_type='cxi', dir=self.tmp_dir)
        with subprocess.Popen(['pynx-id01cdi.py', 'data=%s' % path, 'live_plot'], stderr=subprocess.PIPE,
                              stdout=subprocess.PIPE,
                              cwd=self.tmp_dir, env=my_env) as p:
            stdout, stderr = p.communicate(timeout=200)
            res = p.returncode
            self.assertFalse(res, msg=stderr.decode())

    @unittest.skipIf('opencl' in sys.argv or exclude_cuda, "CUDA tests excluded")
    def test_cdi_runner_id01_2d_cxi_cuda(self):
        my_env = os.environ.copy()
        my_env["PYNX_PU"] = "cuda"
        path = make_cdi_data(shape=(128, 128), file_type='cxi', dir=self.tmp_dir)
        with subprocess.Popen(['pynx-id01cdi.py', 'data=%s' % path], stderr=subprocess.PIPE, stdout=subprocess.PIPE,
                              cwd=self.tmp_dir, env=my_env) as p:
            stdout, stderr = p.communicate(timeout=200)
            res = p.returncode
            self.assertFalse(res, msg=stderr.decode())

    @unittest.skipIf('opencl' in sys.argv or exclude_cuda, "CUDA tests excluded")
    @unittest.skipUnless('live_plot' in sys.argv or 'liveplot' in sys.argv, "live plot tests skipped")
    def test_cdi_runner_id01_2d_cxi_liveplot_cuda(self):
        my_env = os.environ.copy()
        my_env["PYNX_PU"] = "cuda"
        path = make_cdi_data(shape=(128, 128), file_type='cxi', dir=self.tmp_dir)
        with subprocess.Popen(['pynx-id01cdi.py', 'data=%s' % path, 'live_plot'], stderr=subprocess.PIPE,
                              stdout=subprocess.PIPE,
                              cwd=self.tmp_dir, env=my_env) as p:
            stdout, stderr = p.communicate(timeout=200)
            res = p.returncode
            self.assertFalse(res, msg=stderr.decode())


def suite():
    loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
    test_suite = unittest.TestSuite([loadTests(TestCDIRunner)])
    return test_suite


if __name__ == '__main__':
    res = unittest.TextTestRunner(verbosity=2, descriptions=False).run(suite())
