import numpy as n
from xfab import tools,detector
import variables
import sys
from ImageD11 import blobcorrector
import logging
logging.basicConfig(level=logging.DEBUG,format='%(levelname)s %(message)s')
import time

A_id = variables.refarray().A_id

class find_refl:
    def __init__(self,param):
        self.param = param
        self.grain = []
        # determine position of reflections
    
        # Simple transforms of input and set constants
        sintlmin = n.sin(self.param['theta_min']*n.pi/180)/self.param['wavelength']
        sintlmax = n.sin(self.param['theta_max']*n.pi/180)/self.param['wavelength']
        self.K = -2*n.pi/self.param['wavelength']
        self.S = n.array([[1, 0, 0],[0, 1, 0],[0, 0, 1]])
        #V seems not to used anywhere 
        #self.V = n.array(tools.cell_volume(self.param['unit_cell']))
        self.sbox_y_half = (self.param['sbox_y']-1)/2
        self.sbox_z_half = (self.param['sbox_z']-1)/2

        # Detector tilt correction matrix
        self.R = tools.detect_tilt(self.param['tilt_x'],self.param['tilt_y'],self.param['tilt_z'])

        # wedge NB! wedge is in degrees
        # The sign is reversed for wedge as the parameter in 
        # tools.find_omega_general is right handed and in ImageD11
        # it is left-handed (at this point wedge is defined as in ImageD11)
        self.wy = -1.*self.param['wedge']*n.pi/180.
        self.wx = self.param['chi']*n.pi/180.
 
        # Spatial distortion
        if self.param['spatial'] != None:
            self.spatial = blobcorrector.correctorclass(self.param['spatial'])

        # Generate Miller indices for reflections within a certain resolution
        print 'Generating reflections'
        self.hkl  = tools.genhkl(self.param['unit_cell'],self.param['sysconditions'],sintlmin,sintlmax,self.param['crystal_system'])

        print 'Finished generating reflections\n'
    
    def run(self):
        spot_id = 0
        # Generate orientations of the grains and loop over all grains
        for grainno in range(self.param['no_grains']):
            A = []
            grain_id = self.param['grain_list'][grainno] 
            if 'U_grains_%s' %(grain_id) in self.param:
                U = n.array(self.param['U_grains_%s' %(grain_id)])
                U.shape = (3,3)
                o_type = 'U'
                self.grain.append(variables.grain_cont(U))
            elif 'UBI_grains_%s'%(grain_id) in self.param:
                UBI = n.array(self.param['UBI_grains_%s' %(grain_id)])
                UBI.shape = (3,3)
                self.grain.append(variables.grain_cont(UBI=UBI))
                o_type = 'UBI'
            gr_pos = self.param['pos_grains_%s' %(grain_id)]
            print 'GRAIN POSITION of grain ',grain_id,': ',gr_pos
            print 'GRAIN NO: ',grainno
            self.grain[grainno].id = grain_id
            self.grain[grainno].pos = gr_pos
            nrefl = 0
  
            # Calculate these values:
            # totalnr, grainno, refno, hkl, omega, 2theta, eta, dety, detz
            # For all reflections in Ahkl that fulfill omega_start < omega < omega_end.
            # All angles in Grain are in degrees
            if o_type == 'U':
                B_mat = n.array(tools.form_b_mat(self.param['unit_cell']))
                UB = n.dot(U,B_mat)
            elif o_type == 'UBI':
                UB = n.linalg.inv(UBI)*2*n.pi
            else:
                logging.error('No correct grain orientaion type given')
            for hkl in self.hkl:
                Gw = n.dot(self.S,n.dot(UB,hkl))
                tth = tools.tth2(Gw,self.param['wavelength'])
                costth = n.cos(tth)
                Qw = Gw*self.param['wavelength']/(4.*n.pi)
 
#                (Omega, Eta) = tools.find_omega_wedge(Gw,tth,self.param['wedge']*n.pi/180.0)
                (Omega, Eta) = tools.find_omega_general(Qw,
                                                        tth,
                                                        self.wx,
                                                        self.wy)

                if len(Omega) > 0:
                    for solution in range(len(Omega)):
                       omega = Omega[solution]
                       eta   = Eta[solution]
                       for omrange in range(self.param['n_omega_ranges']):
                         if  (self.param['omega_range'][omrange][0]*n.pi/180) < omega and\
                                omega < (self.param['omega_range'][omrange][1]*n.pi/180):
                             Om = tools.form_omega_mat_general(omega,self.wx,self.wy) 
                             Gt = n.dot(Om,Gw)

                             # Calc crystal position at present omega
                             [tx,ty]= n.dot(Om[:2,:2],gr_pos[:2])
                             tz = gr_pos[2]
                             
                             # Calc detector coordinate for peak 
                             (dety, detz) = detector.det_coor(Gt, 
                                                              costth,
                                                              self.param['wavelength'],
                                                              self.param['distance'],
                                                              self.param['y_size'],
                                                              self.param['z_size'],
                                                              self.param['dety_center'],
                                                              self.param['detz_center'],
                                                              self.R,
                                                              tx,ty,tz)



                             if self.param['spatial'] != None :
                                 # To match the coordinate system of the spline file
                                 (detyd,detzd) = detector.distort([dety,detz],
                                                                  self.param['o11'],
                                                                  self.param['o12'],
                                                                  self.param['o21'],
                                                                  self.param['o22'],
                                                                  self.param['dety_size'],
                                                                  self.param['detz_size'],
                                                                  self.spatial)
                             else:
                                 detyd = dety
                                 detzd = detz

                             #If shoebox extends outside detector exclude it
                             if ( self.param['sbox_y'] > detyd) or \
                                     (detyd > self.param['dety_size']-self.param['sbox_y']) or \
                                     (self.param['sbox_z'] > detzd) or \
                                     (detzd > self.param['detz_size']-self.param['sbox_z']):
                                 continue
                            

                             frame_center = self.param['frame_range'][omrange][0]+ self.param['omega_sign']*\
                                 n.floor((omega*180/n.pi-self.param['omega_range'][omrange][0])/self.param['omega_step'])
                             delta_sbox_omega =  int((self.param['sbox_omega']-1)/2)

                            
                             frame_limits = [frame_center - self.param['omega_sign']*delta_sbox_omega,
                                             frame_center + self.param['omega_sign']*delta_sbox_omega]
                            
                             if  self.param['omega_sign'] > 0:
                                 if frame_limits[0] < self.param['frame_range'][omrange][0]:
                                     # start frame is outside (lower side) - set start frame no
                                     # to first possible frame no in frame range 
                                     # make it negative to show that the box is extending outside area
                                     frame_limits[0] = -1*self.param['frame_range'][omrange][0]
                                     border = -1
                                 elif frame_limits[1] > self.param['frame_range'][omrange][1]:
                                     # as above but now this peak area extend outside range in the high side 
                                     frame_limits[1] = -1*self.param['frame_range'][omrange][1]
                                     border = 1
                                 else:
                                     border = 0
                             else:
                                 if frame_limits[0] > self.param['frame_range'][omrange][0]:
                                     # start frame is outside (lower side) - set start frame no
                                     # to first possible frame no in frame range 
                                     # make it negative to show that the box is extending outside area
                                     frame_limits[0] = -1*self.param['frame_range'][omrange][0]
                                     border = -1
                                 elif frame_limits[1] < self.param['frame_range'][omrange][1]:
                                     # as above but now this peak area extend outside range in the high side 
                                     frame_limits[1] = -1*self.param['frame_range'][omrange][1]
                                     border = 1
                                 else:
                                     border = 0

                             # Polarization factor (Kahn et al, J. Appl. Cryst. (1982) 15, 330-337.)
                             rho = n.pi/2.0 + eta + self.param['beampol_direct']*n.pi/180.0 
                             P = 0.5 * (1 + costth*costth +\
                                            self.param['beampol_factor']*n.cos(2*rho)*n.sin(tth)*n.sin(tth))
                             # Lorentz factor
                             if eta != 0:
                                 L=1/(n.sin(tth)*abs(n.sin(eta)))
                             else:
                                 L=n.inf;
                            

                             overlaps = 0 # set the number overlaps to zero
                             # logging.debug("frame_center: %i, omega: %f" %(frame_center,omega*180/n.pi))
                             # logging.debug("frame_limits: %i, %i" %(frame_limits[0],frame_limits[1]))
                             A.append([grain_id,nrefl,spot_id,
                                       hkl[0],hkl[1],hkl[2],
                                       tth,omega,eta,
                                       dety,detz,
                                       detyd,detzd,
                                       Gw[0],Gw[1],Gw[2],
                                       L,P,
                                       frame_limits[0],frame_limits[1],
                                       border,overlaps,frame_center])
                             nrefl   += 1
                             spot_id += 1

            #           print 'Length of Grain', len(self.grain[0].refl)
            A = n.array(A)
            A = A[n.argsort(A,0)[:,A_id['omega']],:] # sort rows according to omega
            A[:,A_id['ref_id']] = n.arange(nrefl)     # Renumber the reflections  
            A[:,A_id['spot_id']] = n.arange(n.min(A[:,A_id['spot_id']]),n.max(A[:,A_id['spot_id']])+1) # Renumber the spot_id
            self.grain[grainno].refs = A

    def overlap_new(self):
        """
        New routine for identifying possible spot overlaps. This routine is much faster 
        than the "overlaps", but does not provide information about which specific spots
        are overlapped, but merely flags a spot as overlapped or not.

        This routine sets up an array over all images and adds the value of 1 at the position of each spot
        A filter (elliptic, user defines the size in pixels and image length) is made with center at the peak
        position.
        if the sum of filter * image is larger than 1 it meens the the center of another (or more) spots 
        is found within the area considered to give overlaps.

        """

        from scipy import sparse
        
        # build one big array of reflection info of all grains
        A = self.grain[0].refs
        for grainno in range(1,self.param['no_grains']):
            A = n.concatenate((A,self.grain[grainno].refs))
        
        #initialize
        nrefl = A.shape[0]
        nover=n.zeros((nrefl))
        overlaps = dict([(i,[]) for i in range(nrefl)])
        logging.debug('Ready to compare all %i reflections',nrefl)

        # Make filter 
        y_max  = self.param['overlap_specs'][1]*2+1
        z_max  = self.param['overlap_specs'][2]*2+1
        om_max = self.param['overlap_specs'][0]*2+1
        y_cen = (y_max-1)/2
        z_cen = (z_max-1)/2
        om_cen = (om_max-1)/2

        filter = n.zeros((y_max,z_max))

        for i in range(y_max):
            for j in range(z_max):
                if (i-y_cen)**2/float(y_cen)**2 + (j-z_cen)**2/float(z_cen)**2 <= 1:
                    filter[i,j] = 1

        # Loop over rotation ranges
        nranges = len(self.param['frame_range'])
        for omrange in range(nranges):
            images = range(self.param['frame_range'][omrange][0],
                           self.param['frame_range'][omrange][1]+1)

            # make A array for only including reflections within this rotation range 
            Atest = A[A[:,22] >= self.param['frame_range'][omrange][0]]
            Atest = Atest[Atest[:,22] <= self.param['frame_range'][omrange][1]]
            nrefl = Atest.shape[0]

            # make stack of empty images as a dictionary of sparse matrices
            stacksize = len(self.param['frame_omega'])
            frames= n.zeros((stacksize,int(self.param['dety_size']),int(self.param['detz_size'])),n.int8)

            # add one in center position of reflections into the image stack. 
            for i in range(nrefl):
                frameindex =  images.index(int(Atest[i,22]))
                frames[frameindex,int(Atest[i,A_id['detyd']]), int(Atest[i,A_id['detzd']])] += 1
         
            # determine if reflections are overlapped
            for i in range(nrefl):
                filter_use = filter.copy()
                frameindex =  images.index(int(Atest[i,22]))
                yc = int(Atest[i,A_id['detyd']])
                zc = int(Atest[i,A_id['detzd']])
                om_1 = frameindex-om_cen
                om_2 = frameindex+om_cen+1
                y_1 = yc - y_cen
                y_2 = yc + y_cen + 1
                z_1 = zc - z_cen
                z_2 = zc + z_cen + 1

                # Check if filterbox is extending outside detector images 
                # if so change filterbox accordingly
                if om_1 < 0 : 
                    om_1 = 0
                if om_2 > stacksize+1 : 
                    om_2 = stacksize+1
                if y_1 < 0: 
                    filter_use = filter_use[abs(y_1):,:]
                    y_1 = 0
                if y_2 > self.param['dety_size']+1: 
                    filter_use = filter_use[:self.param['dety_size']+1-y_2 ,:]
                    self.param['dety_size']+1
                if z_1 < 0: 
                    filter_use = filter_use[:,abs(z_1):]
                    z_1 = 0
                if z_2 > self.param['detz_size'] + 1:
                    filter_use = filter_use[:,:self.param['detz_size']+1-z_2]
                    z_2 = self.param['detz_size']+1

                box = frames[om_1:om_2, y_1:y_2, z_1:z_2]

                # Calculate no of overlaps with this reflection
                no_over = (box*filter_use).sum()-1

                # if any put overlap flag in reflection list
                if no_over > 0:
                    self.grain[int(Atest[i,A_id['grain_id']])].refs[Atest[i,A_id['ref_id']],
                                                                    A_id['overlaps']] += 1
            

    def overlap(self):
        
        dtth = 1*n.pi/180.  # Don't compare position of refs further apart than dtth 

        # build one big array of reflection info of all grains
        A = self.grain[0].refs
        for grainno in range(1,self.param['no_grains']):
            A = n.concatenate((A,self.grain[grainno].refs))
        logging.debug('Finished concatenating ref arrays')
        A = A[n.argsort(A,0)[:,A_id['tth']],:] # sort rows according to tth
        logging.debug('Sorted full ref array after twotheta')
        nrefl = A.shape[0]
        
        nover=n.zeros((nrefl))
        logging.debug('Ready to compare all %i reflections',nrefl)
        overlaps = dict([(i,[]) for i in range(nrefl)])
        for i in range(1,nrefl):
            if i%1000 == 0:
                logging.debug('Comparing reflection %i', i)
            j=i-1
#            print A[i,A_id['tth']],A[j,A_id['tth']],dtth
            while j > -1 and A[i,A_id['tth']]-A[j,A_id['tth']] < dtth :
                if abs(A[i,A_id['omega']]-A[j,A_id['omega']]) < n.pi/180.0*self.param['omega_step']*self.param['sbox_omega']:
#                    if abs(A[i,A_id['detyd']]-A[j,A_id['detyd']]) < self.param['sbox_y']:
#                        if abs(A[i,A_id['detzd']]-A[j,A_id['detzd']]) < self.param['sbox_z']:
                    peak_distance = n.sqrt((A[i,A_id['detyd']]-A[j,A_id['detyd']])**2+\
                        (A[i,A_id['detzd']]-A[j,A_id['detzd']])**2)
                    if peak_distance < (self.param['sbox_y']+self.param['sbox_z'])/2.0:
                            overlaps[A[i,A_id['spot_id']]].append([A[j,A_id['grain_id']],
                                                                   A[j,A_id['ref_id']]])
                            overlaps[A[j,A_id['spot_id']]].append([A[i,A_id['grain_id']],
                                                                   A[i,A_id['ref_id']]])
                            self.grain[int(A[i,A_id['grain_id']])].refs[A[i,A_id['ref_id']],
                                                                        A_id['overlaps']] += 1
                            self.grain[int(A[j,A_id['grain_id']])].refs[A[j,A_id['ref_id']],
                                                                        A_id['overlaps']] += 1
                            nover[i] = 1
                            nover[j] = 1

                j = j - 1
        print 'Number of overlaps %i out of %i refl.' %(n.sum(nover),nrefl)
        co = 0
        # How to find the info for reflection with spot_id
        #refl_with_spotid = A[(A[:,A_id['spot_id']]==spot_id),:]
        #print A
#        for i in range(nrefl):
#            if len(overlaps[i]) > 0:
#                co +=1
#                print i, overlaps[i]


    def save(self,grainno=None):
        if grainno == None:
            savegrains = range(len(self.grain))
        else:
            savegrains = grainno
        for grainno in savegrains:
            A = self.grain[grainno].refs
            grain_id = self.grain[grainno].id
            setno = 0
            filename = '%s/%s_gr%0.4d_set%0.4d.ref' \
                %(self.param['direc'],self.param['stem'],grain_id,setno)
            f = open(filename,'w')
            format = "%d "*6 + "%f "*12 + "%d "*4 + "\n"
            ( nrefl, ncol ) = A.shape
            out = "#"
            A_col = dict([[v,k] for k,v in A_id.items()])
            for col in A_col:
                out = out + ' %s' %A_col[col]
            out = out +"\n"

            f.write(out)
            for i in range(nrefl):
                out = format %(A[i,A_id['grain_id']],
                               A[i,A_id['ref_id']],
                               A[i,A_id['spot_id']],   
                               A[i,A_id['h']],
                               A[i,A_id['k']],
                               A[i,A_id['l']],
                               A[i,A_id['tth']]*180/n.pi,
                               A[i,A_id['omega']]*180/n.pi,
                               A[i,A_id['eta']]*180/n.pi,
                               A[i,A_id['dety']],
                               A[i,A_id['detz']],
                               A[i,A_id['detyd']],
                               A[i,A_id['detzd']],
                               A[i,A_id['gv1']],
                               A[i,A_id['gv2']],
                               A[i,A_id['gv3']],
                               A[i,A_id['L']],
                               A[i,A_id['P']],
                               A[i,A_id['frame_start']],
                               A[i,A_id['frame_end']],
                               A[i,A_id['border']],
                               A[i,A_id['overlaps']]
                           )
                f.write(out)
            f.close()   
            


#     #save Grain
    
    
    
