#! /opt/local/bin/python
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr
from __future__ import division

import sys
import timeit
import time
import numpy as np
import h5py as h5
from pynx.cdi.selection import match_shape, match2
from pynx.utils.math import ortho_modes
from pynx.version import __version__

params = {'modes': False, 'movie': False, 'movie.type': 'complex', 'movie.cmap': 'viridis', 'subpixel': False}

helptext = """
pynx-cdi-analysis: script to analyze a series of CDI reconstructions

Example:
    pynx-cdi-analysis *.cxi modes

command-line arguments:
    path/to/Run.cxi: path to cxi file with the reconstructions to analyze. Several can be supplied [mandatory]

    modes: if used, analyse the CDI reconstructions and decompose them into eigen-values. The first mode should
           represent most of the intensity.
    
    subpixel: if True, use subpixel registration to align the different solutions for modes analysis.
    
    movie: if used with 3D input data, a movie will be made, either from the single CXI or h5 (modes) data file,
           or from the first two files (if the .h5 mode file is listed, it is always considered first).
           Some options can be given:
           - movie=complex: to display the complex 3d data (the default)
           - movie=amplitude: to display the amplitude
           - movie=amplitude,grey: to display the amplitude using a grayscale rather than the default colormap.
                                   possible options are 'grey' and 'grey_r'. Otherwise viridis is used

"""

if __name__ == '__main__':
    t_start = time.time()
    cxi_files = []
    h5_input = None
    for arg in sys.argv:
        if arg == 'help':
            print(helptext)
        elif arg in ['modes', 'subpixel']:
            params[arg] = True
        elif 'movie' in arg:
            params[arg] = True
            if 'complex' in arg:
                params['movie.type'] = 'complex'
            else:
                params['movie.type'] = 'amplitude'

            if 'gray_r' in arg or 'grey_r' in arg:
                params['movie.cmap'] = 'gray_r'
            elif 'gray' in arg or 'grey' in arg:
                params['movie.cmap'] = 'gray'
        else:
            if len(arg) > 4:
                if arg[-4:] in ['.cxi', '.npz']:
                    cxi_files.append(arg)
            if len(arg) > 3:
                if arg[-3:] == '.h5':
                    h5_input = arg

    print("Importing data files")
    vd = {}
    best_llk = None
    best_object = cxi_files[0]
    for s in cxi_files:
        print("Loading: %s" % s)
        if s[-4:] == '.npz':
            for k, v in np.load(s).items():
                vd[s] = v
                if v.size > 1000:
                    break
        else:
            h = h5.File(s, 'r')
            vd[s] = h['entry_1/image_1/data'][()]
            if 'entry_1/image_1/process_1' in h:
                for k, v in h['entry_1/image_1/process_1'].items():
                    if 'note_' in k:
                        if 'description' in v:
                            if v['description'][()] == 'llk_free(Poisson)':
                                llk = float(v['data'][()])
                                if best_llk is not None:
                                    if best_llk > llk:
                                        best_llk = llk
                                        best_object = s
                                else:
                                    best_llk = llk
                                    best_object = s
        if vd[s].size > 1e7:
            # Too large, try cropping
            d = vd[s]
            threshold = 0.15
            shape0 = d.shape
            ad = np.abs(d)
            sup = (ad > (ad.max() * threshold)).astype(np.int16)
            del ad
            margin = 2
            if d.ndim == 3:
                l0 = np.nonzero(sup.sum(axis=(1, 2)))[0].take([0, -1]) + np.array([-margin, margin])
                if l0[0] < 0: l0[0] = 0
                if l0[1] >= sup.shape[0]: l0[1] = -1

                l1 = np.nonzero(sup.sum(axis=(0, 2)))[0].take([0, -1]) + np.array([-margin, margin])
                if l1[0] < 0: l1[0] = 0
                if l1[1] >= sup.shape[1]: l1[1] = -1

                l2 = np.nonzero(sup.sum(axis=(0, 1)))[0].take([0, -1]) + np.array([-margin, margin])
                if l2[0] < 0: l2[0] = 0
                if l2[1] >= sup.shape[2]: l2[1] = -1
                d = d[l0[0]:l0[1], l1[0]:l1[1], l2[0]:l2[1]]
            else:
                l0 = np.nonzero(sup.sum(axis=1))[0].take([0, -1]) + np.array([-margin, margin])
                if l0[0] < 0: l0[0] = 0
                if l0[1] >= sup.shape[0]: l0[1] = -1

                l1 = np.nonzero(sup.sum(axis=0))[0].take([0, -1]) + np.array([-margin, margin])
                if l1[0] < 0: l1[0] = 0
                if l1[1] >= sup.shape[1]: l1[1] = -1

                d = d[l0[0]:l0[1], l1[0]:l1[1]]
            shape1 = d.shape
            print("   ...array too large, cropping at %5.3f*max, shape:" % threshold, shape0, " -> ", shape1)
            vd[s] = d

    if params['modes']:
        print('Matching arrays against the first one [%s] - this may take a while' % (best_object))
        r = vd[best_object]
        del vd[best_object]
        v = list(vd.values())
        v.insert(0, r)
        v = match_shape(v, method='median')
        if params['subpixel']:
            upsample_factor = 20
        else:
            upsample_factor = 1
        v2 = [v[0]]
        t0 = timeit.default_timer()
        for i in range(1, len(v)):
            d = v[i]
            d1c, d2c, r = match2(v[0], d, match_phase=True, upsample_factor=upsample_factor, verbose=False)
            print("R_match(%d, %d) = %6.3f%% [%d arrays remaining]" % (0, i, r * 100, len(v) - i - 1))
            v2.append(d2c)

        dt = timeit.default_timer() - t0
        print('Elapsed time: %6.1fs' % dt)
        print("Analysing modes")
        vo = ortho_modes(v2)
        p = (abs(vo[0]) ** 2).sum() / (abs(vo) ** 2).sum()
        print("First mode represents %6.3f%%" % (p * 100))

        # Also compute average
        vave = v2[0].copy()
        for i in range(1, len(v)):
            vave += v2[i]
        vave /= len(v)

        # Save result to hdf5
        h5_out_filename = 'modes.h5'
        h5_input = h5_out_filename
        print('Saving modes analysis to: %s' % h5_out_filename)
        f = h5.File(h5_out_filename, "w")

        # NeXus
        f.attrs['default'] = 'entry_1'
        f.attrs['creator'] = 'PyNX'
        f.attrs['HDF5_Version'] = h5.version.hdf5_version
        f.attrs['h5py_version'] = h5.version.version

        entry_1 = f.create_group("entry_1")
        entry_1.create_dataset("program_name", data="PyNX %s" % (__version__))
        entry_1.create_dataset("start_time", data=time.strftime("%Y-%m-%dT%H:%M:%S%z", time.localtime(t_start)))
        entry_1.attrs['NX_class'] = 'NXentry'
        entry_1.attrs['default'] = 'data_1'

        image_1 = entry_1.create_group("image_1")
        image_1.create_dataset("data", data=vo, chunks=True, shuffle=True, compression="gzip")
        image_1.attrs['NX_class'] = 'NXdata'  # Is image_1 a well-formed NXdata or not ?
        image_1.attrs['signal'] = 'data'
        image_1.attrs['interpretation'] = 'image'
        image_1.attrs['label'] = 'modes'

        command = ""
        for arg in sys.argv:
            command += arg + " "
        data_1 = f['/entry_1/image_1']
        process_1 = data_1.create_group("process_1")
        process_1.create_dataset("command", data=command)

        # Add shortcut to the main data
        data_1 = entry_1.create_group("data_1")
        data_1["data"] = h5.SoftLink('/entry_1/image_1/data')
        data_1.attrs['NX_class'] = 'NXdata'  # Is image_1 a well-formed NXdata or not ?
        data_1.attrs['signal'] = 'data'
        data_1.attrs['interpretation'] = 'image'

        image_2 = entry_1.create_group("image_2")
        image_2.create_dataset("data", data=vave, chunks=True, shuffle=True, compression="gzip")
        image_2.attrs['NX_class'] = 'NXdata'  # Is image_1 a well-formed NXdata or not ?
        image_2.attrs['signal'] = 'data'
        image_2.attrs['label'] = 'average of solutions'


        f.close()

    if params['movie']:
        import matplotlib

        matplotlib.use("Agg", warn=False)
        import matplotlib.pyplot as plt
        import matplotlib.animation as manimation
        from pynx.utils.plot_utils import complex2rgbalin, insertColorwheel

        # Make a movie going through 3d slices, comparing two objects if at least 2 are listed
        if h5_input is not None:
            # Use the first mode
            o1 = h5.File(h5_input)['entry_1/image_1/data'][0]
            o1n = h5_input
            o2 = list(vd.values())[0]
            o2n = list(vd.keys())[0]
        else:
            o1 = list(vd.values())[0]
            o1n = list(vd.keys())[0]
            if len(vd) > 1:
                o2 = list(vd.values())[1]
                o2n = list(vd.keys())[1]
            else:
                o2 = None

        if o1.ndim != 3:
            print('Movie generation from CXI data only supported for 3D objects')
            exit()

        try:
            FFMpegWriter = manimation.writers['ffmpeg']
        except:
            print("Could not import FFMpeg writer for movie generation")
            exit()

        metadata = dict(title='3D CDI slices', artist='PyNX')
        writer = FFMpegWriter(fps=5, metadata=metadata)
        fontsize = 10
        if o2 is None:
            fig = plt.figure(figsize=(6, 5))
            o1m = np.abs(o1).max()
            with writer.saving(fig, "cdi-3d-slices.mp4", dpi=100):
                for i in range(len(o1)):
                    if (i % 10) == 0:
                        print(i)
                    plt.clf()
                    plt.title("%s - #%3d" % (o1n, i), fontsize=fontsize)
                    if params['movie.type'] == 'amplitude':
                        plt.imshow(abs(o1[i]), vmin=0, vmax=o1m, cmap=params['movie.cmap'])
                    else:
                        plt.imshow(complex2rgbalin(o1[i], smin=0, alpha=(0, np.abs(o1[i]).max() / o1m)))
                        insertColorwheel(left=0.85, bottom=.0, width=.1, height=.1, text_col='black', fs=10)
                    writer.grab_frame()
        else:
            print('Matching shape and orientation of objects for 3D CDI movie')
            o1, o2 = match_shape([o1, o2], method='median')
            o1, o2, r = match2(o1, o2, match_phase=True, verbose=False)
            print("R_match = %6.3f%% " % (r * 100))

            fig = plt.figure(figsize=(12, 5))

            o1m = np.abs(o1).max()
            o2m = np.abs(o2).max()

            with writer.saving(fig, "cdi-3d-slices.mp4", dpi=100):
                for i in range(len(o1)):
                    if (i % 10) == 0:
                        print(i)
                    plt.clf()
                    plt.subplot(121)
                    plt.title("%s" % o1n, fontsize=fontsize)
                    if params['movie.type'] == 'amplitude':
                        plt.imshow(abs(o1[i]), vmin=0, vmax=o1m, cmap=params['movie.cmap'])
                    else:
                        plt.imshow(complex2rgbalin(o1[i], smin=0, alpha=(0, np.abs(o1[i]).max() / o1m)))

                    plt.subplot(122)
                    plt.title("%s" % o2n, fontsize=fontsize)
                    plt.suptitle("%3d" % i)
                    if params['movie.type'] == 'amplitude':
                        plt.imshow(abs(o2[i]), vmin=0, vmax=o2m, cmap=params['movie.cmap'])
                    else:
                        plt.imshow(complex2rgbalin(o2[i], smin=0, alpha=(0, np.abs(o2[i]).max() / o2m)))
                        insertColorwheel(left=0.90, bottom=.0, width=.1, height=.1, text_col='black', fs=10)
                    writer.grab_frame()
