#!/usr/bin/env python
# AutoPointgroup.py
# Maintained by G.Winter
# 3rd November 2005
# 
# A jiffy to autoindex (triclinic) a small wedge of input data, integrate 
# the wedges and then feed the results in to pointless.
# 
# Structure:
# TriclinicIntegrater
# PointlessRunner
# DecisionMaker
# StrategyPredicter (Will require input about the currently collecting
#                    data set)


__doc__ = '''The idea is to basically run the scripts:
#!/bin/bash

ipmosflm HKLOUT pointless.mtz << eof
directory /media/data1/graeme/jcsg/mad/1vpj/12287
template 12287_1_E1_###.img
symmetry p1
beam 108.93 105.04
resolution 2.0
autoindex dps image 1
autoindex dps image 90
go
mosaic 0.5
postref fix all
process 1 3
go
process 88 90 
go
eof

pointless hklin pointless.mtz

and use the output...'''

import os, sys, math
import re, string
import exceptions

# this may want to be DNAHOME when imported there
# changed now
environ_key = 'DNAHOME'

if not os.environ.has_key(environ_key):
    raise RuntimeError, '%s not defined' % environ_key

environ = os.environ[environ_key]

# Provide access to the Driver module - this should be all we need for this
# case, since this is designed to be a closed module.
# communications to the rest of the DNA system
sys.path.append(os.path.join(environ, 'scheduler', 'Scheduler',
                             'Mosflm'))
sys.path.append(os.path.join(environ, 'scheduler', 'Scheduler', 'Driver'))
# also need DiffractionImage
sys.path.append(os.path.join(environ, 'scheduler', 'DiffractionImage',
                             'lib'))
sys.path.append(os.path.join(environ, 'xsd', 'python'))

from Driver import Driver, DriverException

from DiffractionImage import DiffractionImage

from Messenger import Messenger

class TriclinicIntegraterWrapper(Driver):
    '''A jiffy class to integrate the input wedges and produce a new
    MTZ file with the results. This will take as input the template and
    directory, the list of wedge(s) to integrate and optionally a
    refined beam position. This may raise a RuntimeError if anything
    fails.'''

    # wedges should be like [(1, 3), (91, 93)]

    def __init__(self):
        # generic Driver stuff
        Driver.__init__(self)
        self.setExecutable('ipmosflm')

        # the data handling stuff

        self.hklout = None
        self.template = None
        self.directory = None
        self.wedges = None

        # optional parameters

        self.beam = None
        self.resolution = 2.0

    def setHklout(self, hklout):
        self.hklout = hklout

    def setTemplate(self, template):
        self.template = template

    def setDirectory(self, directory):
        self.directory = directory

    def setWedges(self, wedges):
        self.wedges = wedges

    def setBeam(self, x, y):
        self.beam = (x, y)

    def setResolution(self, resolution):
        self.resolution = resolution

    def run(self):
        '''Actually perform the process of integrating the input
        wedges in a triclinic spacegroup.'''

        # validate input - this should be a standard way of performing
        # this operation.
        required_properties = ['hklout', 'template', 'directory', 'wedges']

        for property in required_properties:
            if not hasattr(self, property):
                raise RuntimeError, '%s undefined' % property

            value = getattr(self, property)

            if value is None:
                raise RuntimeError, '%s not assigned' % property

        # if we reach here then all should be well

        # next check the value of the wedges property - this should be
        # a list of tuples of (start, end)

        if not type(self.wedges) == type([]):

            description = '[(start, end), (start, end)]'
            
            raise RuntimeError, 'property wedges should be a list of %s' % \
                  description

        for i in range(len(self.wedges)):
            if not len(self.wedges[i]) == 2:
                raise RuntimeError, 'wedge %d should be (start, end)' % i

        # pick images for autoindex - this should probably be doing some
        # analysis of the image headers but that can wait until later

        if len(self.wedges) == 1 and len(self.wedges[0]) < 3:
            raise RuntimeError, 'Too little information to perform autoindex'

        autoindex = (0, 0)

        if len(self.wedges) == 1:
            # use first and last frame from this wedge
            autoindex = (self.wedges[0][0], self.wedges[0][1])
        else:
            # use first frame from the first two wedges
            autoindex = (self.wedges[0][0], self.wedges[1][0])
        
        # ok, that should be the input validated & set up

        self.start('HKLOUT %s DNAOUT %s' % (self.hklout, 'mosflm.stf'))

        # only write the beam if provided as input

        if type(self.beam) == type((1.0, 2.0)):
            if len(self.beam) == 2:
                #                if beam[0] != 0.0 and beam[1] != 0.0:
                self.input('beam %f %f' % self.beam)

        self.input('resolution %f' % self.resolution)
        self.input('directory %s' % self.directory)
        self.input('template %s' % self.template)
        self.input('symmetry p1')
        for image in autoindex:
            self.input('autoindex dps image %d' % image)
        self.input('go')
        self.input('mosaic 0.5')
        self.input('postref fix all')
        for wedge in self.wedges:
            self.input('process %d %d' % wedge)
            self.input('go')

        self.close()

        # I should probably be analysing the output to see if anything is
        # going wrong here...

        while 1:
            line = self.output()
            
            if not line:
                break

            # investigate output

        self.kill()

        # should probably return something useful here... but for the moment
        # don't bother

def TriclinicIntegrate(template, directory, wedges, hklout,
                       resolution = None, beam = None):
    '''A simple interface to the integration.'''

    tiw = TriclinicIntegraterWrapper()

    tiw.setTemplate(template)
    tiw.setDirectory(directory)
    tiw.setWedges(wedges)
    tiw.setHklout(hklout)

    if resolution:
        tiw.setResolution(resolution)

    if beam:
        tiw.setBeam(beam[0], beam[1])

    results = tiw.run()

    return results

class PointlessRunnerWrapper(Driver):
    '''A class to run pointless on an input MTZ file and then decide
    which is probably the correct spage group.'''

    def __init__(self):
        Driver.__init__(self)
        self.setExecutable('pointless')

        # input information

        self.hklin = None

    def setHklin(self, hklin):
        self.hklin = hklin

    def run(self):
        '''Actually run pointless.'''

        if not self.hklin:
            raise RuntimeError, 'hklin not set'

        self.start('hklin %s' % self.hklin)

        # no standard input
        self.close()

        # need to get some useful information from the pointless
        # standard output

        # FIXME - this needs to be looked at - the Z score and '***'
        # rating need to be passed through...

        results = None
        found_solution = False
        collect_Z = False
        Z_scores = { } 
        conf = { } 

        while 1:
            line = self.output()
            
            if not line:
                break

            # keep this so that the character positions are known
            line_orig = line
            
            line = line.strip()
            list = line.split()

            # now grep the output for the information I want

            # this is all about finding the right point group

            if found_solution and len(line) > 0 and line[0] == '<':
                # get the first spacegroup from the line
                spacegroup = line.split('> <')[0].replace('<', '').replace(
                    '>', '')

            if list[:2] == ['>>>>>', '1']:
                reindex_operator = line.split(
                    '[')[1].split(']')[0].replace(',', ', ')
                found_solution = True
            if list[:2] == ['>>>>>', '2']:
                found_solution = False

            # this is all about establishing the Z scores

            if list[:3] == ['More', 'information', 'about']:
                collect_Z = False

            if collect_Z:
                if len(line):
                    index = int(line_orig[:2])
                    Z = float(line_orig[19:25])
                    Z_scores[index] = Z
                    conf[index] = line[14:24].count('*')

            # mod to work with pointless 0.5.2 - the DNA 1.1 version
            if list[:3] == ['Laue', 'Group', 'NetZcc'] or \
                   list[:3] == ['Laue', 'Group', 'NetZc']:
                collect_Z = True
                
        # return something useful
        results = spacegroup, reindex_operator, conf[1]

        self.kill()

        return results


def PointlessRun(hklin):
    '''A simple interface to pointless.'''

    prw = PointlessRunnerWrapper()
    prw.setHklin(hklin)
    results = prw.run()

    return results

def staticUnitTest():
    '''A quick test to see if this is behaving as I expect.'''

    raise RuntimeError, 'do not use this test'

    TriclinicIntegrate('12287_1_E1_###.img',
                       '/media/data1/graeme/jcsg/mad/1vpj/12287',
                       [(1, 3), (88, 90)],
                       'pointless.mtz',
                       beam = (108.93, 105.04))
    PointlessRun('pointless.mtz')

    return

##########################################
# THESE SHOULD BE SOMEWHERE IN A LIBRARY #
##########################################

def image_to_template_and_directory(image):
    '''Take the name of an image (optionally full path) and return a
    suitable template and directory for use elsewhere'''

    # looking for foo_bar_1_nnn.exten

    regexp = re.compile(r'(.*)/(.*)_([0-9]*)\.(.*)')
    match = regexp.match(image)
    if match:
        directory = match.group(1)
        prefix = match.group(2)
        number = match.group(3)
        extension = match.group(4)
        for d in string.digits:
            number = number.replace(d, '#')
        template = prefix + '_' + number + '.' + extension
        return template, directory
    else:
        # perhaps there was no directory in the expression - try without
        directory = os.getcwd()
        regexp = re.compile(r'(.*)_([0-9]*)\.(.*)')
        match = regexp.match(image)
        prefix = match.group(1)
        number = match.group(2)
        extension = match.group(3)
        for d in string.digits:
            number = number.replace(d, '#')
        template = prefix + '_' + number + '.' + extension
        return template, directory    

def get_image_list(template, directory):
    '''Get a list of images (as integers) which match template in
    directory'''

    files = os.listdir(directory)

    # compose a regular expression to find these files

    expression = template.replace('#', '([0-9]*)', 1)
    expression = expression.replace('#', '')
    regexp = re.compile(expression)

    images = []
    
    for file in files:
        match = regexp.match(file)
        if match and len(template) == len(file):
            images.append(int(match.group(1)))

    # always tidy to return things in order
    
    images.sort()
    return images

def template_and_directory_and_number_to_image(template, directory, image):
    '''Given a template, directory and a number create a full path to an
    image'''

    expression = r'(.*)_(\#*)\.(.*)'
    regexp = re.compile(expression)
    match = regexp.match(template)
    prefix = match.group(1)
    extension = match.group(3)
    length = len(match.group(2))
    
    format = '%0' + str(length) + 'd'
    
    number = format % image
    filename = directory + '/' + prefix + '_' + number + '.' + extension

    return filename

def get_header(image):
    '''Read the header information from this image to a dictionary.'''
    try:
        # I think that the 0 means "just open the header"
        # this is correct (checked!)
        d = DiffractionImage(image, 0)
    except exceptions.Exception, e:
        raise RuntimeError, 'Error %s opening image %s' % \
              (e, image)

    header = { }
    header['time'] = d.getExposureTime()
    header['beam_x'] = d.getBeamX()
    header['beam_y'] = d.getBeamY()
    header['pix_x'] = d.getPixelX()
    header['pix_y'] = d.getPixelY()
    header['width'] = d.getWidth()
    header['height'] = d.getHeight()

    # not reading the date information - you have been warned! The
    # canonical implementation is not this one...

    # check to see if these are in mm or in pixels - assume
    # that this is somewhere near the middle of the
    # detector - so if value > detector width in mm then * pixel
    # size

    if d.getBeamX() > (d.getWidth() * d.getPixelX()) and \
           d.getBeamY() > (d.getHeight() * d.getPixelY()):
        header['beam_x'] = d.getBeamX() * d.getPixelX()
        header['beam_y'] = d.getBeamY() * d.getPixelY()
    elif d.getBeamX() > (d.getWidth() * d.getPixelX()) or \
             d.getBeamY() > (d.getHeight() * d.getPixelY()):
        raise RuntimeError, 'inconsistent beam centre values: %f %f' % \
              (d.getBeamX(), d.getBeamY())

    header['dist'] = d.getDistance()
    header['wave'] = d.getWavelength()
    messages = d.getMessage()
    header['message'] = messages.split(';')
    header['start'] = d.getPhiStart()
    header['end'] = d.getPhiEnd()
    header['osc'] = d.getPhiEnd() - d.getPhiStart()

    # not including the squareness of the diffraction image here

    bytes = open(image, 'r').read(4)

    if bytes == '{\nHE':
        header['format'] = 'smv'
    elif bytes[:2] == 'II':
        header['format'] = 'tiff'
    elif bytes[:2] == 'RA':
        header['format'] = 'raxis'
    elif bytes == '\xd2\x04\x00\x00':
        header['format'] = 'marip'
        
    del(d)

    return header

def nint(a):
    b = int(a)
    if a - b > 0.5:
        return b + 1
    return b

##########################################
#                 END                    #
##########################################

def PointlessDNAThing(template, directory, determine_beam = True, beam = None):
    '''Taking the information from a DNA fileinfo object (the template
    and directory) index, integrate the frames in P1 (could be > 1
    frame set e.g. cell refinement data) and then run pointless.

    This will return the highest scoring spacegroup, the associated Z
    score and the **** rating.'''

    # get list of available images
    images = get_image_list(template, directory)

    # construct the image name for the first image
    first = template_and_directory_and_number_to_image(template,
                                                       directory,
                                                       images[0])

    # get the header information for the first image in the list
    header = get_header(first)
    
    # decide from headers which images to use to simulate
    # 2x3 degrees of data
    phi_start = header['start']
    phi_width = header['osc']

    wedges = []

    # want ideally 3 degrees of data
    wedge = int(3.0 / phi_width) - 1
    if wedge < 2:
        wedge = 2
    end_image = wedge + images[0]
    if end_image in images:
        wedges.append((images[0], end_image))

    # next work from the end...

    last_image = images[0] + wedge + nint(90.0 / phi_width)
    if not last_image in images:
        last_image = images[-1]
    start_image = last_image - wedge
    
    if start_image in images:
        wedges.append((start_image, last_image))

    if wedges == []:
        # something has gone horribly wrong - quietly abort....
        # results = spacegroup, reindex_operator, conf[1]
        results = "unknown", "none", "none"
        Messenger.log_write("Not enough data to run pointless")
        return results
    
    # determine the best beam position
    if determine_beam:
        first = template_and_directory_and_number_to_image(template,
                                                           directory,
                                                           wedges[0][0])
        second = template_and_directory_and_number_to_image(template,
                                                            directory,
                                                            wedges[1][0])

        # 3/MAR/06 allow for the fact that something going wrong
        # with the labelit run shouldn't break DNA
        
        try:
            beam = LabelitBeam(first, second)
        except:
            beam = None
    else:
        beam = None

    # run the job
    if beam:
        TriclinicIntegrate(template,
                           directory,
                           wedges,
                           'pointless.mtz',
                           beam = beam)
    else:
        TriclinicIntegrate(template,
                           directory,
                           wedges,
                           'pointless.mtz')
        
    return PointlessRun('pointless.mtz')    

def dynamicUnitTest(image, beam = None, determine_beam = True, ideal = 3.0):
    '''A more thorough unit test which will take an input image
    and produce all of the required input for this process.'''

    # this one can be run with just a pointer to an image from
    # within the set e.g. /tmp/frames/postref-bert_1_001.img

    # parse image to template and directory
    template, directory = image_to_template_and_directory(image)
    
    # get list of available images
    images = get_image_list(template, directory)

    # construct the image name for the first image
    first = template_and_directory_and_number_to_image(template,
                                                       directory,
                                                       images[0])

    # get the header information for the first image in the list
    header = get_header(first)
    
    # decide from headers which images to use to simulate
    # 2x3 degrees of data
    phi_start = header['start']
    phi_width = header['osc']

    wedges = []

    # want ideally 3 degrees of data
    wedge = int(ideal / phi_width) - 1
    if wedge < 2:
        # wedge = 2
        # pass
        wedge = 1
    

    Messenger.log_write('Using %f degree wedges of data' % (wedge * phi_width))

    end_image = wedge + images[0]

    if end_image in images:
        wedges.append((images[0], end_image))
    else:
        # do something clever like picking all images contiguous
        # with the first one - can I do this?
        pass

    # next work from the end...

    last_image = images[0] + wedge + nint(90.0 / phi_width)
    if not last_image in images:
        last_image = images[-1]
    start_image = last_image - wedge
    
    if start_image in images:
        wedges.append((start_image, last_image))
    else:
        # do something clever like picking all images contiguous
        # with the last one - can I do this?
        pass

    if wedges == []:
        # something has gone horribly wrong - quietly abort....
        # results = spacegroup, reindex_operator, conf[1]
        results = "unknown", "none", "none"
        Messenger.log_write("Not enough data to run pointless")
        return results
    
    # determine the best beam position
    if determine_beam:
        first = template_and_directory_and_number_to_image(template,
                                                           directory,
                                                           wedges[0][0])
        second = template_and_directory_and_number_to_image(template,
                                                            directory,
                                                            wedges[1][0])
        beam = LabelitBeam(first, second)
    else:
        beam = None

    # run the job
    if beam:
        TriclinicIntegrate(template,
                           directory,
                           wedges,
                           'pointless.mtz',
                           beam = beam)
    else:
        TriclinicIntegrate(template,
                           directory,
                           wedges,
                           'pointless.mtz')
    results = PointlessRun('pointless.mtz')    

    return results

# use labelit to get the correct direct beam position using the
# --index_only option - optionally!

class LabelitBeamWrapper(Driver):
    '''A really lightweight driver class to use Labelit to get the
    correct beam centre ONLY - this will not actually be used to provide
    anything useful in terms of the actual indexing.'''

    def __init__(self):
        Driver.__init__(self)
        self.setExecutable('labelit.screen')

        self.images = None

    def setImages(self, first, second):
        self.images = (first, second)

    def run(self):
        '''Actually run labelit to get the "correct" beam centre out...'''

        if not self.images:
            raise RuntimeError, 'images not assigned'

        self.start('--index_only %s %s' % self.images)
        self.close()

        beam = None

        while 1:
            line = self.output()

            if not line:
                break

            if line[:13] == 'Beam center x':
                list = line.split()
                x = float(list[3].replace('mm,', ''))
                y = float(list[5].replace('mm,', ''))
                beam = (x, y)

        self.kill()

        if not beam:
            raise RuntimeError, 'beam not found'

        # need to get the beam centre recorded in the image header for the
        # first image here

        header = get_header(self.images[0])
        x, y = header['beam_x'], header['beam_y']

        # this should be sent to the messenger
        Messenger.log_write('Labelit beam: %6.2f %6.2f' % beam)
        Messenger.log_write('Header beam:  %6.2f %6.2f' % (x, y))
        offset = math.sqrt(((beam[0] - x) * (beam[0] - x)) + \
                           ((beam[1] - y) * (beam[1] - y)))

        # presume here that the pixels are square...
        if offset > 2.0 * header['pix_x']:
            Messenger.log_write('Beam centre off by > 2 pixels: %6.2f mm' % \
                                offset)
        else:
            Messenger.log_write('Beam centre ok, out by %6.2f mm' % \
                                offset)

        return beam

def LabelitBeam(first, second):
    '''Use labelit to compute the direct beam from these images.'''
    lbw = LabelitBeamWrapper()
    lbw.setImages(first, second)
    return lbw.run()

if __name__ == '__main__':

    if len(sys.argv) == 1:
        staticUnitTest()

    else:

        if len(sys.argv) == 3:
            ideal = float(sys.argv[2])
        else:
            ideal = 3.0

        Messenger.log_write('Spag: %s (%s) Conf: %d' % \
                            dynamicUnitTest(sys.argv[1],
                                            True, ideal = ideal))
