# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

import warnings
import numpy as np

try:
    import pycuda.driver as cu_drv

    cu_drv.init()
    import pycuda.gpuarray as gpuarray
    from pycuda.tools import make_default_context
    from pycuda.driver import Context
    import skcuda.fft as cu_fft

    has_cuda = True
except ImportError:
    pass

from . import ProcessingUnit

from pynx.utils.math import test_smaller_primes

# Keep a list of devices used, and the associated dictionary of contexts to avoid creating new ones, with a
# reference counting for safe deletion.
# TODO: keep a more sophisticated list of contexts for different tasks (processing, copy in/out, etc...)
cu_device_dict = {}

class CUProcessingUnit(ProcessingUnit):
    """
    Processing unit in CUDA space.

    Handles initializing the context and fft plan. Kernel initialization must be done in derived classes.
    """

    def __init__(self):
        super(CUProcessingUnit, self).__init__()
        self.cu_ctx = None  # CUDA context
        self.cu_options = None  # CUDA compile options
        self.cufft_plan = None  # cufft plan
        self.cufft_shape = None  # cufft data shape
        self.cufft_batch = None  # cufft axes
        self.cu_arch = None

    def __del__(self):
        if self.cu_ctx is not None:
            cu_device_dict[self.cu_device][self.cu_ctx] -= 1
            if cu_device_dict[self.cu_device][self.cu_ctx] == 0:
                # Maybe should be using sys.getrefcount() instead of manual reference couting ?
                cu_device_dict[self.cu_device].pop(self.cu_ctx)
                self.cu_ctx.pop()
            self.cu_ctx = None


    def init_cuda(self, cu_ctx=None, cu_device=None, fft_size=(1, 1024, 1024), batch=True, gpu_name=None, test_fft=True,
                  verbose=True):
        """
        Initialize the OpenCL context and creates an associated command queue

        :param cu_ctx: pycuda.driver.Context. If none, a default context will be created
        :param cu_device: pycuda.driver.Device. If none, and no context is given, the fastest GPu will be used.
        :param fft_size: the fft size to be used, for benchmark purposes when selecting GPU. different fft sizes
                         can be used afterwards?
        :param batch: if True, will benchmark using a batch 2D FFT
        :param gpu_name: a (sub)string matching the name of the gpu to be used
        :param test_fft: if True, will benchmark the GPU(s)
        :param verbose: report the GPU found and their speed
        :return: nothing
        """
        self.set_benchmark_fft_parameters(fft_size=fft_size, batch=batch)
        self.use_cuda(gpu_name=gpu_name, cu_ctx=cu_ctx, cu_device=cu_device, test_fft=test_fft, verbose=verbose)

        assert test_smaller_primes(fft_size[-1], self.max_prime_fft, required_dividers=(2,)) \
               and test_smaller_primes(fft_size[-2], self.max_prime_fft, required_dividers=(2,))

        if self.cu_device in cu_device_dict:
            if len(cu_device_dict[self.cu_device]) > 0:
                self.cu_ctx = list(cu_device_dict[self.cu_device].keys())[0]
                cu_device_dict[self.cu_device][self.cu_ctx] += 1
        else:
            cu_device_dict[self.cu_device] = {}

        if self.cu_ctx is None:
            self.cu_ctx = self.cu_device.make_context()
            cu_device_dict[self.cu_device][self.cu_ctx] = 1

        self.cu_options = ["-use_fast_math"]
        # TODO: KLUDGE. Add a workaround if the card (Pascal) is more recent than the driver...
        # if cu_drv.get_version() == (7,5,0) and self.cu_device.compute_capability() == (6,1):
        #    print("WARNING: using a Pascal card (compute capability 6.1) with driver .5: forcing arch=sm_52")
        #    self.cu_arch = 'sm_50'
        self.cu_init_kernels()

    def cu_init_kernels(self):
        """
        Initialize kernels. Virtual function, must be derived.

        :return: nothing
        """

    def cu_fft_set_plan(self, cu_data, batch=True):
        """
        Creates FFT plan, or updates it if the shape of the data or the axes have changed. Nothing is done
        if the shape
        :param cu_data: an array from which the FFT shape will be extracted
        :param batch: if True, perform a 2D batch FFT on an n-dimensional array (n>2). Ignored if data is 2D.
                      The FFT is computed over the last two dimensions.
        :return: nothing
        """
        if len(cu_data.shape) == 2:
            batch = False
        if self.cufft_plan is not None:
            if cu_data.shape == self.cufft_shape and batch == self.cufft_batch:
                return
        if batch:
            # print("Setting cufft plan: ", cu_data.shape, batch)
            self.cufft_plan = cu_fft.Plan(cu_data.shape[-2:], np.complex64, np.complex64,
                                          batch=np.product(cu_data.shape[:-2]))
        else:
            # print("Setting cufft plan: ", cu_data.shape, batch)
            self.cufft_plan = cu_fft.Plan(cu_data.shape, np.complex64, np.complex64, batch=1)

        self.cufft_shape = cu_data.shape
        self.cufft_batch = batch

    def finish(self):
        self.cu_ctx.synchronize()
