#!/usr/bin/env python
# QuickScaler.py
# Maintained by G.Winter
# 21st June 2005
# 
# A quick scaler - this will built on top of a scala wrapper to get 
# some quick results out. Note well that the scala wrapper is included,
# so this should be self contained.
# 

import os, time

__doc__ = \
        '''A module to provide quick scaling analysis, to ensure that
        data are being collected as expected during an X-Ray diffraction
        experiment. This is designed to work as a part of DNA and provide
        very rapid on-line feedback. Note that this also includes the
        sorting of reflection files.'''

__fixme__ = \
          '''The following features need to be fixed of incorporated:

          - inclusion of truncate to test for twinning
          - perhaps a patterson check
          - space group determination (optional)'''

__fixed__ = \
          '''The following features have been added from the fixme list:
          - extension to include more than one input reflection file'''
          

if not os.environ.has_key('CCP4'):
    raise RuntimeError, 'CCP4 not found'
if not os.environ.has_key('DNAHOME'):
    raise RuntimeError, 'DNAHOME not defined'

import sys, string

dna = os.environ['DNAHOME']

sys.path.append(os.path.join(dna, 'scheduler', 'Scheduler', 'Driver'))

import Driver
import Output

def RemovePrefix(list, prefix):
    '''A small routine to work through a list of strings which are
    assumed to be the equivalent of "%s%d" % (prefix, number) - this
    returns the list of numbers.'''
    
    result = []
    for l in list:
        this = str(l)
        if this != '__count__':
            that = string.replace(this, prefix, '')
            result.append(int(that))

    result.sort()
    return result

def QuickScalerWrapperTranslation(output):
    '''Get some interesting information from the output object and return it
    in a dictionary. The contents and structure of the dictionary are
    described in the accompanying documentation (as soon as it is
    written!).'''

    # gw 21st June 2005 this should probably be XMLified

    results = { }

    try:
        message = output.get('error', 'error', 'message')[0]

        results['status'] = {'code':'error', 'messages':str(message)}

        return results
    except:
        pass

    # if we got this far then the status is not "error"

    results['status'] = {'code':'ok'}

    resolutions = []
    completeness = []

    for shell in output.searchforlist(
        'completeness_multiplicity_vs_resolution', 'shell_'):
        percent_possible = float(output.get(
            'completeness_multiplicity_vs_resolution', 'shell_' + shell,
            'percent_possible')[0])
        resolution = float(output.get(
            'completeness_multiplicity_vs_resolution', 'shell_' + shell,
            'd_min')[0])

        resolutions.append(resolution)
        completeness.append((resolution, percent_possible))

    results['resolution_shells'] = resolutions
    results['completeness'] = completeness

    # maybe get the axial reflections out here - then maybe not
    # get the per-batch information out

    batches = RemovePrefix(output.get('scale_factors_by_batch').keys(), \
                           'batch_')

    results['batches'] = batches

    b_factor_data = { }
    r_merge_data = { }
    scale_factor_data = { }
    n_ref_data = { }
    n_rej_data = { }

    for b in batches:
        key = 'batch_' + str(b)
        
        n_ref = int(output.get('scale_factors_by_batch', key, 'n_ref')[0])
        n_ref_data[b] = n_ref
        
        scale_factor = float(output.get('scale_factors_by_batch',
                                        key, 'mean_k')[0])
        scale_factor_data[b] = scale_factor
        
        b_factor = float(output.get('scale_factors_by_batch',
                                    key, 'b_factor')[0])
        b_factor_data[b] = b_factor
        
        n_rej = int(output.get('scale_factors_by_batch',
                               key, 'n_rejected')[0])
        n_rej_data[b] = n_rej

        r_merge = float(output.get('analysis_against_batch',
                                   key, 'rmerge')[0])
        r_merge_data[b] = r_merge

    results['by_batch'] = {'r_merge':r_merge_data,
                           'scale':scale_factor_data,
                           'n_ref':n_ref_data,
                           'n_rej':n_rej_data,
                           'b_factor':b_factor_data}

    rmerge = float(output.get('summary_data', 'overall', 'rmerge')[0])
    results['r_merge'] = rmerge

    # information as a function of resolution shell next

    shells = output.searchforlist('analysis_against_resolution', 'shell_')
    int_shells = []
    for s in shells:
        int_shells.append(int(s))
    shells = int_shells
    shells.sort()

    results['resolution_shells'] = shells

    resolution_data = { }
    i_over_sig_data = { }
    fract_bias_data = { }
    r_merge_data = { }

    for s in shells:
        resolution = float(output.get('analysis_against_resolution',
                                 'shell_%d' % s, 'd_min')[0])
        resolution_data[s] = resolution

        r_merge = float(output.get('analysis_against_resolution',
                                   'shell_%d' % s, 'rmerge')[0])
        r_merge_data[s] = r_merge
        
        i_over_sig = float(output.get('analysis_against_resolution',
                       'shell_%d' % s, 'mean_intensity_sigma')[0])
        i_over_sig_data[s] = i_over_sig
        
        fract_bias = float(output.get('analysis_against_resolution',
                       'shell_%d' % s, 'fractional_bias')[0])
        fract_bias_data[s] = fract_bias

    results['by_shell'] = {'r_merge':r_merge_data,
                           'fract_bias':fract_bias_data,
                           'i_over_sig':i_over_sig_data,
                           'resolution':resolution_data}

    # now the overall statistics for the data set

    results['resolution_range'] = (float(output.get('summary_data',
                                                    'overall',
                                                    'high_resolution_limit'
                                                    )[0]),
                                   float(output.get('summary_data',
                                                    'overall',
                                                    'low_resolution_limit')
                                         [0]))
    results['r_merge'] = float(output.get('summary_data',
                                          'overall',
                                          'rmerge')[0])
    results['i_over_sig'] = float(output.get('summary_data',
                                             'overall',
                                             'mean_i_over_sigma')[0])
    results['completeness'] = float(output.get('summary_data',
                                               'overall',
                                               'completeness')[0])
    results['multiplicity'] = float(output.get('summary_data',
                                               'overall',
                                               'multiplicity')[0])
    results['n_ref_tot'] = int(output.get('summary_data',
                                          'overall',
                                          'total_observations')[0])
    results['n_ref_unique'] = int(output.get('summary_data',
                                             'overall',
                                             'total_unique_observations')[0])
    
    # now the overall statistics for the high resolution shell

    results['resolution_range_hr'] = (float(output.get('summary_data',
                                                    'outer_shell',
                                                    'high_resolution_limit'
                                                    )[0]),
                                      float(output.get('summary_data',
                                                       'outer_shell',
                                                       'low_resolution_limit'
                                                       )[0]))
    results['r_merge_hr'] = float(output.get('summary_data',
                                             'outer_shell',
                                             'rmerge')[0])
    results['i_over_sig_hr'] = float(output.get('summary_data',
                                             'outer_shell',
                                             'mean_i_over_sigma')[0])
    results['completeness_hr'] = float(output.get('summary_data',
                                               'outer_shell',
                                               'completeness')[0])
    results['multiplicity_hr'] = float(output.get('summary_data',
                                               'outer_shell',
                                               'multiplicity')[0])
    results['n_ref_tot_hr'] = int(output.get('summary_data',
                                          'outer_shell',
                                          'total_observations')[0])
    results['n_ref_unique_hr'] = int(output.get('summary_data',
                                             'outer_shell',
                                             'total_unique_observations')[0])
    
    # that's probably enough for now

    return results

class QuickScalerWrapper(Driver.Driver):
    '''A (lightweight) class to execute scala - note that this is about
    getting quick diagnostic results out, not scaled data'''

    def __init__(self, mode = 'gw'):
        Driver.Driver.__init__(self, 1, 'scala')
        self.setExecutable('scala')

        self.hklin = ''
        self.hklout = ''

        self.resolution = 0.0

        if mode == 'gw':
            self.cycle_limit = 3
            self.default_commands = ['run 1 all', 'anomalous on',
                                     'scales rotation spacing 5 ' + \
                                     'secondary 6 ' + \
                                     'bfactor on tails']

        elif mode == 'hrp':
            self.cycle_limit = 6
            self.default_commands = ['run 1 all', 'anomalous on',
                                     'scales rotation spacing 10 ']        

    def setHklout(self, hklout):
        self.hklout = hklout

    def setHklin(self, hklin):
        self.hklin = hklin

    def setResolution(self, resolution):
        self.resolution = resolution

    def scale(self):
        ''' actually start the scaling task'''

        if not self.hklout:
            raise CCP4.CCP4Exception, 'no HKLOUT specified'

        if self.hklin == '':
            raise CCP4.CCP4Exception, 'no HKLIN specified'

        if self.getWorkingDirectory() == '':
            here = Driver.getcwd()
        else:
            here = self.getWorkingDirectory()

        if self.resolution is None:
            self.resolution = 0.0

        workingDirectory = here + '/quick-scale'
        self.setWorkingDirectory(workingDirectory)

        try:
            os.mkdir(workingDirectory)
        except:
            pass

        command = 'DNA scala.stf HKLIN ' + self.hklin + \
                  ' HKLOUT ' + self.hklout

        self.start(command)

        self.input('cycles %d' % self.cycle_limit)
        if self.resolution > 0.0:
            self.input('resolution %f' % self.resolution)

        for command in self.default_commands:
            self.input(command)

        self.close()

        while 1:
            try:
                line = self.output()

            except DriverException, e:
                raise CCP4.CCP4Exception, e

            if not line:
                break

        self.kill()

        # gather up the output

        if self.getWorkingDirectory():
            self.result = Output.Output(self.getWorkingDirectory() + \
                                        '/scala.stf')
        else:
            self.result = Output.Output('scala.stf')

        result = QuickScalerWrapperTranslation(self.result)

        return result

class QuickSorterWrapper(Driver.Driver):
    '''A class to wrap a no-options version of sortmtz'''

    def __init__(self):
        Driver.Driver.__init__(self, 1, 'sortmtz')
        self.setExecutable('sortmtz')

        self.hklin = None
        self.hklout = ''

        self.default_commands = ['H K L M/ISYM BATCH']

    def setHklout(self, hklout):
        self.hklout = hklout

    def setHklin(self, hklin):
        self.hklin = hklin

    def addHklin(self, hklin):

        if self.hklin == None:
            self.hklin = []
        
        if type(self.hklin) == type([]):
            self.hklin.append(hklin)
        else:
            raise RuntimeError, 'cannot add when you have called set'

    def sort(self):
        '''Actually sort the reflections - maybe this should inspect the
        reflection file before starting?'''

        if not self.hklout:
            raise CCP4.CCP4Exception, 'no HKLOUT specified'

        if not self.hklin:
            raise CCP4.CCP4Exception, 'no HKLIN specified'

        if self.getWorkingDirectory() == '':
            here = Driver.getcwd()
        else:
            here = self.getWorkingDirectory()

        workingDirectory = here + '/quick-sort'
        self.setWorkingDirectory(workingDirectory)

        try:
            os.mkdir(workingDirectory)
        except:
            pass

        # what is happening here?
        # if hklin is one file, add on the command line
        # else add in the commands for the sort and follow with a list
        # of input reflection files (to fit in with the DNA way of
        # processing data)

        if type(self.hklin) == type(''):
            command = 'HKLIN %s HKLOUT %s' % (self.hklin, self.hklout)
            self.start(command)
        elif type(self.hklin) == type([]):
            command = 'HKLOUT %s' % (self.hklout)
            self.start(command)
        else:
            raise RuntimeError, 'unknown type for hklin'
        
        for c in self.default_commands:
            self.input(c)

        if type(self.hklin) == type([]):
            for hklin in self.hklin:
                self.input(hklin)
            

        self.close()

        while 1:
            try:
                line = self.output()

            except DriverException, e:
                raise CCP4.CCP4Exception, e

            if not line:
                break

        self.kill()

        return 

def QuickScale(hklin, hklout, mode = 'hrp', working_dir = None,
               resolution = None):
    '''Do a quick scaling to find out about the data'''

    # first check that the HKLIN and HKLOUT have decent file names
    # this is a little messy because the input could be a file or
    # a list of files - damn python and it's no typing!

    if type(hklin) == type(str):
        if hklin[0] != '/' and hklin[0] != '~':
            hklin = os.path.join(os.getcwd(), hklin)

    else:
        hklin_new = []
        for h in hklin:
            if h[0] != '/' and h[0] != '~':            
                hklin_new.append(os.path.join(os.getcwd(), h))
            else:
                hklin_new.append(h)
                
        hklin = hklin_new
                                 

    if hklout[0] != '/' and hklout[0] != '~':
        hklout = os.path.join(os.getcwd(), hklout)

    # first sort

    if type(hklin) == type([]):
        tempfile = '%s-sort.tmp' % hklin[0]
    else:
        tempfile = '%s-sort.tmp' % hklin

    qsrtw = QuickSorterWrapper()
    if working_dir:
        qsrtw.setWorkingDirectory(working_dir)
    if type(hklin) == type(''):
        qsrtw.setHklin(hklin)
    else:
        for h in hklin:
            qsrtw.addHklin(h)
    qsrtw.setHklout(tempfile)    
    qsrtw.sort()

    qsw = QuickScalerWrapper(mode)
    if working_dir:
        qsw.setWorkingDirectory(working_dir)    
    qsw.setHklin(tempfile)
    qsw.setHklout(hklout)
    if resolution:
        qsw.setResolution(resolution)
    results = qsw.scale()

    # copy the log files somewhere useful
    

    return results

if __name__ == '__main__':
    # write a test

    if len(sys.argv) < 3:
        raise RuntimeError, '%s hklin hklout' % sys.argv[0]

    # be able to scale more than one file in?

    hklin = sys.argv[1:-1]
    hklout = sys.argv[-1]

    t = time.time()
    results = QuickScale(hklin, hklout)
    print results['r_merge']
    t = time.time() - t
