# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2018-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr


import socket
import time
import timeit
import sys
import sqlite3
import numpy as np
import matplotlib.pyplot as plt
from pynx.processing_unit import has_cuda, has_opencl, default_processing_unit

if has_opencl:
    from pynx.processing_unit import opencl_device
if has_cuda:
    from pynx.processing_unit import cuda_device
from pynx.scattering.test import mrats


class SpeedTest(object):
    """
    Class for speed tests using either GPU or CPU
    """

    def __init__(self, gpu_name, language, cl_platform=None):
        """

        :param gpu_name: the gpu name to be tested, or 'CPU'.
        :param language: either 'cuda', 'opencl' or 'CPU'
        :param cl_platform: the opencl platform, when using OpenCL. Optional.
        """
        if gpu_name is not None:
            self.gpu_name = gpu_name.lower()
        else:
            self.gpu_name = ''
        if language is not None:
            self.language = language.lower()
        else:
            self.language = ''
        self.cl_platform = cl_platform
        self.results = {}
        self.results['hostname'] = socket.gethostname()
        self.results['epoch'] = time.time()
        self.results['language'] = self.language
        if self.language == 'cuda':
            default_processing_unit.use_cuda(gpu_name=gpu_name, test_fft=False, verbose=False)
            self.results['GPU'] = default_processing_unit.cu_device.name()

        if self.language == 'opencl':
            default_processing_unit.use_opencl(gpu_name=gpu_name, platform=cl_platform, test_fft=False, verbose=False)
            self.results['GPU'] = '%s [%s]' % (default_processing_unit.cl_device.name,
                                               default_processing_unit.cl_device.platform.name)
        # TODO: handle case where language is not given, CPU
        self.db_results = {}
        self.db_conn = None
        self.db_curs = None

    def prepare_db(self, db_name="pynx_speed.db"):
        """
        Create database file and check that it has all the necessary columns
        :param db_name: the name of the sqlite3 database file
        :return: nothing
        """
        self.db_conn = sqlite3.connect(db_name)
        self.db_curs = self.db_conn.cursor()

        self.db_curs.execute('''CREATE TABLE IF NOT EXISTS pynx_speed_test
                     (epoch real, hostname text, language text, GPU text)''')

        for k, v in self.results.items():
            try:
                self.db_conn.execute('ALTER TABLE pynx_speed_test ADD COLUMN %s;' % (k))
            except:
                # column already existed
                pass

    def test_scattering(self, size):
        """
        Test using pynx.scattering.speed.mrats
        :param size: the number of atoms = number of reflections
        :return: nothing. The result is added to self.results
        """
        try:
            gflops, dt = mrats(size, size, gpu_name=self.gpu_name, verbose=True, language=self.language,
                               cl_platform=self.cl_platform, timing=True)
            gflops *= 8e-3
        except:
            gflops, dt = -1, 0
        self.results['scattering_%d_Gflops' % (size)] = gflops
        self.results['scattering_%d_dt' % (size)] = dt
        return dt

    def test_fft_2d(self, size):
        """
        Test a stacked (16) 2D FFT
        :param size: the size N of the 2D FFT, which will be computed as using 16xNxN array
        :return: nothing. The result is added to self.results
        """
        try:
            if 'opencl' in self.language:
                res = opencl_device.available_gpu_speed(gpu_name=self.gpu_name, cl_platform=self.cl_platform,
                                                        fft_shape=(16, size, size), axes=(-1, -2), verbose=False,
                                                        return_dict=True)
            elif 'cuda' in self.language:
                res = cuda_device.available_gpu_speed(gpu_name=self.gpu_name, fft_shape=(16, size, size), batch=True,
                                                      verbose=False, return_dict=True)
            else:
                # TODO: test CPU speed
                pass
            # Get the best results (should be just 1, unless multiple GPU are present)
            gflops, dt = 0, 0
            for k, v in res.items():
                if v['Gflop/s'] > gflops:
                    gflops = v['Gflop/s']
                    dt = v['dt']
        except:
            gflops = -1
            dt = 0
        self.results['fft_2Dx16_%d_Gflops' % (size)] = gflops
        self.results['fft_2Dx16_%d_dt' % (size)] = dt
        print('fft_2Dx16_%d: %8.2f Gflop/s, dt =%6.4fs' % (size, gflops, dt))
        return dt

    def test_fft_3d(self, size):
        """
        Test a 3D FFT
        :param size: the size N of the 3D FFT, which will be computed using a NxNxN array
        :return: nothing. The result is added to self.results
        """
        try:
            if 'opencl' in self.language:
                res = opencl_device.available_gpu_speed(gpu_name=self.gpu_name, cl_platform=self.cl_platform,
                                                        fft_shape=(size, size, size), axes=None, verbose=False,
                                                        return_dict=True)
            elif 'cuda' in self.language:
                res = cuda_device.available_gpu_speed(gpu_name=self.gpu_name, fft_shape=(size, size, size), batch=False,
                                                      verbose=False, return_dict=True)
            else:
                # TODO: test CPU speed
                pass
            # Get the best results (should be just 1, unless multiple GPU are present)
            gflops, dt = 0, 0
            for k, v in res.items():
                if v['Gflop/s'] > gflops:
                    gflops = v['Gflop/s']
                    dt = v['dt']
        except:
            gflops = -1
            dt = 0
        self.results['fft_3D_%d_Gflops' % (size)] = gflops
        self.results['fft_3D_%d_dt' % (size)] = dt
        print('fft_3D_%d: %8.2f Gflop/s, dt =%6.4fs' % (size, gflops, dt))
        return dt

    def test_ptycho(self, nb_frame, frame_size, nb_cycle=20, algo="AP"):
        """
        Run 2D ptychography speed test

        :param nb_frame:
        :param frame_size:
        :param nb_cycle:
        :param nb_obj:
        :param nb_probe:
        :param algo:
        :return: the execution time (not counting initialisation)
        """
        # TODO: move this elswhere, to avoid imports inside the function
        from pynx.ptycho import simulation, shape
        if 'opencl' in self.language:
            from pynx.ptycho import cl_operator as ops
            ops.default_processing_unit.select_gpu(gpu_name=self.gpu_name, language='opencl', verbose=False)
        elif 'cuda' in self.language:
            from pynx.ptycho import cu_operator as ops
            ops.default_processing_unit.select_gpu(gpu_name=self.gpu_name, language='cuda', verbose=False)
        elif 'cpu' in self.language:
            from pynx.ptycho import cpu_operator as ops
        else:
            from pynx.ptycho import operator as ops
            ops.default_processing_unit.select_gpu(gpu_name=self.gpu_name, verbose=False)
        from pynx.ptycho.ptycho import Ptycho, PtychoData

        n = frame_size
        pixel_size_detector = 55e-6
        wavelength = 1.5e-10
        detector_distance = 1
        obj_info = {'type': 'phase_ampl', 'phase_stretch': np.pi / 2, 'alpha_win': .2}
        probe_info = {'type': 'gauss', 'sigma_pix': (40, 40), 'shape': (n, n)}

        # 50 scan positions correspond to 4 turns, 78 to 5 turns, 113 to 6 turns
        scan_info = {'type': 'spiral', 'scan_step_pix': 30, 'n_scans': nb_frame}
        data_info = {'num_phot_max': 1e9, 'bg': 0, 'wavelength': wavelength, 'detector_distance': detector_distance,
                     'detector_pixel_size': pixel_size_detector, 'noise': 'poisson'}

        # Initialisation of the simulation with specified parameters
        s = simulation.Simulation(obj_info=obj_info, probe_info=probe_info, scan_info=scan_info, data_info=data_info,
                                  verbose=False)
        s.make_data()

        # Positions from simulation are given in pixels
        posy, posx = s.scan.values

        ampl = s.amplitude.values  # square root of the measured diffraction pattern intensity
        pixel_size_object = wavelength * detector_distance / pixel_size_detector / n
        data = PtychoData(iobs=ampl ** 2, positions=(posy * pixel_size_object, posx * pixel_size_object),
                          detector_distance=1, mask=None, pixel_size_detector=pixel_size_detector,
                          wavelength=wavelength)

        p = Ptycho(probe=s.probe.values, obj=s.obj.values, data=data, background=None)

        if algo.lower() == "dm":
            op = ops.DM()
        elif algo.lower() == "ml":
            op = ops.ML()
        else:
            algo = "AP"
            op = ops.AP()

        p = op ** 5 * p
        ops.default_processing_unit.finish()

        t0 = timeit.default_timer()
        p = op ** nb_cycle * p
        ops.default_processing_unit.finish()
        dt = timeit.default_timer() - t0
        # 5 * n * n is a (very) conservative estimate of the number of elementwise operations.
        gflops = nb_cycle * nb_frame * (2 * 5 * n * n * np.log2(n * n) + n * n * 5) / 1e9 / dt

        self.results['ptycho_%dx%dx%d_%s_Gflops' % (nb_frame, n, n, algo)] = gflops
        self.results['ptycho_%dx%dx%d_%s_dt' % (nb_frame, n, n, algo)] = dt / nb_cycle
        print('ptycho_%dx%dx%d_%s: %8.2f Gflop/s, dt =%6.4fs/cycle' % (nb_frame, n, n, algo, gflops, dt / nb_cycle))

        return dt

    def test_cdi3d(self, size=128, nb_cycle=20, algo="ER"):
        """
        Run 3D CDI speed test
        :param size: the size of the object along one dimension
        :param nb_cycle: the number of cycles to execute
        :param algo: the algorithm to use (ER by default, can also be HIO or RAAR)
        :return: the execution time (not counting initialisation)
        """
        from pynx.cdi.cdi import CDI
        if 'opencl' in self.language:
            from pynx.cdi import cl_operator as ops
            ops.default_processing_unit.select_gpu(gpu_name=self.gpu_name, language='opencl', verbose=False)
        elif 'cuda' in self.language:
            from pynx.cdi import cu_operator as ops
            ops.default_processing_unit.select_gpu(gpu_name=self.gpu_name, language='cuda', verbose=False)
        elif 'cpu' in self.language:
            from pynx.cdi import cpu_operator as ops
        else:
            from pynx.cdi import operator as ops
            ops.default_processing_unit.select_gpu(gpu_name=self.gpu_name, verbose=False)

        n = size

        # Object coordinates
        tmp = np.arange(-n // 2, n // 2, dtype=np.float32)
        z, y, x = np.meshgrid(tmp, tmp, tmp, indexing='ij')

        # Parallelepiped object
        obj0 = (abs(x) < 12) * (abs(y) < 10) * (abs(z) < 16)
        # Start from a slightly loose support
        support = (abs(x) < 20) * (abs(y) < 20) * (abs(z) < 25)

        cdi = CDI(np.zeros_like(obj0), obj=obj0, support=np.fft.fftshift(support), mask=None, wavelength=1e-10,
                  pixel_size_detector=55e-6)

        cdi = ops.Calc2Obs() * cdi

        if algo.lower() == "hio":
            op = ops.HIO()
        elif algo.lower() == "raar":
            op = ops.RAAR()
        else:
            algo = "ER"
            op = ops.ER(calc_llk=False)

        cdi = op ** 5 * cdi
        ops.default_processing_unit.finish()

        t0 = timeit.default_timer()
        cdi = op ** nb_cycle * cdi
        ops.default_processing_unit.finish()
        dt = timeit.default_timer() - t0
        # 5 * n * n is a (very) conservative estimate of the number of elementwise operations.
        gflops = nb_cycle * (2 * 5 * n ** 3 * np.log2(n ** 3) + n * n * n * 5) / 1e9 / dt

        self.results['cdi3d_%dx%dx%d_%s_Gflops' % (n, n, n, algo)] = gflops
        self.results['cdi3d_%dx%dx%d_%s_dt' % (n, n, n, algo)] = dt
        print('cdi3d_%dx%dx%d_%d%s: %8.2f Gflop/s, dt =%6.4fs/cycle' % (n, n, n, nb_cycle, algo, gflops, dt/nb_cycle))

        return dt

    def run_all(self, export_db=None, verbose=True):
        """
        Run all speed tests, and optionally save the results to a a database (sqlite3 file).

        :param export_db: the name of the database to save the speed tests to. The file will be created if necessary.
        :param verbose: if True, verbose output
        :return: nothing
        """
        dt = self.test_scattering(int(2 ** 10))
        dt = self.test_scattering(int(2 ** 14))
        if dt < 1:
            dt = self.test_scattering(int(2 ** 18))
        if dt < 1:
            dt = self.test_scattering(int(2 ** 20))
        dt = self.test_fft_2d(size=256)
        dt = self.test_fft_2d(size=1024)
        if dt < 1:
            dt = self.test_fft_2d(size=4096)
        dt = self.test_fft_3d(size=128)
        dt = self.test_fft_3d(size=256)
        if dt < 1:
            dt = self.test_fft_3d(size=512)
        if export_db:
            self.prepare_db()
            cols = ""
            vals = ""
            for k, v in self.results.items():
                if len(cols) > 0:
                    cols += ','
                    vals += ','
                cols += k
                if k in ['hostname', 'language', 'GPU']:
                    vals += "'%s'" % (v)
                else:
                    vals += str(v)
            com = "INSERT INTO pynx_speed_test (%s) VALUES (%s);" % (cols, vals)
            self.db_curs.execute(com)
            self.db_conn.commit()
            self.db_curs.close()

    def import_db(self, unique=True, gpu_name=None, language=None, cl_platform=None):
        """
        Extract the results stored in the database.

        :param unique: if True, only the latest result will be plotted for a given combination
                       of hostname + GPU + language.
        :param gpu_name: name or partial name for the GPU which should be listed
        :param language: 'cuda' or 'opencl' or 'CPU' to filter results
        :param cl_platform: the opencl platform to plot
        :return: nothing
        """
        self.db_results = {}
        self.prepare_db()
        self.db_curs.execute('select * from pynx_speed_test order by epoch')
        rr = self.db_curs.fetchall()
        tt = [x[0] for x in self.db_curs.description]
        for r in rr:
            d = {}
            for k, v in zip(tt, r):
                d[k] = v
            name = '%s[%s]\n[%s]' % (d['GPU'], d['language'], d['hostname'])
            if not unique:
                name += '\n[%s]' % (time.strftime('%Y/%m/%d %H:%M:%S', time.gmtime(d['epoch'])))
            if gpu_name is not None:
                if gpu_name.lower() not in d['GPU'].lower():
                    continue
            if language is not None:
                if language.lower() not in d['language'].lower():
                    continue
            if cl_platform is not None:
                if cl_platform.lower() not in d['cl_platform'].lower():
                    continue
            self.db_results[name] = d

    def plot(self, unique=True, gpu_name=None, language=None, cl_platform=None):
        """
        Plot all the results stored in the database.

        :param unique: if True, only the latest result will be plotted for a given combination
                       of hostname + GPU + language.
        :param gpu_name: name or partial name for the GPU which should be listed
        :param language: 'cuda' or 'opencl' or 'CPU' to filter results
        :param cl_platform: the opencl platform to plot
        :return: nothing
        """
        self.import_db(unique=unique, gpu_name=gpu_name, language=language, cl_platform=cl_platform)
        tt = list(list(self.db_results.values())[0].keys())
        for t in tt:
            if 'Gflops' in t:
                t2 = t.split('Gflops')[0] + 'dt'
                name, gflops, dt = [], [], []
                for k, v in self.db_results.items():
                    name.append(k)
                    if v[t] is None:
                        gflops.append(0)
                    else:
                        gflops.append(v[t])
                    if v[t] is None:
                        dt.append(0)
                    else:
                        dt.append(v[t2])
                print('Plotting: %s' % t)
                gflops = np.array(gflops)
                dt = np.array(dt)
                x = range(len(gflops))
                plt.figure(figsize=(12, 6))
                plt.bar(x, np.array(gflops))
                plt.xticks(x, name, rotation=90, horizontalalignment='center', verticalalignment='bottom')
                plt.ylabel(t)
                plt.ylim(10, gflops.max() * 1.05)
                ax1 = plt.gca()
                ax2 = ax1.twinx()
                ax2.set_yscale('log')
                imax = gflops.argmax()
                dtmin = dt[imax] / 1.05
                dtmax = dt[imax] * gflops[imax] / 10
                ax2.set_ylim(dtmax, dtmin)
                ax2.set_ylabel('dt(s)')
                plt.show()


if __name__ == '__main__':
    gpu_name = None
    language = None
    cl_platform = None
    do_plot = False
    for a in sys.argv[1:]:
        karg = a.split('=')
        if len(karg) == 2:
            if karg[0] == 'gpu':
                gpu_name = karg[1]
            elif karg[0] == 'language':
                language = karg[1]
            elif karg[0] == 'cl_platform':
                cl_platform = karg[1]
        elif a.lower() == 'plot':
            do_plot = True
    s = SpeedTest(gpu_name, language=language, cl_platform=cl_platform)
    if gpu_name is not None:
        s.run_all(export_db='pynx_speed.db')
    if do_plot:
        r = s.plot(unique=True)
