# ImageD11_v1.1 Software for beamline ID11
# Copyright (C) 2005 - 2008  Jon Wright
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

import numpy as n
import logging
import math
from ImageD11 import transform, unitcell, columnfile
from ImageD11.parameters import par, parameters

PARAMETERS = [
     par( "omegasign", 1.0, 
          helpstring = "Sign of the rotation about z "+\
              "(normally +1 for right handed)",
          vary=False, 
          can_vary=False),
     par('z_center',  1024.0,
          helpstring = "Beam centre in vertical, pixels",
          vary=True, 
          can_vary=True,
          stepsize = 1.0),
     par('y_center',  1024.0,
          helpstring = "Beam centre in horizontal, pixels",
          vary=True, 
          can_vary=True,
          stepsize = 1.0),
     par('distance',   50000.0,
          helpstring = "sample detector distance, same units as pixel size",
          vary=True, 
          can_vary=True,
          stepsize = 100.0),
     par('z_size',   48.08150,
          helpstring = "pixel size in vertical, same units distance",
          vary=False, 
          can_vary=False), # this could actually vary - a bit crazy?
     par('y_size',   46.77648,
          helpstring = "pixel size in horizontal, same units as distance",
          vary=False, 
          can_vary=False), # this could actually vary - a bit crazy?
     par('tilt_z',    0.0,
          helpstring = "detector tilt, right handed around z",
          vary=True, 
          can_vary=True,
          stepsize = transform.radians(0.1) ),
     par('tilt_y',    0.0,
          helpstring = "detector tilt, right handed around y",
          vary=True, 
          can_vary=True,
          stepsize = transform.radians(0.1) ),
     par('tilt_x',    0.0,
          helpstring = "detector tilt, right handed around x",
          vary=False, 
          can_vary=True,
          stepsize = transform.radians(0.1) ),
     par('fit_tolerance', 0.05,
          helpstring = "tolerance to decide which peaks to use",
          vary=False, 
          can_vary=False),
     par('wavelength', 0.155,
          helpstring = "wavelength, normally angstrom, "+\
             "same as units unit cell ",
          vary=False, 
          can_vary=True,# but you'll be lucky!
          stepsize = 0.0001 ),
     par('wedge',0.0,
          helpstring = "wedge, rotation around y under omega",
          vary=False, 
          can_vary=True,
          stepsize = transform.radians(0.1) ),
     par('chi', 0.0,
          helpstring = "wedge, rotation around x under omega",
          vary=False, 
          can_vary=True,
          stepsize = transform.radians(0.1) ),
     par('cell__a' , 4.1569,
          helpstring = "unit cell par, same units as wavelength",
          vary=False, 
          can_vary=True,
          stepsize = 0.01 ),
     par('cell__b' , 4.1569,
          helpstring = "unit cell par, same units as wavelength",
          vary=False, 
          can_vary=True,
          stepsize = 0.01 ),
     par('cell__c' , 4.1569,
          helpstring = "unit cell par, same units as wavelength",
          vary=False, 
          can_vary=True,
          stepsize = 0.01 ),
     par('cell_alpha' , 90.0,
          helpstring = "unit cell par, degrees",
          vary=False, 
          can_vary=True,
          stepsize = 0.01 ),
     par('cell_beta' , 90.0,
          helpstring = "unit cell par, degrees",
          vary=False, 
          can_vary=True,
          stepsize = 0.01 ),
     par('cell_gamma' , 90.0,
          helpstring = "unit cell par, degrees",
          vary=False, 
          can_vary=True,
          stepsize = 0.01 ),
     par('cell_lattice_[P,A,B,C,I,F,R]', "P",
          helpstring = "lattice centering type. Try P if you are not sure",
          vary=False, 
          can_vary=False),
     par('o11' , 1, 
          helpstring = "detector flip element +1 for frelon & quantix",
          vary=False, 
          can_vary=False ),
     par('o12' , 0, 
          helpstring = "detector flip element 0 for frelon & quantix",
          vary=False, 
          can_vary=False),
     par('o21' , 0, 
          helpstring = "detector flip element 0 for frelon & quantix",
          vary=False, 
          can_vary=False),
     par('o22' , -1, 
          helpstring = "detector flip element -1 for frelon & +1 for quantix",
          vary=False, 
          can_vary=False),
     par('t_x' , 0,
          helpstring = "crystal translation, units as distance/pixels",
          vary=False, 
          can_vary=True,
          stepsize = 1.),
     par('t_y' , 0,
          helpstring = "crystal translation, units as distance/pixels",
          vary=False, 
          can_vary=True,
          stepsize = 1.),
     par('t_z' , 0,
          helpstring = "crystal translation, units as distance/pixels",
          vary=False, 
          can_vary=True,
          stepsize = 1.),
     par('no_bins', 10000,
         helpstring = "Number of bins to use in histogram based filters",
         vary = False,
         can_vary = False ),
     par('min_bin_prob', 1e-5,
         helpstring = "Number of bins to use in histogram based filters",
         vary = False,
         can_vary = False ),
     ]



class transformer:
    """
    Handles the algorithmic, fitting and state information for 
    fitting parameters to give experimental calibrations
    """
    def __init__(self, parfile = None, fltfile = None):
        """
        Nothing is passed in
        ...will need to loadfileparameters and also peaks
        """
        self.unitcell = None
        # this sets defaults according to class dict.
        self.parameterobj = parameters()
        for p in PARAMETERS:
            self.parameterobj.addpar(p)
        # Interesting - is this an alias for the dict?
        self.pars = self.parameterobj.get_parameters()
        self.colfile = None
        self.xname = None
        self.yname = None
        self.omeganame = None
        if parfile is not None:
            self.loadfileparameters(parfile)
        if fltfile is not None:
            self.loadfiltered(fltfile)
   
    def updateparameters(self):
        self.pars = self.parameterobj.get_parameters()

    def get_variable_list(self):
        return self.parameterobj.get_variable_list()

    def getvars(self):
        """ decide what is refinable """
        return self.parameterobj.varylist
    
    def setvars(self, varlist):
        """ set the things to refine """
        self.parameterobj.varylist = varlist

    def loadfiltered(self, filename):
        """
        Read in 3D peaks from peaksearch
        """
        self.colfile = columnfile.columnfile(filename)
        if (self.colfile.titles[0:3] == ["sc","fc","omega"]):
            self.setxyomcols("sc","fc", "omega")
        if (self.colfile.titles[0:3] == ["xc","yc","omega"]):
            self.setxyomcols("xc","yc", "omega")
        if "spot3d_id" not in self.colfile.titles:
            self.colfile.addcolumn( range(self.colfile.nrows),
                                    "spot3d_id")
        
    def setxyomcols(self, xname, yname, omeganame):
        self.xname = xname
        self.yname = yname
        self.omeganame = omeganame
        logging.warning("titles are %s  %s  %s"%(self.xname,
                                                 self.yname,
                                                 self.omeganame))

    def getcols(self):
        return self.colfile.titles

    def loadfileparameters(self,filename):
        """ Read in beam center etc from file """
        self.parameterobj.loadparameters(filename)

    def saveparameters(self,filename):
        """ Save beam center etc to file """
        self.parameterobj.saveparameters(filename)

    def applyargs(self,args):
        """ for use with simplex/gof function, alter parameters """
        self.parameterobj.set_variable_values(args)


    def getcolumn(self, name):
        """Return the data"""
        return self.colfile.getcolumn(name)

    def addcolumn(self, col, name):
        """Return the data"""
        return self.colfile.addcolumn(col, name)

    def compute_tth_eta(self):
        """ Compute the twotheta and eta for peaks previous read in """
        if None in [self.xname, self.yname]:
            raise Exception("No peaks loaded")
        peaks = [self.getcolumn(self.xname),
                 self.getcolumn(self.yname)]
        peaks_xyz =  transform.compute_xyz_lab( peaks,
                                                **self.parameterobj.get_parameters() ) 
        # Store these in the columnfile
        self.addcolumn( peaks_xyz[0], "xl" )
        self.addcolumn( peaks_xyz[1], "yl" )
        self.addcolumn( peaks_xyz[2], "zl" )
        # Get the Omega name?
        omega = self.getcolumn( self.omeganame )
        tth, eta = transform.compute_tth_eta_from_xyz( 
            peaks_xyz,
            omega,
            **self.parameterobj.get_parameters() ) 
        self.addcolumn( tth , "tth" )
        self.addcolumn( eta , "eta" )
        return tth, eta

    def compute_histo(self, colname):
        """ Compute the histogram over twotheta for peaks previous read in
        Filtering is moved to a separate function """
        if colname not in self.colfile.titles:
            raise Exception("Cannot find column "+colname)
        bins, hist, hpk = transform.compute_tth_histo( self.getcolumn(colname),
                                             **self.parameterobj.get_parameters())
        self.addcolumn(hpk, colname+"_hist_prob")
        return bins, hist

    def compute_tth_histo(self):
        """ Give hardwire access to tth """
        if "tth" not in self.colfile.titles:
            self.compute_tth_eta()
        return self.compute_histo("tth")

    def filter_min(self, col, minval):
        """
        and filter peaks out falling in bins with less than min_prob
        """
        if "tth_hist_prob" not in self.colfile.titles:
            self.compute_tth_histo()    
        mask = self.colfile.getcolumn("tth_hist_prob") > minval
        logging.info("Number of peaks before filtering = %d"%(
            self.colfile.nrows))
        self.colfile.filter( mask )
        logging.info("Number of peaks after filtering = %d"%(
            self.colfile.nrows))

    def tth_entropy(self):
        """
        Compute the entropy of the two theta values
        ... this may depend on the number of tth bins (perhaps?)
        ... to be used for cell parameter free line straightening
        S = \sum -p log p
        """
        if "tth_hist_prob" not in self.colfile.titles:
            self.compute_tth_histo()
        hpk = self.getcolumn("tth_hist_prob")
        lp = n.log(hpk)
        entropy = n.sum( - hpk * lp )
        return entropy

    def gof(self, args):
        """ Compute how good is the fit of obs/calc peak positions in tth """
        self.applyargs(args)
        # Here, pars is a dictionary of name/value pairs to pass to compute_tth_eta
        tth, eta = self.compute_tth_eta()
        w = self.wavelength
        gof = 0.
        npeaks = 0
        for i in range(len(self.tthc)):# (twotheta_rad_cell.shape[0]):
            self.tthc[i]= transform.degrees(math.asin(self.fitds[i]*w/2)*2)
            diff = n.take(tth, self.indices[i]) - self.tthc[i]
#         print "peak",i,"diff",maximum.reduce(diff),minimum.reduce(diff)
            gof = gof + n.sum(diff*diff)
            npeaks = npeaks + len(diff)
        gof = gof / npeaks
        return gof*1e3

    def fit(self,tthmin=0,tthmax=180):
        """ Apply simplex to improve fit of obs/calc tth """
        tthmin = float(tthmin)
        tthmax = float(tthmax)
        
        import simplex
        if self.theoryds == None:
            self.addcellpeaks()
        # Assign observed peaks to rings
        self.wavelength = None
        self.indices=[]  # which peaks used
        self.tthc=[]     # computed two theta values
        self.fitds=[]    # hmm?
        self.fit_tolerance=1.
        pars = self.parameterobj.get_parameters()
        w = float(pars['wavelength'])
        self.wavelength = w
        self.fit_tolerance = float(pars['fit_tolerance'])
        print "Tolerance for assigning peaks to rings",\
            self.fit_tolerance,", min tth",tthmin,", max tth",tthmax
        tth, eta = self.compute_tth_eta()
        for i in range(len(self.theoryds)):
            dsc=self.theoryds[i]
            tthcalc=math.asin(dsc*w/2)*360./math.pi # degrees
            if tthcalc>tthmax:
                break
            elif tthcalc < tthmin:
                continue
            logicals= n.logical_and( n.greater(tth, 
                                               tthcalc-self.fit_tolerance),
                                     n.less(tth , 
                                            tthcalc+self.fit_tolerance)  )

            if sum(logicals)>0:
                self.tthc.append(tthcalc)
                self.fitds.append(dsc)
                ind=n.compress(logicals,range(len(tth)))
                self.indices.append(ind)
        guess = self.parameterobj.get_variable_values()
        inc = self.parameterobj.get_variable_stepsizes()
        if len(guess) == 0:
            # There is nothing to fit.
            logging.warning("You try to fit with no variables!?")
            return None
        s=simplex.Simplex(self.gof,guess,inc)
        newguess,error,niter=s.minimize()
        inc=[v/10 for v in inc]
        guess=newguess
        s=simplex.Simplex(self.gof,guess,inc)
        newguess,error,niter=s.minimize()
        self.parameterobj.set_variable_values(newguess)
        self.gof(newguess) 
        print newguess


    def addcellpeaks(self, limit=None):
        """
        Adds unit cell predicted peaks for fitting against
        Optional arg limit gives highest angle to use

        Depends on parameters:
            'cell__a','cell__b','cell__c', 'wavelength' # in angstrom
            'cell_alpha','cell_beta','cell_gamma' # in degrees
            'cell_lattice_[P,A,B,C,I,F,R]'
        """
        #
        # Given unit cell, wavelength and distance, compute the radial positions
        # in microns of the unit cell peaks
        #
        pars = self.parameterobj.get_parameters()
        cell = [ pars[name] for name in ['cell__a','cell__b','cell__c',
                                        'cell_alpha','cell_beta','cell_gamma']]
        lattice = pars['cell_lattice_[P,A,B,C,I,F,R]']
        self.unitcell = unitcell.unitcell(cell, lattice)
        if "tth" not in self.colfile.titles:
            self.compute_tth_eta()
        # Find last peak in radius
        if limit is None:
            highest = n.maximum.reduce( self.getcolumn("tth") )
        else:
            highest = limit
        w = pars['wavelength']
        ds = 2*n.sin(transform.radians(highest)/2.) / w
        self.dslimit = ds
        self.theorypeaks = self.unitcell.gethkls(ds)
        self.unitcell.makerings(ds)
        self.theoryds = self.unitcell.ringds
        tths = [n.arcsin(w*dstar/2)*2 
                for dstar in self.unitcell.ringds]
        self.theorytth = transform.degrees(n.array(tths))

    def computegv(self):
        """
        Using tth, eta and omega angles, compute x,y,z of spot
        in reciprocal space
        """
        if "tth" not in self.colfile.titles:
            self.compute_tth_eta()
        pars = self.parameterobj.get_parameters()
        if pars.has_key("omegasign"):
            om_sgn = pars["omegasign"]
        else:
            om_sgn = 1.0
        gv = transform.compute_g_vectors(
            self.getcolumn("tth"),
            self.getcolumn("eta"),
            self.getcolumn("omega")*om_sgn,           
            **pars)
        self.addcolumn(gv[0],"gx")
        self.addcolumn(gv[1],"gy")
        self.addcolumn(gv[2],"gz")

    def getaxis(self):
        """
        Compute the rotation axis
        This handles omegasign in a more elegant way
        """
        # unit vector along z
        v = n.array([0,0,1],n.float)
        p = self.parameterobj.get("omegasign")
        return v*p
        
    def savegv(self,filename):
        """
        Save g-vectors into a file
        Use crappy .ass format from previous for now (testing)
        """
        #        self.parameterobj.update_other(self)
        if "gz" not in self.colfile.titles:
            self.computegv()
        if self.unitcell is None:
            self.addcellpeaks()
        f = open(filename,"w")
        f.write(self.unitcell.tostring())
        f.write("\n")
        pars = self.parameterobj.get_parameters()
        f.write("# wavelength = %f\n"%( float(pars['wavelength']) ) )
        f.write("# wedge = %f\n"%( float(pars['wedge']) ))
        # Handle the axis direction somehow
        f.write("# axis %f %f %f\n"%tuple(self.getaxis()))
        # Put a copy of all the parameters in the gve file
        pkl = pars.keys()
        pkl.sort()
        for k in pkl:
            f.write("# %s = %s \n"%(k,pars[k]))
        f.write("# ds h k l\n")
        for peak in self.theorypeaks:
            f.write("%10.7f %4d %4d %4d\n"%(peak[0],
                                            peak[1][0],
                                            peak[1][1],
                                            peak[1][2]))
        tth = self.getcolumn("tth")
        ome = self.getcolumn("omega")
        eta = self.getcolumn("eta")
        gx  = self.getcolumn("gx")
        gy  = self.getcolumn("gy")
        gz  = self.getcolumn("gz")
        x = self.getcolumn(self.xname)
        y = self.getcolumn(self.yname)
        spot3d_id = self.getcolumn("spot3d_id")
        xl = self.getcolumn("xl")
        yl = self.getcolumn("yl")
        zl = self.getcolumn("zl")
        order = n.argsort(tth)
        f.write("#  gx  gy  gz  xc  yc  ds  eta  omega  spot3d_id  xl  yl  zl\n")
        print n.maximum.reduce(ome),n.minimum.reduce(ome)
        ds = 2*n.sin(transform.radians(tth/2))/pars["wavelength"]
        fmt = "%f "*8+"%d "+"%f "*3+"\n"
        for i in order:
            f.write(fmt % (gx[i], gy[i], gz[i],
                           x[i], y[i],
                           ds[i], eta[i], ome[i],
                           spot3d_id[i],
                           xl[i], yl[i], zl[i] ))
        f.close()

    def write_colfile(self,filename):
        """
        Save out the column file (all info stored we hope)
        """
        self.colfile.parameters = self.parameterobj
        self.colfile.writefile(filename)

    def write_graindex_gv(self,filename):
        from ImageD11 import write_graindex_gv
        if ("gx" not in self.colfile.titles):
            self.compute_gv()
        gv = [ self.getcolumn("gx"),
               self.getcolumn("gy"),
               self.getcolumn("gz") ] 

        if ("sum_intensity" in self.colfile.titles):
            ints = self.getcolumn("sum_intensity")
        elif ("avg_intensity" in self.colfile.titles) and \
             ("Number_of_pixels" in self.colfile.titles):
            ints = self.getcolumn("sum_intensity")* \
                   self.getcolumn("Number_of_pixels")
        elif ("avg_intensity" in self.colfile.titles) and \
             ("npixels" in self.colfile.titles):
            ints = self.getcolumn("sum_intensity")* \
                   self.getcolumn("npixels")
        else:
            ints = n.zeros( self.colfile.nrows )

        pars = self.parameterobj.get_parameters()
        if pars.has_key("omegasign"):
            om_sgn = pars["omegasign"]
        else:
            om_sgn = 1.0
                           
        write_graindex_gv.write_graindex_gv(filename,
                                            n.array(gv),
                                            self.getcolumn("tth"),
                                            self.getcolumn("eta"),
                                            self.getcolumn("omega") * om_sgn,
                                            ints,
                                            self.unitcell)
        
