# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2017-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

import numpy as np
import matplotlib.pyplot as plt
from .holotomo import HoloTomo, HoloTomoData, OperatorHoloTomo

__all__ = ['ShowObj', 'ShowPsi']

import numpy as np
import matplotlib.pyplot as plt

from ..utils.plot_utils import complex2rgbalin, complex2rgbalog, insertColorwheel, cm_phase
from .holotomo import OperatorHoloTomo


class ShowObj(OperatorHoloTomo):
    """
    Class to display a phase contrast object.
    """

    def __init__(self, fig_num=None, istack=0, i=None, type='Phase',
                 title=None, figsize=None, probe=False):
        """

        :param fig_num: the matplotlib figure number. if None, a new figure will be created each time.
        :param istack: index of the stack to display
        :param i: the index of the object to be displayed. If this is a list or array, all the listed views
                  will be shown. If None, all objects are shown.
        :param type: what to show. Can be 'phase' (the default), 'amplitude', 'rgba'.
        :param title: the title for the view. If None, a default title will be used.
        :param figsize: if a new figure is created, this parameter will be passed to matplotlib
        :param probe: if True, also plot the probe
        """
        super(ShowObj, self).__init__()
        self.fig_num = fig_num
        self.istack = istack
        self.i = i
        self.type = type
        self.title = title
        self.figsize = figsize
        self.idx = None  # projections indices - inferred from istack and i
        self.probe = probe

    def pre_imshow(self, pci: HoloTomo):
        pci._from_pu()
        idx = pci.data.idx[pci.data.stack_size * self.istack: pci.data.stack_size * (self.istack + 1)]
        if self.type.lower() == 'phase':
            d = pci.get_obj_phase_unwrapped(idx=idx)[1]
        else:
            d = pci.data.stack_v[self.istack].obj
            d = d.reshape((d.shape[0], d.shape[-2], d.shape[-1]))
        if self.i is None:
            self.i = list(range(len(d)))
        if type(self.i) is int:
            d = d[self.i]
            self.idx = [idx[self.i]]
        else:
            d = d.take(self.i, axis=0)
            self.idx = idx.take(self.i, axis=0)

        if self.probe:
            pr = pci.get_probe(crop_padding=True)[:,0]
            if self.type.lower() == 'phase':
                pr = np.angle(pr)
            d = np.append(d, pr, axis=0)
            # Code probe with <0 indices
            self.idx = np.append(self.idx, np.arange(-1, -len(pr) - 1, -1))

        if self.fig_num != -1:
            fig = plt.figure(self.fig_num, figsize=self.figsize)
            self.fig_num = fig.number

        x, y = pci.get_x_y()
        s = np.log10(max(abs(x).max(), abs(y).max()))
        if s < -6:
            unit_name = "nm"
            s = 1e9
        elif s < -3:
            unit_name = u"µm"
            s = 1e6
        elif s < 0:
            unit_name = "mm"
            s = 1e3
        else:
            unit_name = "m"
            s = 1
        return d, x * s, y * s, unit_name

    def post_imshow(self, pci: HoloTomo, x, y, unit_name):
        plt.xlabel("X (%s)" % (unit_name))
        plt.ylabel("Y (%s)" % (unit_name))

    def op(self, pci: HoloTomo):
        d, x, y, unit_name = self.pre_imshow(pci)
        nd = 1
        if d.ndim == 3:
            nd = len(d)
            max_cols = 3
            if nd > max_cols:
                ncols = max_cols
                nrows = nd // ncols
                if nd % ncols > 0:
                    nrows += 1
            else:
                ncols = nd
                nrows = 1

        # Check if there were already plots which can be re-used,
        # otherwise clear the figure
        fig = plt.figure(self.fig_num)
        # Current labels in figure
        vlabels = [ax.get_label() for ax in fig.axes]
        # Expected labels
        axes_labels = [f"ShowObj {self.type.lower()} #{self.idx[i]}" for i in range(nd)]

        reuse_axes = np.all([axes_labels[i] in vlabels for i in range(nd)])

        if not reuse_axes:
            plt.clf()

        if self.title is not None:
            plt.suptitle(self.title)
        else:
            plt.suptitle("Phase contrast object [%s]" % self.type)

        for i in range(nd):
            if d.ndim == 3:
                if not reuse_axes:
                    plt.subplot(nrows, ncols, i + 1)
                di = d[i]
            else:
                di = d
            if self.type.lower() == 'rgba':
                rgba = complex2rgbalin(di, percentile=(1, 99))
                if reuse_axes:
                    ax = fig.axes[vlabels.index(axes_labels[i])]
                    ax.get_images()[0].set_data(rgba)
                else:
                    ax = plt.imshow(rgba, extent=(x.min(), x.max(), y.min(), y.max())).axes
                    insertColorwheel(left=.02, bottom=.0, width=.1, height=.1, text_col='black', fs=10)
            elif self.type.lower() == 'amplitude':
                di = abs(di)
                vmin, vmax = np.percentile(di, (1, 99))
                if reuse_axes:
                    ax = fig.axes[vlabels.index(axes_labels[i])]
                    ax.get_images()[0].set_data(di)
                    # Update color scale ?
                    vmin0, vmax0 = ax.get_images()[0].get_clim()
                    dv0 = vmax0 - vmin0
                    dv = vmax - vmin
                    m = max(dv0, dv) * 0.05  # min change to trigger rescale
                    if max(abs(vmax0 - vmax), abs(vmin0 - vmin)) > m:
                        ax.get_images()[0].set_clim(vmin, vmax)
                else:
                    ax = plt.imshow(di, extent=(x.min(), x.max(), y.min(), y.max()),
                                    cmap='gray', vmin=vmin, vmax=vmax).axes
                    plt.colorbar()
            else:
                vmin, vmax = np.percentile(di, (1, 99))
                if reuse_axes:
                    ax = fig.axes[vlabels.index(axes_labels[i])]
                    ax.get_images()[0].set_data(di)
                    # Update color scale ?
                    vmin0, vmax0 = ax.get_images()[0].get_clim()
                    dv0 = vmax0 - vmin0
                    dv = vmax - vmin
                    m = max(dv0, dv) * 0.05  # min change to trigger rescale
                    if max(abs(vmax0 - vmax), abs(vmin0 - vmin)) > m:
                        ax.get_images()[0].set_clim(vmin, vmax)
                else:
                    ax = plt.imshow(di, extent=(x.min(), x.max(), y.min(), y.max()),
                                    cmap='gray', vmin=vmin, vmax=vmax).axes
                    plt.colorbar()

            if not reuse_axes:
                if self.idx[i] > 0:
                    plt.title(f"Obj proj #{self.idx[i]}")
                else:
                    plt.title(f"Probe iz={-self.idx[i]}")

                # Add a label to the axes so we know what was plotted
                ax.axes.set_label(axes_labels[i])

                self.post_imshow(pci, x, y, unit_name)
        if not reuse_axes:
            plt.tight_layout()

            if self.type.lower() in ['rgba']:
                insertColorwheel(left=.02, bottom=.0, width=.1, height=.1, text_col='black', fs=10)

        try:
            plt.draw()
            plt.gcf().canvas.draw()
            if plt.get_backend() not in ['ipympl', 'widget']:  # This outputs the graphs again
                plt.pause(.001)
        except:
            pass
        return pci

    def timestamp_increment(self, pci):
        pass


class ShowPsi(OperatorHoloTomo):
    """
    Class to display a Psi array.
    """

    def __init__(self, fig_num=None, iproj=0, iz=0, type='phase', title=None, figsize=None):
        """

        :param i_stack: the index of the stack to display.
        :param fig_num: the matplotlib figure number. if None, a new figure will be created each time.
        :param iproj: the index of the projection to be displayed. If this is a list or array, all projections
                   are shown. If None, all are shown. This can only be used if there is more than 1 projection
                   in the Psi stack.
        :param iz: the index of the distance to be displayed. If this is a list or array, all listed distances
                   are shown. If None, all are shown.
        :param type: what to show. Can be 'phase' (the default), 'amplitude', 'rgba'.
        :param title: the title for the view. If None, a default title will be used.
        :param figsize: if a new figure is created, this will be passed to matplotlib
        """
        super(ShowPsi, self).__init__()
        self.fig_num = fig_num
        self.iproj = iproj
        self.iz = iz
        self.type = type
        self.title = title
        self.figsize = figsize

    def pre_imshow(self, pci: HoloTomo):
        pci._from_pu(psi=True)
        d = pci._psi
        d = np.fft.fftshift(d, axes=(-2, -1))

        if self.iproj is None:
            self.iproj = list(range(d.shape[0]))

        if self.iz is None:
            self.iz = list(range(d.shape[1]))

        if type(self.iproj) is int:
            d = d.take((self.iproj,), axis=0)
        else:
            d = d.take(self.iproj, axis=0)

        if type(self.iz) is int:
            d = d.take((self.iz,), axis=1)
            z = pci.data.detector_distance.take((self.iz,))
        else:
            d = d.take(self.iz, axis=1)
            z = pci.data.detector_distance.take(self.iz)

        if self.fig_num != -1:
            plt.figure(self.fig_num, figsize=self.figsize)
        plt.clf()

        x, y = pci.get_x_y()
        s = np.log10(max(abs(x).max(), abs(y).max()))
        if s < -6:
            unit_name = "nm"
            s = 1e9
        elif s < -3:
            unit_name = u"µm"
            s = 1e6
        elif s < 0:
            unit_name = "mm"
            s = 1e3
        else:
            unit_name = "m"
            s = 1
        return d, z, x * s, y * s, unit_name

    def op(self, pci: HoloTomo):
        d, z, x, y, unit_name = self.pre_imshow(pci)
        if self.title is not None:
            plt.suptitle(self.title)
        else:
            plt.suptitle("Psi [%s]" % self.type)
        nrows, ncols = d.shape[:2]
        for irow in range(nrows):
            for icol in range(ncols):
                plt.subplot(nrows, ncols, irow * ncols + icol + 1)
                di = d[irow, icol]
                if self.type.lower() == 'rgba':
                    rgba = complex2rgbalin(di)
                    plt.imshow(rgba, extent=(x.min(), x.max(), y.min(), y.max()))
                    insertColorwheel(left=.02, bottom=.0, width=.1, height=.1, text_col='black', fs=10)
                elif self.type.lower() == 'amplitude':
                    plt.imshow(np.abs(di), extent=(x.min(), x.max(), y.min(), y.max()), cmap=plt.cm.get_cmap('gray'))
                    plt.colorbar()
                else:
                    plt.imshow(np.angle(di), extent=(x.min(), x.max(), y.min(), y.max()), cmap=plt.cm.get_cmap('gray'))
                    plt.colorbar()

                if icol == 0:
                    plt.ylabel("Y (%s) [view #%d]" % (unit_name, irow))
                if irow == 0:
                    plt.title("Z = %8.5fm" % (z[icol]))
                if irow == nrows - 1:
                    plt.xlabel("X (%s)" % (unit_name))
        plt.tight_layout()
        try:
            plt.draw()
            plt.gcf().canvas.draw()
            if plt.get_backend() not in ['ipympl', 'widget']:  # This outputs the graphs again
                plt.pause(.001)
        except:
            pass

        if self.type.lower() in ['rgba']:
            insertColorwheel(left=.002, bottom=.0, width=.1, height=.1, text_col='black', fs=10)
        return pci

    def timestamp_increment(self, pci):
        pass
