#from numpy.oldnumeric import *
import numpy as n
from xfab import tools,detector
import variables
import sys
from ImageD11 import blobcorrector
import logging
logging.basicConfig(level=logging.DEBUG,format='%(levelname)s %(message)s')

A_id = variables.refarray().A_id

class find_refl:
    def __init__(self,param):
        self.param = param
        self.grain = []
        # determine position of reflections
    
        # Simple transforms of input and set constants
        sintlmin = n.sin(self.param['theta_min']*n.pi/180)/self.param['wavelength']
        sintlmax = n.sin(self.param['theta_max']*n.pi/180)/self.param['wavelength']
        self.K = -2*n.pi/self.param['wavelength']
        self.S = n.array([[1, 0, 0],[0, 1, 0],[0, 0, 1]])
        self.B = n.array(tools.FormB(self.param['unit_cell']))
        self.V = n.array(tools.CellVolume(self.param['unit_cell']))
        self.sbox_y_half = (self.param['sbox_y']-1)/2
        self.sbox_z_half = (self.param['sbox_z']-1)/2

        # Detector tilt correction matrix
        self.R = tools.detect_tilt(self.param['tilt_x'],self.param['tilt_y'],self.param['tilt_z'])

        # Spatial distortion
        if self.param['spatial'] != None:
            self.spatial = blobcorrector.correctorclass(self.param['spatial'])

        # Generate Miller indices for reflections within a certain resolution
        print 'Generating reflections'
        self.hkl  = tools.genhkl(self.param['unit_cell'],self.param['sysconditions'],sintlmin,sintlmax)

        print 'Finished generating reflections\n'
    
    def run(self):
        spot_id = 0
        # Generate orientations of the grains and loop over all grains
        for grainno in range(self.param['no_grains']):
            A = []
            if 'U_grains_%s' %(grainno) in self.param:
                U = n.array(self.param['U_grains_%s' %(grainno)])
                U.shape = (3,3)
                o_type = 'U'
                self.grain.append(variables.grain_cont(U))
            elif 'UBI_grains_%s' %(grainno) in self.param:
                UBI = n.array(self.param['UBI_grains_%s' %(grainno)])
                UBI.shape = (3,3)
                #UBI = n.transpose(UBI)
                self.grain.append(variables.grain_cont(UBI=UBI))
                o_type = 'UBI'

            gr_pos = n.array(self.param['pos_grains'][grainno])
            print 'GRAIN POSITION of grain ',grainno,': ',gr_pos
            print 'GRAIN NO: ',grainno
            self.grain[grainno].pos = gr_pos
            nrefl = 0
  
            # Calculate these values:
            # totalnr, grainno, refno, hkl, omega, 2theta, eta, dety, detz
            # For all reflections in Ahkl that fulfill omega_start < omega < omega_end.
            # All angles in Grain are in degrees
            if o_type == 'U':
                UB = n.dot(U,self.B)
            elif o_type == 'UBI':
                UB = 2*n.pi*n.linalg.inv(UBI)
            else:
                logging.error('No correct grain orientaion type given')
            for hkl in self.hkl:
                Gw = n.dot(self.S,n.dot(UB,hkl))
                tth = tools.tth2(Gw,self.param['wavelength'])
                costth = n.cos(tth)

                (Omega, Eta) = tools.find_omega_wedge(Gw,tth,self.param['wedge'])
  
                if len(Omega) > 0:
                    for solution in range(len(Omega)):
                       omega = Omega[solution]
                       eta   = Eta[solution]
                       for omrange in range(self.param['n_omega_ranges']):
                         if  (self.param['omega_range'][omrange][0]*n.pi/180) < omega and\
                                omega < (self.param['omega_range'][omrange][1]*n.pi/180):
                            Om = tools.OMEGA(omega)
                            Gt = n.dot(Om,Gw)

                            # Calc crystal position at present omega
                            [tx,ty]= n.dot(Om[:2,:2],gr_pos[:2])
                            tz = gr_pos[2]
                            
                            # Calc detector coordinate for peak 
                            (dety, detz) = detector.det_coor(Gt, 
                                                             costth,
                                                             self.param['wavelength'],
                                                             self.param['distance'],
                                                             self.param['y_size'],
                                                             self.param['z_size'],
                                                             self.param['dety_center'],
                                                             self.param['detz_center'],
                                                             self.R,
                                                             tx,ty,tz)



                            if self.param['spatial'] != None :
                                # To match the coordinate system of the spline file
                                (detyd,detzd) = detector.distort([dety,detz],
                                                                 self.param['o11'],
                                                                 self.param['o12'],
                                                                 self.param['o21'],
                                                                 self.param['o22'],
                                                                 self.param['dety_size'],
                                                                 self.param['detz_size'],
                                                                 self.spatial)
                            else:
                                detyd = dety
                                detzd = detz

                             #If shoebox extends outside detector exclude it
                            if ( self.param['sbox_y'] > detyd) or \
                                    (detyd > self.param['dety_size']-self.param['sbox_y']) or \
                                    (self.param['sbox_z'] > detzd) or \
                                    (detzd > self.param['detz_size']-self.param['sbox_z']):
                                 continue
                            

                            frame_center = self.param['frame_range'][omrange][0]+ self.param['omega_sign']*\
                                n.floor((omega*180/n.pi-self.param['omega_range'][omrange][0])/self.param['omega_step'])
                            delta_sbox_omega =  int((self.param['sbox_omega']-1)/2)

                            
                            frame_limits = [frame_center - self.param['omega_sign']*delta_sbox_omega,
                                            frame_center + self.param['omega_sign']*delta_sbox_omega]
                            
                            if  self.param['omega_sign'] > 0:
                                if frame_limits[0] < self.param['frame_range'][omrange][0]:
                                    # start frame is outside (lower side) - set start frame no
                                    # to first possible frame no in frame range 
                                    # make it negative to show that the box is extending outside area
                                    frame_limits[0] = -1*self.param['frame_range'][omrange][0]
                                    border = -1
                                elif frame_limits[1] > self.param['frame_range'][omrange][1]:
                                    # as above but now this peak area extend outside range in the high side 
                                    frame_limits[1] = -1*self.param['frame_range'][omrange][1]
                                    border = 1
                                else:
                                    border = 0
                            else:
                                if frame_limits[0] > self.param['frame_range'][omrange][0]:
                                    # start frame is outside (lower side) - set start frame no
                                    # to first possible frame no in frame range 
                                    # make it negative to show that the box is extending outside area
                                    frame_limits[0] = -1*self.param['frame_range'][omrange][0]
                                    border = -1
                                elif frame_limits[1] < self.param['frame_range'][omrange][1]:
                                    # as above but now this peak area extend outside range in the high side 
                                    frame_limits[1] = -1*self.param['frame_range'][omrange][1]
                                    border = 1
                                else:
                                    border = 0

                            #Polarization factor (Kahn et al, J. Appl. Cryst. (1982) 15, 330-337.)
                            rho = n.pi/2.0 + eta + self.param['beampol_direct']*n.pi/180.0 
                            P = 0.5 * (1 + costth*costth +\
                                        self.param['beampol_factor']*n.cos(2*rho)*n.sin(tth)*n.sin(tth))
                            #Lorentz factor
                            if eta != 0:
                                L=1/(n.sin(tth)*abs(n.sin(eta)))
                            else:
                                L=n.inf;
                            

                            overlaps = 0 # set the number overlaps to zero
                            #logging.debug("frame_center: %i, omega: %f" %(frame_center,omega*180/n.pi))
                            #logging.debug("frame_limits: %i, %i" %(frame_limits[0],frame_limits[1]))
                            A.append([grainno,nrefl,spot_id,
                                      hkl[0],hkl[1],hkl[2],
                                      tth,omega,eta,
                                      dety,detz,
                                      detyd,detzd,
                                      Gw[0],Gw[1],Gw[2],
                                      L,P,
                                      frame_limits[0],frame_limits[1],
                                      border,overlaps])
                            nrefl   += 1
                            spot_id += 1

#           print 'Length of Grain', len(self.grain[0].refl)
            A = n.array(A)
            A = A[n.argsort(A,0)[:,A_id['omega']],:] # sort rows according to omega
            A[:,A_id['ref_id']] = n.arange(nrefl)     # Renumber the reflections  
            A[:,A_id['spot_id']] = n.arange(n.min(A[:,A_id['spot_id']]),n.max(A[:,A_id['spot_id']])+1) # Renumber the spot_id
            self.grain[grainno].refs = A
            

    def overlap(self):
        
        dtth = 1*n.pi/180.  # Don't compare position of refs further apart than dtth 

        # build one big array of reflection info of all grains
        A = self.grain[0].refs
        for grainno in range(1,self.param['no_grains']):
            A = n.concatenate((A,self.grain[grainno].refs))
        logging.debug('Finished concatenating ref arrays')
        A = A[n.argsort(A,0)[:,A_id['tth']],:] # sort rows according to tth
        logging.debug('Sorted full ref array after twotheta')
        nrefl = A.shape[0]
        
        nover=n.zeros((nrefl))
        logging.debug('Ready to compare all %i reflections',nrefl)
        overlaps = dict([(i,[]) for i in range(nrefl)])
        for i in range(1,nrefl):
            if i%1000 == 0:
                logging.debug('Comparing reflection %i', i)
            j=i-1
            while j > -1 and A[i,A_id['tth']]-A[j,A_id['tth']] < dtth :
                if abs(A[i,A_id['omega']]-A[j,A_id['omega']]) < n.pi/180.0*self.param['omega_step']*self.param['sbox_omega']:
#                    if abs(A[i,A_id['detyd']]-A[j,A_id['detyd']]) < self.param['sbox_y']:
#                        if abs(A[i,A_id['detzd']]-A[j,A_id['detzd']]) < self.param['sbox_z']:
                    peak_distance = n.sqrt((A[i,A_id['detyd']]-A[j,A_id['detyd']])**2+\
                        (A[i,A_id['detzd']]-A[j,A_id['detzd']])**2)
                    if peak_distance < (self.param['sbox_y']+self.param['sbox_z'])/2.0:
                            overlaps[A[i,A_id['spot_id']]].append([A[j,A_id['grain_id']],A[j,A_id['ref_id']]])
                            overlaps[A[j,A_id['spot_id']]].append([A[i,A_id['grain_id']],A[i,A_id['ref_id']]])
                            self.grain[int(A[i,A_id['grain_id']])].refs[A[i,A_id['ref_id']],A_id['overlaps']] += 1
                            self.grain[int(A[j,A_id['grain_id']])].refs[A[j,A_id['ref_id']],A_id['overlaps']] += 1
                j = j - 1
        print 'Number of overlaps %i out of %i refl.' %(n.sum(nover),nrefl)
        co = 0
        # How to find the info for reflection with spot_id
        #refl_with_spotid = A[(A[:,A_id['spot_id']]==spot_id),:]
        
        for i in range(nrefl):
            if len(overlaps[i]) > 0:
                co +=1
                #print i, overlaps[i]
        #print co

    def save(self,grainno=None):
        if grainno == None:
            savegrains = range(len(self.grain))
        else:
            savegrains = grainno
        for grainno in savegrains:
            A = self.grain[grainno].refs
            setno = 0
            filename = '%s/%s_gr%0.4d_set%0.4d.ref' \
                %(self.param['direc'],self.param['stem'],grainno,setno)
            f = open(filename,'w')
            format = "%d "*6 + "%f "*12 + "%d "*4 + "\n"
            ( nrefl, ncol ) = A.shape
            out = "#"
            A_col = dict([[v,k] for k,v in A_id.items()])
            for col in A_col:
                out = out + ' %s' %A_col[col]
            out = out +"\n"

            f.write(out)
            for i in range(nrefl):
                out = format %(A[i,A_id['grain_id']],
                               A[i,A_id['ref_id']],
                               A[i,A_id['spot_id']],   
                               A[i,A_id['h']],
                               A[i,A_id['k']],
                               A[i,A_id['l']],
                               A[i,A_id['tth']]*180/n.pi,
                               A[i,A_id['omega']]*180/n.pi,
                               A[i,A_id['eta']]*180/n.pi,
                               A[i,A_id['dety']],
                               A[i,A_id['detz']],
                               A[i,A_id['detyd']],
                               A[i,A_id['detzd']],
                               A[i,A_id['gv1']],
                               A[i,A_id['gv2']],
                               A[i,A_id['gv3']],
                               A[i,A_id['L']],
                               A[i,A_id['P']],
                               A[i,A_id['frame_start']],
                               A[i,A_id['frame_end']],
                               A[i,A_id['border']],
                               A[i,A_id['overlaps']]
                           )
                f.write(out)
            f.close()   
            


#     #save Grain
    
    
    
