#!/usr/bin/env python

#
# Checking input  
#

from string import split
import sys, os 
from fabio import deconstruct_filename, jump_filename
from xfab import sg
import variables,file_io

import numpy as n
import logging

logging.basicConfig(level=logging.DEBUG,format='%(levelname)s %(message)s')

class parse_input:
    def __init__(self,input_file = None):
        self.filename = input_file
        self.entries = {}
        self.grainno = 0
        self.grain_list_orient = []
        self.grain_list_pos = []
        self.entries['omega_range'] = []
        self.entries['imagefile'] = []
        self.entries['frame_range'] = []
        self.no_pos = 0 # keeping track of no of grain position
        # Experimental setup
        self.needed_items = {
                    'wavelength' : 'Missing input: wavelenght [wavelength in angstrom]',
                    'distance'   : 'Missing input: distance [sample-detector distance in mm)]',
                    'dety_center': 'Missing input: dety_center [beamcenter, y in pixel coordinatees]',
                    'detz_center': 'Missing input: detz_center [beamcenter, z in pixel coordinatees]',
                    'y_size'     : 'Missing input: y_size [Pixel size y in mm]',
                    'z_size'     : 'Missing input: z_size [Pixel size z in mm]',
                    'dety_size'  : 'Missing input: dety_size [detector y size in pixels]',
                    'detz_size'  : 'Missing input: detz_size [detector z size in pixels]',
                    #'omega_range'       : 'Missing input: omega_range [Omega start and end in degrees]',
                    'omega_step'      : 'Missing input: omega_step [Omega step size in degrees]',
#                    'unit_cell'  : 'Missing input: unit_cell [unit cell parameters: a,b,c,alpha,beta, gamma]',
                    'direc'      : 'Missing input: direc [directory to save output]',
                    'imagefile' :  'Missing input: imagefile [the first image file to be processed]'
                    }
        self.optional_items = {
#            'sgno': 1,
            'tilt_x'     : 0,
            'tilt_y'     : 0,
            'tilt_z'     : 0,
            'wedge'      : 0.,
            'chi'        : 0.,
            'beampol_factor' : 1,
            'beampol_angle' : 0.0,
            'spatial' : None,
            'flood' : None,
            'dark' : None,
            'darkoffset' : None,
            'grainsfile' : None,
            'overlap_test' : True,
            'overlap_specs' : [-1,0,0]
            }

        
    def read(self):     
        try:
            f = open(self.filename,'r')
        except:
            raise IOError( '\n\n No file named %s\n' %self.filename )
        
        self.input = f.readlines()
        f.close()

        for lines in self.input:
            if lines[0] != '#' and lines[0] != '%':
                #line = split(lines,'=')
                if lines.find('#') > -1:
                    lines = split(lines,'#')[0]
                if lines.find('%') > -1:
                    lines = split(lines,'%')[0]
                line= split(lines)
                if len(line) != 0:
                    key = line[0]
                    val = line[1:]
                    valtmp = '['
                    if len(val) > 1 or key == 'imagefile':
                        for i in val:
                            valtmp = valtmp + i +','
                                                        
                        val = valtmp + ']'
                    else:
                        val = val[0]
                    try:
                        self.entries[key] = eval(val)
                    except:
                        raise SyntaxError, 'Line starting with: %s is not following the syntax' %key

                
    def check(self):
        self.missing = False

        for item in self.needed_items:
            if item not in self.entries:
                logging.warning(item+" missing from your input file")
                self.missing = True

        assert len(self.entries['omega_range']) == \
            len(self.entries['imagefile'])*2,\
            'number of omega ranges does not match the number' + \
            ' of imagefile given'

        if  'sgname' in self.entries and \
                'sgno' in self.entries and \
                'sysconditions' not in self.entries:
            assert sg.sg(sgname=self.entries['sgname']).name == \
                sg.sg(sgno=self.entries['sgno']).name,\
                'both space group name (sgname) and number (sgno) '+\
                'are input and these do not correspond to the same space group'
        if 'sysconditions' in self.entries and \
                ('sgno' in self.entries or  'sgname' in self.entries):
            logging.warning("Both space group (sgno or sgname) is given " + \
                                "and so it systematic absences (sysconditions)")
            logging.warning("The systematic absent conditions " + \
                                "given by sysconditions will be used")

        
        no_grains = 0
        for key in self.entries:
            val = self.entries[key]
            if val != None and key != 'imagefile':
                if  'U_grains' in key:
                    if len(val) != 3:
                        assert len(val) == 9, 'Wrong number of arguments for %s' %key
                    else:
                        assert val.shape == (3,3), \
                            'Wrong number of arguments for %s' %key

                    self.entries[key] = n.array(self.entries[key])
                    self.entries[key].shape = (3,3)
                    no_grains += 1
                if  'UBI_grains' in key:
                    if len(val) != 3:
                        assert len(val) == 9, \
                            'Wrong number of arguments for %s' %key
                    else:
                        assert val.shape == (3,3), \
                            'Wrong number of arguments for %s' %key

                    self.entries[key] = n.array(self.entries[key])
                    self.entries[key].shape = (3,3)
                    no_grains += 1
                if 'pos_grains' in key:
                    assert len(val) == 3, 'Wrong number of arguments for %s' %key

        # Fill in default values for item in optional_items not specified by user
        for item in self.optional_items:
            if (item not in self.entries):
                self.entries[item] = self.optional_items[item]

        # read U, pos        

        for item in self.entries:
            if '_grains_' in item:
                if 'U' in item or 'UBI' in item:
                    self.grain_list_orient.append(eval(split(item,'_grains_')[1]))
                elif 'pos' in item:
                    self.grain_list_pos.append(eval(split(item,'_grains_')[1]))

        self.grain_list_orient.sort()
        self.grain_list_pos.sort()

        assert len(self.grain_list_orient) != 0 or \
            self.entries['grainsfile'] != None, 'No grain orientations given'        

        if  len(self.grain_list_pos) != 0 :
            assert self.grain_list_orient == self.grain_list_pos,\
            'Specified grain numbers for U_grains (UBI_grains) and pos_grains disagree'

        if self.entries['grainsfile'] != None:
            self.read_graininfo()
            

        self.entries['grain_list'] = self.grain_list_orient        
        self.entries['no_grains'] = len(self.entries['grain_list'])


    def read_graininfo(self):
        if os.path.splitext(self.entries['grainsfile'])[-1] == '.gff':
            from ImageD11 import columnfile
            gff_file =  columnfile.columnfile(self.entries['grainsfile'])
            self.grain_list_orient = gff_file.grain_id.astype(n.int)
            if 'UBI11' in gff_file.titles:
                for i in range(gff_file.nrows):
                    self.entries['UBI_grains_%i' %gff_file.grain_id[i]] = \
                        n.array([
                            [gff_file.UBI11[i],gff_file.UBI12[i],gff_file.UBI13[i]],
                            [gff_file.UBI21[i],gff_file.UBI22[i],gff_file.UBI23[i]],
                            [gff_file.UBI31[i],gff_file.UBI32[i],gff_file.UBI33[i]]
                            ])
            elif 'U11' in gff_file.titles:
                for i in range(gff_file.nrows):
                    self.entries['U_grains_%i' %gff_file.grain_id[i]] = \
                        n.array([
                            [gff_file.U11[i],gff_file.U12[i],gff_file.U13[i]],
                            [gff_file.U21[i],gff_file.U22[i],gff_file.U23[i]],
                            [gff_file.U31[i],gff_file.U32[i],gff_file.U33[i]]
                            ])
             
            if 'x' in gff_file.titles:
                self.grain_list_pos = []
                for i in range(gff_file.nrows):
                    self.entries['pos_grains_%i' %gff_file.grain_id[i]] = \
                        n.array([gff_file.x[i],gff_file.y[i],gff_file.z[i]])
                    self.grain_list_pos.append(gff_file.grain_id[i])

            

        elif os.path.splitext(self.entries['grainsfile'])[-1] == '.log':
            log_file = file_io.readlog(self.entries['grainsfile'])
            
        else:
            raise IOError, 'Unknown file type - knwon types are .gff and .log'
     
    def initialize(self):
 
        # Setting conditions for systematic absent reflections
        if 'sgno' not in self.entries:
            if 'sgname' not in self.entries:
                sysconditions = sg.sg(sgno=1).syscond
                crystal_system = sg.sg(sgno=1).crystal_system
            else:
                crystal_system = sg.sg(sgname=self.entries['sgname']).crystal_system
                sysconditions = sg.sg(sgname=self.entries['sgname']).syscond
        else:
            crystal_system = sg.sg(sgno=self.entries['sgno']).crystal_system
            sysconditions = sg.sg(sgno=self.entries['sgno']).syscond
        self.entries['crystal_system'] = crystal_system
        if 'sysconditions' not in self.entries:
            self.entries['sysconditions'] = sysconditions
            
            
        

        # If grain position not given set positions to (0,0,0)
        if  len(self.grain_list_pos) == 0 :
            for i in range(self.entries['no_grains']):
                self.entries['pos_grains_%s' %(self.entries['grain_list'][i])] =\
                    n.zeros((3))

        # Init image related parameters
        self.entries['n_omega_ranges'] = len(self.entries['omega_range'])/2
        self.entries['omega_range'] = n.array(self.entries['omega_range']).\
            reshape(self.entries['n_omega_ranges'],2)

        # Init overlap specifications
        if self.entries['overlap_test']:
            if self.entries['overlap_specs'][0] < 0:
                self.entries['overlap_specs'] = [self.entries['sbox_y'],
                                                 self.entries['sbox_z'],
                                                 self.entries['sbox_omega']]
                


        # Init image related parameters
        self.entries['n_omega_ranges'] = len(self.entries['omega_range'])
        self.entries['frame_numbers'] = [] # initilize frame number list

        if 'start_frame' not in self.entries:
            self.entries['start_frame'] = []
            use_fileinfo = True
        else:
            use_fileinfo = False

        for image in self.entries['imagefile']:
            fileinfo = deconstruct_filename(image)
            self.entries['filetype'] = fileinfo.format
            self.entries['stem']= fileinfo.stem

            if use_fileinfo == True:
                self.entries['start_frame'].append( fileinfo.num)

        # Generate FILENAME of frames
        
        omega_step = self.entries['omega_step']
        omega_sign = self.entries['omega_sign']

        i=0      

        #Initialize frameinfo container
        self.frameinfo = []
        logging.info("Generating frame data...")
        for omrange in range(self.entries['n_omega_ranges']):
            omega_start  = self.entries['omega_range'][omrange][0]
            omega_end  = self.entries['omega_range'][omrange][1]
            start_frame = self.entries['start_frame'][omrange]

            # The 1e-9 is a hack in order to make sure the omega_end is 
            # also in the list. arange sometimes does it sometimes not :-)
            omegalist = omega_sign*n.arange(omega_start,omega_end+1e-9,omega_step) 
            omegalist.sort()
            nframes = len(omegalist)-1
            
            if omega_sign > 0:
                 filerange = n.arange(start_frame,start_frame+nframes)
            else:
                 filerange = n.arange((start_frame-1)+nframes,
                                      (start_frame-1),
                                      omega_sign)
                 # reverse omega_start/omega_end
                 self.entries['omega_range'][omrange] = [omega_end*omega_sign,
                                                         omega_start*omega_sign]
            self.entries['frame_range'].append([filerange[0],filerange[-1]])

            nlist = 0
            self.entries['frame_omega'] = {}
            for no in filerange:
                 self.frameinfo.append(variables.frameinfo_cont(no))
                 self.entries['frame_numbers'].append(no)
                 self.entries['frame_omega'][no] = omegalist[nlist]
                 self.frameinfo[i].name = jump_filename(self.entries['imagefile'][omrange],no)
                 self.frameinfo[i].omega = omegalist[nlist]
                 self.frameinfo[i].nrefl = 0 # Initialize number of reflections on frame
                 self.frameinfo[i].refs = [] # Initialize number of reflections on frame
                 nlist += 1
                 i += 1

        self.entries['no_images'] = i
        logging.debug("Printing frameinfo...")

        # Does output directory exist?
        if not os.path.exists(self.entries['direc']):
            os.mkdir(self.entries['direc'])
            
        if 'theta_min' not in self.entries:
            # Find maximum theta for generation of all possible reflections on
            # the detector from the detector specs
            logging.warning("No theta_min given - sets theta_min to 0.0")
            self.entries['theta_min'] = 0.0

        if 'theta_max' not in self.entries:
            # Find maximum theta for generation of all possible reflections on
            # the detector from the detector specs
            c2c = n.zeros(4)
            dety_center_mm = self.entries['dety_center'] * self.entries['y_size']
            detz_center_mm = self.entries['detz_center'] * self.entries['z_size']
            dety_size_mm = self.entries['dety_size'] * self.entries['y_size']
            detz_size_mm = self.entries['detz_size'] * self.entries['z_size']
            c2c[0] = n.sqrt((dety_center_mm-dety_size_mm)**2 + (detz_center_mm-detz_size_mm)**2)
            c2c[1] = n.sqrt((dety_center_mm-dety_size_mm)**2 + (detz_center_mm-0)**2)
            c2c[2] = n.sqrt((dety_center_mm-0)**2 + (detz_center_mm-detz_size_mm)**2)
            c2c[3] = n.sqrt((dety_center_mm-0)**2 + (detz_center_mm-0)**2)
            c2c_max = c2c.max()
            theta_max = n.arctan(c2c_max/self.entries['distance'])/2. * 180/n.pi
            logging.warning("No theta_max given - " + \
                                "sets theta_max to %f for full detector coverage" %theta_max)
            self.entries['theta_max'] = theta_max



if __name__=='__main__':

    #import check_input
    try:
        filename = sys.argv[1] 
    except:
        print 'Usage: check_input.py  <input.inp>'
        sys.exit()

    myinput = parse_input(input_file = filename)
    myinput.read()
    print myinput.entries
    myinput.check() 
    if myinput.missing == True:
        print 'MISSING ITEMS'
    myinput.evaluate()
