
import numpy as n
from scipy.sparse import csc_matrix
from Fabric import td_dda,file_io,pcgls
from xfab import tools
from time import clock

class reconstruct:
    def __init__(self,param=None,grain=None,grainno=None,reconst=None):
        self.param = param
        self.grain = grain
        self.grainno = grainno
        self.B = tools.FormB(self.param['unit_cell'])
        self.reconst = reconst

    def setup(self):
        uvmaps = file_io.readuvm(self.param['direc'],self.param['stem'])
        uvmaps.readall(self.grainno,0)
        self.uvmapsfull = uvmaps.uvm.copy()
        subsize = (2*self.param['u_max']+1)/self.param['uv_bin_factor']
        scale = self.param['uv_bin_factor']
        if self.param['uv_bin_factor'] > 1:
            no_maps = uvmaps.uvm.shape[0]
            self.uvm2 = n.zeros((no_maps,subsize,subsize))
            self.uvm3 = n.zeros((no_maps,subsize,subsize))
            for mm in range(uvmaps.uvm.shape[0]): 
                for i in range(subsize):
                    for j in range(subsize):
                        uvmtmp = uvmaps.uvm[mm].copy()
                        self.uvm2[mm,i,j] = uvmtmp[i*scale:(i+1)*scale,j*scale:(j+1)*scale].sum()
                #STILL HAVE'NT DETERMINED TO REASON FOR THE NEED OF ROTATING 90
                #THE UV MAPS - It equals n.transpose(n.fliplr(uvm)) in MATLAB
                self.uvm3[mm] = n.rot90(self.uvm2[mm].copy())
                
            uvmaps.uvm2 = self.uvm3

        self.nr_refl = len(self.reconst)
        #uv_max = n.array([self.param['u_max'],self.param['v_max']])
        uv_max = n.array([(subsize-1)/2,
                          (subsize-1)/2])
        print uv_max
        self.odf_range = n.array(self.param['odf_range'])
        uv_sub_scale = self.param['uv_sub_scale']
        self.odf_scale = self.param['uv_bin_factor']*self.param['u_scale']*180.0/n.pi/2.0
        
        spot_id = n.array(uvmaps.spot_id)[self.reconst]
        self.uvmap = uvmaps.uvm2[self.reconst].copy()
        
        #Make hkl
        hkl = []
        for sid in spot_id:
            h = self.grain[self.grainno].retrieve('spot_id',sid,'h')
            k = self.grain[self.grainno].retrieve('spot_id',sid,'k')
            l = self.grain[self.grainno].retrieve('spot_id',sid,'l')
            hkl.append([h,k,l])
        hkl = n.array(hkl)
        

        self.b = make_b_odf(self.uvmap)
        self.A = make_A_odf(self.B,hkl,self.odf_range,uv_max, uv_sub_scale)


        
    def run(self):
        print 'Reconstructing ODF.......'
        (self.x,self.r) = pcgls.pcgls(self.A,self.b,self.param['niter'],2)

    def optimal(self):
        uv_range = (2*self.param['u_max']+1)/self.param['uv_bin_factor']
        self.opt = pcgls.stoppingrule(self.A,
                                   self.b,
                                   self.x,
                                   uv_range,
                                   self.nr_refl)

    def save_sol(self):

        filename = '%s/%s_gr%0.4d.x' %(self.param['direc'],
                                       self.param['stem'],
                                       self.grainno)
        file = open(filename,'w')
        file.write('ODF size: %i %i %i\n' %(self.odf_range[0],
                                            self.odf_range[1],
                                            self.odf_range[2]))
        file.write('ODF scale: %f\n' %(self.odf_scale))
        file.write('Iterations: %i\n' %(self.param['niter']))
        for i in range(self.param['niter']):
            self.x[:,i].tofile(file,sep=' ',format='%f')
            file.write('\n')
        file.close()
        
    def save_odf(self):
        try:
            # reshaping should make the axis correct - 
            # in opposition to MATLAB, where a permutation was needed
            self.odf = self.x[:,int(self.opt)].reshape(self.odf_range[0],
                                                     self.odf_range[1],
                                                     self.odf_range[2])
        except:
            self.optimal()
            self.odf = self.x[:,int(self.opt)].reshape(self.odf_range[0],
                                                     self.odf_range[1],
                                                     self.odf_range[2])

        filename = '%s/%s_gr%0.4d.odf' %(self.param['direc'],
                                         self.param['stem'],
                                         self.grainno)
        file = open(filename,'w')
        file.write('ODF size: %i %i %i\n' %(self.odf_range[0],
                                            self.odf_range[1],
                                            self.odf_range[2]))
        file.write('ODF scale: %f\n' %(self.odf_scale))
        for i in range(int(self.odf_range[0])):
            self.odf[i,:,:].tofile(file,sep=' ',format='%f')
            file.write(' ')
        file.close()

def make_b_odf(uvmap):
    print 'Generating b vector...'

    uv_range = uvmap.shape[1:]
    k=0
    b = n.zeros(n.prod(uvmap.shape))
    print 'b.shape',b.shape
    for refl_nr in range(uvmap.shape[0]):  #loop over reflections
        #read in the raw file
        D = uvmap[refl_nr].copy()
        #Normalise such that sum(b(k)) for one image is 1.  
        D = D/D.sum()

        #Fill the data into the b vector
        for mm in range(uv_range[0]):
            for nn in range(uv_range[1]):
                b[k] = D[mm,nn]
                k=k+1
    return b
   
#b(1+nr_pix*(refl_nr-1):nr_pix*refl_nr) = reshape(D',1,nr_pix)




def make_A_odf(B, hkl, r_range, uv_max, uv_sub_scale):
    """
    Setup geometric matrix A relating ODF voxel grid x and uv-maps b
    """
    #Initialisations
    #print 'Initializing...'
    uv_range = uv_max*2+1
    uv_sub_max = n.round( (uv_sub_scale-1.)/2. )
    axes = n.ones(3)
    grid_origo = 0.5*r_range
    N = n.prod(r_range)

    z = n.array([0,0,1],dtype=n.float)
    print 'Generating A matrix...'
    A = csc_matrix((uv_range[0]*uv_range[1]*hkl.shape[0],N))
    #outer loop over number of reflections
    for refl_nr in range(hkl.shape[0]):
        print 'Projecting uvmap no %i of %i' %(refl_nr,hkl.shape[0])
        t1 = clock()
        Gc0 = n.dot(B,hkl[refl_nr,:])
        y0 = Gc0/n.linalg.norm(Gc0)
        u = n.cross(y0, z)
        u = u/n.linalg.norm(u)
        v = n.cross(u, y0)
#        testuv = n.zeros(uv_range)
#        myodf = n.zeros(r_range)
        for i in range(uv_range[0]):
            for j in range(uv_range[1]):
                for i1 in range(uv_sub_scale):
                    for j1 in range(uv_sub_scale):
                        dy = (i-uv_max[0]+(i1-uv_sub_max)/uv_sub_scale)*u +\
                             (j-uv_max[1]+(j1-uv_sub_max)/uv_sub_scale)*v
                        p_start = y0+dy
                        p_end = -y0+dy
                        pixels = td_dda.td_dda(p_start,p_end,axes,grid_origo,r_range)
                        mm = refl_nr*uv_range[0]*uv_range[1] + i*uv_range[1] + j
                        if pixels != None:
                            for np in range(pixels.shape[0]):
                                nn = pixels[np,0]*r_range[1]*r_range[2] +\
                                     pixels[np,1]*r_range[2] + \
                                     pixels[np,2]

#                                myodf[pixels[np,0],
#                                      pixels[np,1],
#                                      pixels[np,2]] = myodf[pixels[np,0],
#                                                            pixels[np,1],
#                                                            pixels[np,2]] + pixels[np,3]

                                A[mm,nn] =A[mm,nn] + pixels[np,3]
        print 'That took %f seconds' %(clock()-t1)

#                                testuv[i,j] = testuv[i,j] + 1

        #Fill in the Wrefl into the joint matrix W
        #start = refl_nr *uv_range[0]*uv_range[1]
        #slut =  (refl_nr+1)*uv_range[0]*uv_range[1]
        #W[start:slut,:] = Wsub
    #end the loop over reflections
    #Scale W to correct for the uv_sub_scale
#     from pylab import *
#     figure(1)
#     imshow(testuv,interpolation='nearest')
#     show()
#     print 'VIRKER?'
#     print testuv
    A = A/(uv_sub_scale*uv_sub_scale)
    return A
