import numpy as n

from xfab import tools
from xfab import detector
from fabio import edfimage,tifimage

import time 
import sys


class make_profiles:
	def __init__(self,param,grainno=0,setno=0):
	    self.param = param
	    self.grainno = grainno
	    self.param
	    self.setno = setno
	    self.K = -2*n.pi/self.param['wavelength']
	    self.R = tools.detect_tilt(self.param['tilt_x'],self.param['tilt_y'],self.param['tilt_z'])

	def readref(self):
            self.ref = file_io.readref(self.param['direc'],self.param['stem'])
	    self.ref._open(self.grainno,self.setno)
            self.ref.read(self.grainno,self.setno)
	    # now we have self.ref.refinfo
	    # and self.ref.colinfo

        def setup_odf(self):
		
            odf_scale = self.param['odf_scale'] 
            if self.param['odf_type'] == 1:
                odf_spread = self.param['mosaicity']/4
                odf_spread_grid = odf_spread/odf_scale
                sigma = odf_spread_grid*n.ones(3)
                r1_max = n.ceil(3*odf_spread_grid)
                r1_range = r1_max*2 + 1
                r2_range = r1_max*2 + 1
                r3_range = r1_max*2 + 1
                mapsize = r1_range*n.ones(3)
                odf_center = r1_max*n.ones(3)

		print 'size of ODF map', mapsize
                self.odf = generate_grains.gen_odf(sigma,odf_center,mapsize)
		#from pylab import *
		#imshow(self.odf[:,:,odf_center[2]])
		#show()
            elif self.param['odf_type'] == 3:
                odf_spread = self.param['mosaicity']/4
                odf_spread_grid = odf_spread/odf_scale
                r1_max = n.ceil(3*odf_spread_grid)
                r2_max = n.ceil(3*odf_spread_grid)
                r3_max = n.ceil(3*odf_spread_grid)
                r1_range = r1_max*2 + 1
                r2_range = r2_max*2 + 1
                r3_range = r3_max*2 + 1
		print 'size of ODF map', r1_range*n.ones(3)
                odf_center = r1_max*n.ones(3)
		self.odf= n.zeros((r1_range,r2_range,r3_range))
		# Makes spheric ODF for debug purpuses
                for i in range(self.odf.shape[0]):
                    for j in range(self.odf.shape[1]):
                        for k in range(self.odf.shape[2]):
                            r = [i-(r1_max), j-(r2_max), k-(r3_max)]
                            if n.linalg.norm(r) > r1_max:
                                 self.odf[i,j,k] = 0
			    else:
                                 self.odf[i,j,k] = 1
		#from pylab import *
		#imshow(self.odf[:,:,r3_max],interpolation=None)
		#show()
            elif self.param['odf_type'] == 2:
                file = self.param['odf_file']
                print 'Read ODF from file_ %s' %file
		file = open(file,'r')
		(r1_range, r2_range, r3_range) = file.readline()[9:].split()
		r1_range = int(r1_range)
		r2_range = int(r2_range)
		r3_range = int(r3_range)
		odf_scale = float(file.readline()[10:])
		oneD_odf = n.fromstring(file.readline(),sep=' ')
		elements = r1_range*r2_range*r3_range
		self.odf = oneD_odf[:elements].reshape(r1_range,r2_range,r3_range)
		if self.param['odf_sub_sample'] > 1:
			sub =self.param['odf_sub_sample']
			print 'subscale =',sub
			r1_range_sub = r1_range * self.param['odf_sub_sample']
			r2_range_sub = r2_range * self.param['odf_sub_sample']
			r3_range_sub = r3_range * self.param['odf_sub_sample']
			odf_fine = n.zeros((r1_range_sub,r2_range_sub,r3_range_sub))
			for i in range(r1_range):
				for j in range(r2_range):
					for k in range(r3_range):
						odf_fine[i*sub:(i+1)*sub,
							 j*sub:(j+1)*sub,
							 k*sub:(k+1)*sub] = self.odf[i,j,k]
			self.odf = odf_fine.copy()/(sub*sub*sub)
			r1_range = r1_range_sub 
			r2_range = r2_range_sub 
			r3_range = r3_range_sub
			odf_scale = odf_scale/sub
			print 'odf_scale', odf_scale

                #[r1_range, r2_range, r3_range] = self.odf.shape
                odf_center = [(r1_range)/2, r2_range/2, r3_range/2]
		
		print odf_center
		#self.odf[:,:,:] = 0.05
		print self.odf.shape

		#from pylab import *
		#imshow(self.odf[:,:,odf_center[2]])
		#show()
            self.Uodf = n.zeros(r1_range*r2_range*r3_range*9).\
		reshape(r1_range,r2_range,r3_range,3,3)
	    if self.param['odf_cut'] != None:
		    self.odf_cut = self.odf.max()*self.param['odf_cut']
	    else:
		    self.odf_cut = 0.0
            for i in range(self.odf.shape[0]):
                for j in range(self.odf.shape[1]):
                    for k in range(self.odf.shape[2]):
                        r = odf_scale*n.pi/180.*\
                            n.array([i-odf_center[0],
                                     j-odf_center[1],
                                     k-odf_center[2]])

                        self.Uodf[i,j,k,:,:] = tools.rod2U(r)
	    
	    if self.param['odf_type'] !=  2:
		    file = open(self.param['stem']+'.odf','w')
		    file.write('ODF size: %i %i %i\n' %(r1_range,r2_range,r3_range))
		    file.write('ODF scale: %f\n' %(odf_scale))
		    for i in range(int(r1_range)):
			    self.odf[i,:,:].tofile(file,sep=' ',format='%f')
			    file.write(' ')
		    file.close()
	    
	    return self.Uodf




	def run(self):
            from scipy import sparse
	    from scipy import ndimage

	    #make stack of empty images as a dictionary of sparse matrices
	    #print 'Build sparse image stack'
            #stacksize = len(self.graindata.frameinfo)
	    #self.frames = {}
	    #for i in range(stacksize):
	    #	    self.frames[i]=sparse.lil_matrix((int(self.param['dety_size']),
	    #					      int(self.param['detz_size'])))


	    # Open files


            filename = '%s/%s_gr%0.4d_set%0.4d.prf' \
                %(self.param['direc'],self.param['stem'],self.grainno,self.setno)
            prf_out = open(filename,'w')
            # write types to file
            param_out = ['sbox_omega','sbox_y','sbox_z']
            write_param = '# '
            for p in param_out:
              write_param = write_param + '%s ' %(p)
            prf_out.write(write_param+'\n')
            write_param = ''
	    for p in param_out:
	       write_param = write_param + '%i ' %(self.param[p])
            prf_out.write(write_param+'\n')
  

            # loop over grains

	    for grainno in range(self.param['no_grains']):
		    gr_pos = n.array(self.param['pos_grains'][grainno])
		    B = tools.FormB(self.param['unit_cell'])
		    U = n.array(self.param['U_grains_%s' %grainno])
		    U.shape = (3,3)

		    omega_start  = self.param['omega_range'][0][0]
		    omega_end  = self.param['omega_range'][0][1]
		    # loop over reflections for each grain
		    for nref in range(len(self.ref.refinfo)):
			    # exploit that the reflection list is sorted according to omega
			    shoebox = n.zeros((self.param['sbox_omega'],
					       self.param['sbox_y'],
					       self.param['sbox_z']),dtype=n.float32)
			    sbox_yh = (self.param['sbox_y']-1)/2
			    sbox_zh = (self.param['sbox_z']-1)/2


			    print 'Doing reflection: %i' %nref
			    intensity = 1000
        
			    hkl = n.array([self.ref.refinfo[nref,self.ref.colinfo['h']],
					   self.ref.refinfo[nref,self.ref.colinfo['k']],
					   self.ref.refinfo[nref,self.ref.colinfo['l']]])

			    aoi_y = n.int(round(self.ref.refinfo[nref,self.ref.colinfo['dety']]))+\
				n.array([-sbox_yh,sbox_yh+1])
			    aoi_z = n.int(round(self.ref.refinfo[nref,self.ref.colinfo['detz']]))+\
				n.array([-sbox_zh,sbox_zh+1])
			    
			    om_center = self.ref.refinfo[nref,self.ref.colinfo['omega']]*n.pi/180.0

			    Gc  = n.dot(B,hkl)
			    for i in range(self.odf.shape[0]):
				    for j in range(self.odf.shape[1]):
					    for k in range(self.odf.shape[2]):
                                              if self.odf[i,j,k] > self.odf_cut:
                                                Gtmp = n.dot(self.Uodf[i,j,k],Gc)
						Gw =   n.dot(U,Gtmp)
						Glen = n.sqrt(n.dot(Gw,Gw))
						tth = 2*n.arcsin(Glen/(2*abs(self.K)))
						costth = n.cos(tth)
						(Omega, eta) = tools.find_omega_wedge(Gw,
										      tth,
										      self.param['wedge'])
						try:
							minpos = n.argmin(n.abs(Omega*(180.0/n.pi)-self.ref.refinfo[nref,self.ref.colinfo['omega']]))
						except:
							print Omega
						if len(Omega) == 0:
							continue
						omega = Omega[minpos]
						# if omega not in rotation range continue to next step
						if (omega_start*n.pi/180) > omega or\
						       omega > (omega_end*n.pi/180):
							continue
						Om = tools.OMEGA(omega)
						Gt = n.dot(Om,Gw)
						
	                                        # Calc crystal position at present omega
						[tx,ty]= n.dot(Om[:2,:2],gr_pos[:2])
						tz = gr_pos[2]

						(dety, detz) = detector.det_coor(Gt, 
										 costth,
										 self.param['wavelength'],
										 self.param['distance'],
										 self.param['y_size'],
										 self.param['z_size'],
										 self.param['dety_center'],
										 self.param['detz_center'],
										 self.R,
										 tx,ty,tz)


# 						if self.param['spatial'] != None :
# 							# To match the coordinate system of the spline file
# 							# SPLINE(i,j): i = detz; j = (dety_size-1)-dety
# 							# Well at least if the spline file is for frelon2k
# 							(x,y) = detector.detyz2xy([dety,detz],
# 										  self.param['o11'],
# 										  self.param['o12'],
# 										  self.param['o21'],
# 										  self.param['o22'],
# 										  self.param['dety_size'],
# 										  self.param['detz_size'])
# 							# Do the spatial distortion
# 							(xd,yd) = self.spatial.distort(x,y)

# 							# transform coordinates back to dety,detz
# 							(dety,detz) = detector.xy2detyz([xd,yd],
# 											  self.param['o11'],
# 											  self.param['o12'],
# 											  self.param['o21'],
# 											  self.param['o22'],
# 											  self.param['dety_size'],
# 											  self.param['detz_size'])
								   
						dety = int(round(dety))-aoi_y[0]
						detz = int(round(detz))-aoi_z[0]
						frame_no = self.param['start_frame'][grainno] + n.floor((omega*180/n.pi-omega_start)/\
									   self.param['omega_step'])
						wlayer = int(frame_no - self.ref.refinfo[nref,self.ref.colinfo['frame_start']])

						if dety > -1 and dety < self.param['sbox_y'] and\
							    detz > -1 and detz < self.param['sbox_z'] and\
							    wlayer > -1 and wlayer < self.param['sbox_omega']:
							shoebox[wlayer,dety,detz] = shoebox[wlayer,dety,detz]+ intensity*self.odf[i,j,k]
			    self.writeprf(prf_out,shoebox,grainno,nref)
		    prf_out.close()
				    
	def correct_image(self):
              for frame_no in self.frames:
		      t1 = time.clock()

		      frame = self.frames[frame_no].toarray()
		      if self.param['bg'] > 0:
			      frame = frame + self.param['bg']*n.ones((self.param['dety_size'],
								       self.param['detz_size']))
		      # add noise
		      if self.param['noise'] != 0:
			      frame = n.random.poisson(frame)
		      # apply psf
		      if self.param['psf'] != 0:
			      frame = ndimage.gaussian_filter(frame,self.param['psf']*0.5)
	              # limit values above 16 bit to be 16bit
		      frame = n.clip(frame,0,2**16-1)
	              # convert to integers
		      frame = n.uint16(frame)
		      #flip detector orientation according to input: o11, o12, o21, o22
		      frame = detector.trans_orientation(frame,
							 self.param['o11'],
							 self.param['o12'],
							 self.param['o21'],
							 self.param['o22'],
							 'inverse')
		      # Output frames 
		      if '.edf' in self.param['output']:
			      self.write_edf(frame_no,frame)
		      if '.tif' in self.param['output']:
			      self.write_tif(frame_no,frame)
		      print '\rDone frame %i took %8f s' %(frame_no+1,time.clock()-t1),
		      sys.stdout.flush()
				

	def writeprf(self, file = None, peak = None, grainno = None, reflno = None, format= '%i'):
                #write shoebox as one string of number separated by a space
                #no of elements is sbox_omega*sbox_y*sbox_z
                #fastest zdet_relative, 
                #medium  ydet_relative,
		#slow    omega layer
                #To read shoebox 
                #1d_shoebox =n.fromstring(file.readline,sep=' ',dtype=dtype)
                #shoebox = 1d_shoebox.reshape(sbox_omega,sbox_y,sbox_z)
		file.write('REFL_ID = %i\n' %self.ref.refinfo[reflno,self.ref.colinfo['ref_id']])
		file.write('SPOT_ID = %i\n' %self.ref.refinfo[reflno,self.ref.colinfo['spot_id']])
		for i in range(self.param['sbox_omega']):
                       #file.write('WLAYER = %i\n' %i)
			if len(peak[i]) == 0:
				peak[i] = n.zeros((self.param['sbox_y'],self.param['sbox_z']))
			peak[i].tofile(file,sep=' ')
			file.write(' ')
		file.write('\n')


