import numpy as n
import logging
logging.basicConfig(level=logging.DEBUG,format='%(levelname)s %(message)s')

from fabio import openimage
import variables
    
A_id = variables.refarray().A_id

class harvest_refl:
    def __init__(self,param=None,grain=None,frames=None):
        self.param = param
        self.grain = grain
        self.frames = frames
        if self.grain != None:
            self.ngrains = len(self.grain)
        if self.param['flood'] != None:
            self.flood = openimage.openimage(self.param['flood']).data
        if self.param['dark'] != None:
            self.dark = openimage.openimage(self.param['dark']).data


        self.numtype = {n.dtype('int32')  : '%i',
                        n.dtype('int64')  : '%i',
                        n.dtype('uint16') : '%i',
                        n.dtype('uint32') : '%i',
                        n.dtype('uint64') : '%i',
                        n.dtype('float')  : '%f',
                        n.dtype('float32'): '%f',
                        n.dtype('float64'): '%f'}
        self.uint_types = [n.dtype('uint16'),
                           n.dtype('uint32'), 
                           n.dtype('uint64')]

    def frameinfo(self):
        # Fill in frameinfo
        for grainno in range(len(self.grain)):
            #print self.grain[grainno].no
            ( nrefl, ncol ) = self.grain[grainno].refs.shape
            for i in range(nrefl):
                for j in range(int(abs(self.grain[grainno].refs[i,A_id['frame_start']])), 
                               int(abs(self.grain[grainno].refs[i,A_id['frame_end']]))+self.param['omega_sign'],self.param['omega_sign']):
                    if j in self.param['frame_numbers']:
                        self.frames[self.param['frame_numbers'].index(j)].nrefl +=1
                        self.frames[self.param['frame_numbers'].index(j)].refs.append([grainno,i])
                        #print i,self.frames[self.param['frame_numbers'].index(j)].name, self.frames[self.param['frame_numbers'].index(j)].omega

    def run(self):

        # Open files
        setno = 0
        hstout = [None]*self.ngrains
        sbox_yh = (self.param['sbox_y']-1)/2
        sbox_zh = (self.param['sbox_z']-1)/2

        for grainno in range(self.ngrains):
            filename = '%s/%s_gr%0.4d_set%0.4d.hst' \
                %(self.param['direc'],self.param['stem'],self.grain[grainno].id,setno)
            hstout[grainno] = open(filename,'w')
            # write types to file
            param_out = ['sbox_omega','sbox_y','sbox_z']
            write_param = '# '
            for p in param_out:
              write_param = write_param + '%s ' %(p)
            hstout[grainno].write(write_param+'\n')
            write_param = ''
            for p in param_out:
              write_param = write_param + '%i ' %(self.param[p])
            hstout[grainno].write(write_param+'\n')
            
        # Make temporary shoebox container
        tmppeaks = []
        for i in range(self.ngrains):
            tmppeaks.append(variables.grain_cont())
            for j in range(len(self.grain[i].refs)):
                tmppeaks[i].generic_add()
                tmppeaks[i].generic[j].peak = [n.array([],dtype=n.float32)]*self.param['sbox_omega']

        # Here starts the harvesting routine
        for i in range(len(self.frames)):
            logging.info('Harvesting frame: %s' %self.frames[i].name)
            img = openimage.openimage(self.frames[i].name)
            img = img.data
            if self.param['dark'] != None:
                if img.dtype in self.uint_types:
                    img = img.astype(n.int32)
                    img = img - self.dark
            if self.param['flood'] != None:
                img = img.astype(n.float32)
                img = img/self.flood
            if self.param['darkoffset'] != None:
                img = img + self.param['darkoffset']

            # Flip image to match standard orientation 
            img = self.detector_flips(img)
            # Determine numeric type of first image
            if i == 0:
                format = self.numtype[img.dtype]
            for ref in self.frames[i].refs:
                grainno = int(ref[0])
                reflno =  int(ref[1])
                
                # Determine no of omega layer using either frame_end or frame_start depending
                # On whether the box extends out side area of interest
#                if self.param['omega_sign'] > 0:

                if self.grain[grainno].refs[reflno,A_id['border']] < 0:
#                if self.grain[grainno].refs[reflno,A_id['frame_start']] < 0:
                    wlayer = int(self.param['sbox_omega']-1- self.param['omega_sign']*\
                                     (self.grain[grainno].refs[reflno,A_id['frame_end']]-self.param['frame_numbers'][i]))
                else:
                    wlayer = self.param['omega_sign']*(int(self.param['frame_numbers'][i]-abs(self.grain[grainno].refs[reflno,A_id['frame_start']])))


                # Determine Area-of-Interst 
                aoi_y = n.int(round(self.grain[grainno].refs[reflno,A_id['detyd']]))+\
                    n.array([-sbox_yh,sbox_yh+1])
                aoi_z = n.int(round(self.grain[grainno].refs[reflno,A_id['detzd']]))+\
                    n.array([-sbox_zh,sbox_zh+1])

                # Peak out AOI from image
                tmppeaks[grainno].generic[reflno].peak[wlayer] = img[(slice(aoi_y[0],aoi_y[1]),
                                                                      slice(aoi_z[0],aoi_z[1]))]

                if wlayer == self.param['sbox_omega']-1 or self.param['frame_numbers'][i] == abs(self.grain[grainno].refs[reflno,A_id['frame_end']]) :
                    self.writehst(hstout[grainno],\
                                tmppeaks[grainno].generic[reflno].peak,grainno,reflno)
                    # Remove peak info from container 
                    tmppeaks[grainno].generic[reflno].peak = []
            
        logging.info("Finished harvesting reflection intensities")
        
        # Close all open hst files
        for grainno in range(self.ngrains):
            hstout[grainno].close()
            

    def detector_flips(self,img):
        #  Detector_orientation [[o11,o12],[o21,o22]]
        #  if pretransposed to get img -> img(dety,detz) then
        #        [[o11,o12],[o21,o22]]
        #        [[ -1,  0],[  0,  1]]  => flipud
        #        [[  1,  0],[  0, -1]]  => fliplr
        #        [[  0,  1],[  1,  0]]  => transpose
        #        [[  0, -1],[ -1,  0]]  => transpose fliplr flipud
        #        [[  0, -1],[  1,  0]]  => flipud transpose 
        #        [[  0,  1],[ -1,  0]]  => fliplr transpose
        if abs(self.param['o11']) == 1:
            if (abs(self.param['o22']) != 1) or \
                    (self.param['o12'] != 0) or \
                    (self.param['o21'] != 0):
                raise ValueError, 'detector orientation makes no sense 1'
            img = n.transpose(img) # to get A[i,j] be standard A[dety,detz] 
#            if self.param['o11'] == -1:
#                img = n.flipud(img)
#            if self.param['o22'] == -1:
#                img = n.fliplr(img)
            if self.param['o11'] == -1:
                img = n.fliplr(img)
            if self.param['o22'] == -1:
                img = n.flipud(img)
            return img
        if abs(self.param['o12']) == 1:
            if abs(self.param['o21']) != 1 or \
                    (self.param['o11'] != 0) or \
                    (self.param['o22'] != 0):
                raise ValueError, 'detector orientation makes no sense 2'
            #n.transpose no need since the matrix is transp from scratch
            if self.param['o12'] == -1:
                img = n.flipud(img)
            if self.param['o21'] == -1:
                img = n.fliplr(img)
            return img
        raise ValueError, 'detector orientation makes no sense 3'

    def writehst(self, file = None, peak = None, grainno = None, reflno = None, format= '%i'):
        #write shoebox as one string of number separated by a space
        #no of elements is sbox_omega*sbox_y*sbox_z
        #fastest zdet_relative, 
        #medium  ydet_relative,
        #slow    omega layer
        #To read shoebox 
        #1d_shoebox =n.fromstring(file.readline,sep=' ',dtype=dtype)
        #shoebox = 1d_shoebox.reshape(sbox_omega,sbox_y,sbox_z)
        file.write('REFL_ID = %i\n' %self.grain[grainno].refs[reflno,A_id['ref_id']])
        file.write('SPOT_ID = %i\n' %self.grain[grainno].refs[reflno,A_id['spot_id']])
        for i in range(self.param['sbox_omega']):
            #file.write('WLAYER = %i\n' %i)
            if len(peak[i]) == 0:
                peak[i] = n.zeros((self.param['sbox_y'],self.param['sbox_z']))
            peak[i].tofile(file,sep=' ',format=format)
            file.write(' ')
        file.write('\n')

