#!/usr/bin/env python
# Scala.py
# Maintained by G.Winter
# 5th April 2004
# A part of the second generation scheduler.
# 
# This is a wrapper for the CCP4 program "Scala", which will scale and merge
# integrated intensities from Mosflm, which have been sorted by sortmtz.
# The output from this program is a set of scaled intensities, and some useful
# statistics.
# 
# $Id: Scala.py,v 1.11 2005/11/22 13:38:15 svensson Exp $

import os

if not os.environ.has_key('CCP4'):
    raise RuntimeError, 'CCP4 not found'
if not os.environ.has_key('DNAHOME'):
    raise RuntimeError, 'DNAHOME not defined'

import sys

dna = os.environ['DNAHOME']

sys.path.append(dna + '/xsd/python')
sys.path.append(dna + '/scheduler/Scheduler/Driver')
sys.path.append(dna + '/scheduler/Scheduler/Mosflm')
sys.path.append(dna + '/expertise/python/graeme/Expert')

import Driver
import Output
import Exist
import Expert
import CCP4
import CCP4Translation
import XSD
import Anomalous

# singletons for messanging and storage

from Messenger import Messenger
from Somewhere import Somewhere

class Scala(Driver.Driver):
    '''A class to execute scala'''

    def __init__(self):
        Driver.Driver.__init__(self, 1, 'scala')
        self.setExecutable('scala')

        self.hklin = ''
        self.hklout = ''
        self.polish = None

        # input parameters - assume that 'None' corresponds
        # to not set
        
        self.cycle_limit = None
        self.anomalous_scattering = None
        self.bfactor_refinement = None
        self.spacing = None
        self.secondary = None
        self.resolution_upper = None
        self.resolution_lower = None
        self.tails = True

        self.start_batch = None
        self.end_batch = None

        self.sdadd_part = None
        self.sdfac_part = None
        self.sdadd_full = None
        self.sdfac_full = None

        # result data
        self.scaling_converged = True
        
    def setSD_full(self, sdfac_full, sdadd_full):
        self.sdfac_full = sdfac_full
        self.sdadd_full = sdadd_full

    def setSD_partial(self, sdfac_part, sdadd_part):
        self.sdfac_part = sdfac_part
        self.sdadd_part = sdadd_part

    def setHklout(self, hklout):
        self.hklout = hklout

    def setHklin(self, hklin):
        self.hklin = hklin

    def setPolish(self, polish):
        self.polish = polish

    def setStart_batch(self, start_batch):
        self.start_batch = start_batch

    def setEnd_batch(self, end_batch):
        self.end_batch = end_batch

    def setResolution(self, upper, lower = None):
        self.resolution_upper = upper
        if lower:
            self.resolution_lower = lower

    def setCycle_limit(self, cycle_limit):
        self.cycle_limit = cycle_limit

    def setAnomalous_scattering(self, anomalous_scattering):
        if anomalous_scattering == 'true' or \
           anomalous_scattering == 'on':
            self.anomalous_scattering = 'on'
        else:
            self.anomalous_scattering = 'off'

    def setBfactor_refinement(self, bfactor_refinement):
        if bfactor_refinement == 'true' or \
           bfactor_refinement == 'on':
            self.bfactor_refinement = 'on'
        else:
            self.bfactor_refinement = 'off'

    def setSpacing(self, spacing):
        self.spacing = spacing

    def setSecondary(self, secondary):
        self.secondary = secondary

    def scale(self):
        ''' actually start the scaling task'''
        if not self.hklout:
            raise CCP4.CCP4Exception, 'no HKLOUT specified'

        if self.hklin == '':
            raise CCP4.CCP4Exception, 'no HKLIN specified'

        if self.getWorkingDirectory() == '':
            here = Driver.getcwd()
        else:
            here = self.getWorkingDirectory()

        workingDirectory = here + '/scale'
        self.setWorkingDirectory(workingDirectory)

        try:
            os.mkdir(workingDirectory)
        except:
            pass

        if self.polish:
            self.hklout = '%s-broken' % self.hklout

        command = 'DNA scala.stf HKLIN ' + self.hklin + \
                  ' HKLOUT ' + self.hklout

        if self.polish:
            command += ' SCALEPACK %s' % self.polish

        self.start(command)

        # see if any batches need excluding
        duff_batches = Somewhere.get('duff_batches')
        if duff_batches:
            duff_command = ''
            for d in duff_batches:
                duff_command += '%d ' % d
            Messenger.log_write('Excluding batches: %s' % duff_command)

            start = self.start_batch
            end = self.end_batch

            batch_range = Somewhere.get('batch_range')
            if batch_range:
                if not start:
                    start = batch_range[0]
                if not end:
                    end = batch_range[1]

            batch_commands = Expert.BatchesForScaling(duff_batches,
                                                      start, end)

            for b in batch_commands:
                self.input(b)
            
        else:
            if self.start_batch and self.end_batch:
                self.input('run 1 batch %d to %d' %\
                           (self.start_batch, self.end_batch))
            else:
                self.input('run 1 all')
    
        if self.polish:
            Messenger.log_write('Outputting unmerged polish - ignore %s' \
                                % self.hklout)
            self.input('output unmerged polish')
            

        if self.sdfac_full and \
               self.sdadd_full and \
               self.sdfac_part and \
               self.sdadd_part:
            self.input('sdcorrection full %f %f partial %f %f' % \
                       (self.sdfac_full, self.sdadd_full, \
                        self.sdfac_part, self.sdadd_part))
            
        if self.resolution_upper:
            resolution_command = 'resolution high ' + \
                                 str(self.resolution_upper)
            if self.resolution_lower:
                resolution_command += ' lower ' + \
                                      str(resolution_lower)
            self.input(resolution_command)

        # now input some of the 'optional' keywords
        if self.cycle_limit != None:
            self.input('cycles ' + str(self.cycle_limit))

        if self.anomalous_scattering != None:
            self.input('anomalous ' + str(self.anomalous_scattering))

        scaling_command = 'scales rotation '

        if self.spacing != None:
            scaling_command += 'spacing ' + str(self.spacing) + ' '

        if self.secondary != None:
            scaling_command += 'secondary ' + str(self.secondary) + ' '

        if self.bfactor_refinement != None:
            scaling_command += 'bfactor ' + str(self.bfactor_refinement) + ' '

        if self.tails:
            scaling_command += 'tails'

        self.input(scaling_command)

        self.close()

        while 1:
            try:
                line = self.output()

                if line[:29] == ' Scaling has not converged!!':
                    self.scaling_converged = False

            except DriverException, e:
                raise CCP4.CCP4Exception, e

            if not line:
                break

        self.kill()

        # gather up the output

        if not self.scaling_converged:
            Messenger.log_write('Warning - the scaling has not converged!')

        if self.getWorkingDirectory():
            self.result = Output.Output(self.getWorkingDirectory() + \
                                        '/scala.stf')
        else:
            self.result = Output.Output('scala.stf')

        Somewhere.store('scaling_results', self.result)
        if self.hklout[0] == '/':
            Somewhere.store('scaling_mtz', self.hklout)
        else:
            Somewhere.store('scaling_mtz', '%s/%s' %
                            (self.getWorkingDirectory(), self.hklout))

        self.dnaResult = \
                       CCP4Translation.Scale_reflections_response(self.result)

        if not self.scaling_converged:
            status = self.dnaResult.getStatus()
            status.setCode('warning')
            status.setMessage('The scaling has not converged in %d cycles' % \
                              self.cycle_limit)
            self.dnaResult.setStatus(status)

        # analyse the NORMPLOT and ANOMPLOT if we're doing anomalous,
        # otherwise just have a look at the normplot

        if self.anomalous_scattering == 'on':
            normplot = self.getWorkingDirectory() + '/NORMPLOT'
            anomplot = self.getWorkingDirectory() + '/ANOMPLOT'

            chi_sq = Anomalous.AnalyseNPPAnom(normplot, anomplot)

            Messenger.log_write('Analysis of Normal Probability Statistics')
            Messenger.log_write('Over sigma range [%f, %f]' % \
                                (chi_sq[4], chi_sq[5]))
            Messenger.log_write('Mean      = %f' % chi_sq[0])
            Messenger.log_write('Full      = %f' % chi_sq[1])
            Messenger.log_write('Partial   = %f' % chi_sq[2])
            Messenger.log_write('Anomalous = %f' % chi_sq[3])

        self.setWorkingDirectory(here)
        return self.dnaResult

if __name__ == '__main__':
    s = Scala()

    s.setHklin('graeme_sorted.mtz')
    s.setHklout('graeme_scaled.mtz')

    s.setAnomalous_scattering('off')
    s.setBfactor_refinement('off')
    s.setCycle_limit(20)
    s.setSpacing(5)
    s.setSecondary(6)

    result = s.scale()

    print result.marshal()
