#!/usr/bin/env python
# Crystallographer.py
# Maintained by G.Winter
# 11th february 2004
# An implementation of a crystallographer for the second generation 
# Expert system for DNA.
# 
# This will implement common crystallographic methods, so that we can 
# compare the results with what Mosflm comes up with.
# 
# $Id: Crystallographer.py,v 1.4 2004/10/13 15:44:11 gwin Exp $

# we'll have plenty of sums to be doing so
import math

# and I'll want my mathematical/linear algebra hat on so
import Vector

# Definitions:
# There are a few common definitions of things which I will use - 
# with Vectors being a list of numbers and Matrices being a list of
# lists of numbers. See Vector.py.
# 
# a_star, b_star, c_star are reciprocal lattice vectors. A is a matrix
# made of these reciprocal lattice vectors. Niggli is a matrix made of
# a.a, b.b, c.c
# b.c, c.a, a.b
# 

def Equal(a, b):
    eps = 1.0e-5
    if math.fabs(b - a) < eps:
        return True
    else:
        return False

def ReduceCell(a_star, b_star, c_star):
    '''Reduce a cell in the unstable way described by Kim,
    J. Appl. Cryst. (1989)'''

    # this is not yet properly implemented

    # reduce the c_star axis

    normal = Vector.Unit(Vector.Cross(a_star, b_star))
    c_dot_n = Vector.Dot(c_star, normal)
    c_perp = Vector.MultiplyScalarVector(c_dot_n, normal)
    c_para = Vector.Subtract(c_star, c_perp)

    l = Vector.Unit(Vector.Cross(normal, b_star))
    m = Vector.Unit(Vector.Cross(normal, a_star))

    p = Vector.Dot(c_para, l) / Vector.Dot(a_star, l)
    q = Vector.Dot(c_para, m) / Vector.Dot(b_star, m)

    # now reduce a_star

    normal = Vector.Unit(Vector.Cross(b_star, c_star))
    a_dot_n = Vector.Dot(a_star, normal)
    a_perp = Vector.MultiplyScalarVector(a_dot_n, normal)
    a_para = Vector.Subtract(a_star, a_perp)

    l = Vector.Unit(Vector.Cross(normal, c_star))
    m = Vector.Unit(Vector.Cross(normal, b_star))

    p = Vector.Dot(a_para, l) / Vector.Dot(b_star, l)
    q = Vector.Dot(a_para, m) / Vector.Dot(c_star, m)

    # finally reduce b_star
    normal = Vector.Unit(Vector.Cross(a_star, c_star))
    b_dot_n = Vector.Dot(b_star, normal)
    b_perp = Vector.MultiplyScalarVector(b_dot_n, normal)
    b_para = Vector.Subtract(b_star, b_perp)

    l = Vector.Unit(Vector.Cross(normal, c_star))
    m = Vector.Unit(Vector.Cross(normal, a_star))

    p = Vector.Dot(b_para, l) / Vector.Dot(a_star, l)
    q = Vector.Dot(b_para, m) / Vector.Dot(c_star, m)

    # note - at this moment there is no cell reduction going on...

def Invert(a, b, c):
    '''Generate the reciprocal (i.e. other space) cell'''

    V = Vector.Dot(a, Vector.Cross(b, c))

    a_star = Vector.MultiplyScalarVector(1.0 / V, Vector.Cross(b, c))
    b_star = Vector.MultiplyScalarVector(1.0 / V, Vector.Cross(c, a))
    c_star = Vector.MultiplyScalarVector(1.0 / V, Vector.Cross(a, b))

    return a_star, b_star, c_star

def InvertCell(a, b, c, alpha, beta, gamma):
    '''As below, but using matrix manipulations'''
    Bmatrix = BMatrix(a, b, c, alpha, beta, gamma)
    # since the Bmatrix is only a special case of a real solution...
    a_star, b_star, c_star = Invert(Bmatrix[0], Bmatrix[1], Bmatrix[2])
    return Cell(a_star, b_star, c_star)

def InvertCellOld(a, b, c, alpha, beta, gamma):
    pi = 4.0 * math.atan(1.0)
    dtor = pi / 180.0
    rtod = 180.0 / pi
    ca = math.cos(dtor * alpha)
    cb = math.cos(dtor * beta)
    cc = math.cos(dtor * gamma)
    sa = math.sin(dtor * alpha)
    sb = math.sin(dtor * beta)
    sc = math.sin(dtor * gamma)
    # BUG! Missed out a sqrt here

    # BUG! This can throw a math domain error sometimes - which means that
    # the definitions are of limited use...
    V = a * b * c * math.sqrt(1 - ca * ca - cb * cb - cc * cc + \
                              2 * ca * cb * cc)

    a_star = b * c * sa / V
    b_star = a * c * sb / V
    c_star = a * b * sc / V

    val_a = V / (a * b * c * sb * sc)
    val_b = V / (a * b * c * sa * sc)
    val_c = V / (a * b * c * sa * sb)

    # to get around math domain errors from rounding

    if val_a > 1.0:
        val_a = 1.0
    if val_a < -1.0:
        val_a = -1.0

    if val_b > 1.0:
        val_b = 1.0
    if val_b < -1.0:
        val_b = -1.0

    if val_c > 1.0:
        val_c = 1.0
    if val_c < -1.0:
        val_c = -1.0

    alpha_star = rtod * math.asin(val_a)
    beta_star = rtod * math.asin(val_b)
    gamma_star = rtod * math.asin(val_c)

    return a_star, b_star, c_star, alpha_star, beta_star, gamma_star


def Type(a, b, c):
    T = Vector.Dot(a, b) * Vector.Dot(b, c) * Vector.Dot(c, a)
    if T > 0:
        return 1
    else:
        return 2

def Niggli(a, b, c):
    niggli = []
    niggli.append(Vector.Dot(a, a))
    niggli.append(Vector.Dot(b, b))
    niggli.append(Vector.Dot(c, c))
    niggli.append(Vector.Dot(b, c))
    niggli.append(Vector.Dot(c, a))
    niggli.append(Vector.Dot(a, b))
    return niggli

def TypePenalty(type, a, b, c):
    penalty = 0.0
    n = Niggli(a, b, c)

    A = n[0]
    B = n[1]
    C = n[2]
    D = n[3]
    E = n[4]
    F = n[5]

    if A > B:
        penalty += A - B

    if B > C:
        penalty += B - C

    if 2 * math.fabs(D) > B:
        penalty += 2 * math.fabs(D) - B

    if 2 * math.fabs(E) > A:
        penalty += 2 * math.fabs(E) - A

    if 2 * math.fabs(F) > A:
        penalty += 2 * math.fabs(F) - A

    if type == 1:
        penalty += Type1Penalty(A, B, C, D, E, F)
    else:
        penalty += Type2Penalty(A, B, C, D, E, F)

    return penalty    

def Type1Penalty(A, B, C, D, E, F):
    penalty = 0.0
    
    if D > 0:
        penalty += D

    if E > 0:
        penalty += E

    if F > 0:
        penalty == F

    special = False
    if special:

        if Equal(A, B) and D > E:
            penalty += D - E

        if Equal(B, C) and F > E:
            penalty += F - E

        if Equal(2 * D, B) and F > 2 * E:
            penalty += F - 2 * E

        if Equal(2 * E, A) and F > 2 * D:
            penalty += F - 2 * D

        if Equal(2 * F, A) and E > 2 * D:
            penalty += E - 2 * D

    return penalty

def Type2Penalty(A, B, C, D, E, F):
    penalty = 0.0
    
    if 2 * (math.fabs(D) + math.fabs(E) + math.fabs(F)) > \
       A + B:
        penalty +=  2 * (math.fabs(D) + math.fabs(E) + math.fabs(F)) - \
                   (A + B)

    if D > 0:
        penalty += D

    if E > 0:
        penalty += E

    if F > 0:
        penalty += F

    special = False
    if special:

        if Equal(A, B) and math.fabs(D) > math.fabs(E):
            penalty += math.fabs(D) - math.fabs(E)

        if Equal(B, C) and math.fabs(E) > math.fabs(F):
            penalty += math.fabs(E) - math.fabs(F)

        if Equal(2 * D, B):
            penalty += math.fabs(F)

        if Equal(2 * E, A):
            penalty += math.fabs(F)

        if Equal(2 * F, A):
            penalty += math.fabs(E)

        if Equal(2 * (math.fabs(D) + math.fabs(E) + math.fabs(F)),
                 A + B) and A > (2 * E + F):
            penalty += A - 2 * E - F

    return penalty

def GroupPenalty(group, a, b, c):
    penalty = 0.0
    n = Niggli(a, b, c)

    A = n[0]
    B = n[1]
    C = n[2]
    D = n[3]
    E = n[4]
    F = n[5]

    if group == 1:
        penalty = math.fabs(B - A) + math.fabs(C - B)
    elif group == 2:
        penalty = math.fabs(B - A)
    elif group == 3:
        penalty = math.fabs(C - B)
    elif group == 4:
        penalty = 0.0

    return penalty
        
def Penalty(lattice, a, b, c):
    n = Niggli(a, b, c)

    A = n[0]
    B = n[1]
    C = n[2]
    D = n[3]
    E = n[4]
    F = n[5]

    scale = A + B + C
    
    if lattice == 1:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(1, a, b, c)
        penalty += math.fabs(2 * D - A)
        penalty += math.fabs(2 * E - A)
        penalty += math.fabs(2 * F - A)
    elif lattice == 2:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(1, a, b, c)
        penalty += math.fabs(E - D)
        penalty += math.fabs(F - D)
    elif lattice == 3:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(1, a, b, c)
        penalty += math.fabs(D)
        penalty += math.fabs(E)
        penalty += math.fabs(F)
    elif lattice == 4:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(1, a, b, c)
        penalty += math.fabs(E - D)
        penalty += math.fabs(F - D)
    elif lattice == 5:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(1, a, b, c)
        penalty += math.fabs(3 * D + A)
        penalty += math.fabs(3 * E + A)
        penalty += math.fabs(3 * F + A)
    elif lattice == 6:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(1, a, b, c)
        penalty += math.fabs(2 * math.fabs(D + E + F) - (A + B))
        penalty += math.fabs(E - D)
    elif lattice == 7:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(1, a, b, c)
        penalty += math.fabs(2 * math.fabs(D + E + F) - (A + B))
        penalty += math.fabs(F - E)
    elif lattice == 8:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(1, a, b, c)
        penalty += math.fabs(2 * math.fabs(D + E + F) - (A + B))
    elif lattice == 9:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(2, a, b, c)
        penalty += math.fabs(2 * D - A)
        penalty += math.fabs(2 * E - A)
        penalty += math.fabs(2 * F - A)
    elif lattice == 10:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(2, a, b, c)
        penalty += math.fabs(E - D)
    elif lattice == 11:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(2, a, b, c)
        penalty += math.fabs(D)
        penalty += math.fabs(E)
        penalty += math.fabs(F)
    elif lattice == 12:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(2, a, b, c)
        penalty += math.fabs(D)
        penalty += math.fabs(E)
        penalty += math.fabs(2 * F + A)
    elif lattice == 13:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(2, a, b, c)
        penalty += math.fabs(D)
        penalty += math.fabs(E)
    elif lattice == 14:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(2, a, b, c)
        penalty += math.fabs(E - D)
    elif lattice == 15:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(2, a, b, c)
        penalty += math.fabs(2 * D + A)
        penalty += math.fabs(2 * E + A)
        penalty += math.fabs(F)
    elif lattice == 16:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(2, a, b, c)
        penalty += math.fabs(2 * math.fabs(D + E + F) - (A + B))
        penalty += math.fabs(E - D)
    elif lattice == 17:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(2, a, b, c)
        penalty += math.fabs(2 * math.fabs(D + E + F) - (A + B))
    elif lattice == 18:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(3, a, b, c)
        penalty += math.fabs(4 * D - A)    
        penalty += math.fabs(2 * E - A)
        penalty += math.fabs(2 * F - A)
    elif lattice == 19:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(3, a, b, c)
        penalty += math.fabs(2 * E - A)
        penalty += math.fabs(2 * F - A)
    elif lattice == 20:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(3, a, b, c)
        penalty += math.fabs(F - E)
    elif lattice == 21:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(3, a, b, c)
        penalty += math.fabs(D)
        penalty += math.fabs(E)
        penalty += math.fabs(F)
    elif lattice == 22:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(3, a, b, c)
        penalty += math.fabs(2 * D + B)
        penalty += math.fabs(E)
        penalty += math.fabs(F)
    elif lattice == 23:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(3, a, b, c)
        penalty += math.fabs(E)
        penalty += math.fabs(F)
    elif lattice == 24:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(3, a, b, c)
        penalty += math.fabs(2 * math.fabs(D + E + F) - (A + B))
        penalty += math.fabs(3 * E + A)
        penalty += math.fabs(3 * F + A)
    elif lattice == 25:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(3, a, b, c)
        penalty += math.fabs(F - E)
    elif lattice == 26:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(4 * D - A)
        penalty += math.fabs(2 * E - A)
        penalty += math.fabs(2 * F - A)
    elif lattice == 27:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(2 * E - A)
        penalty += math.fabs(2 * F - A)
    elif lattice == 28:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(2 * E - A)
        penalty += math.fabs(2 * D - E)
    elif lattice == 29:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(2 * D - E)
        penalty += math.fabs(2 * F - A)
    elif lattice == 30:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(2 * D - B)
        penalty += math.fabs(2 * E - F)
    elif lattice == 31:
        penalty = TypePenalty(1, a, b, c) + GroupPenalty(4, a, b, c)
    elif lattice == 32:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(D)
        penalty += math.fabs(E)
        penalty += math.fabs(F)
    elif lattice == 33:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(D)
        penalty += math.fabs(F)
    elif lattice == 34:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(D)
        penalty += math.fabs(E)
    elif lattice == 35:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(E)
        penalty += math.fabs(F)
    elif lattice == 36:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(D)
        penalty += math.fabs(2 * E + A)
        penalty += math.fabs(F)
    elif lattice == 37:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(2 * E + A)
        penalty += math.fabs(F)
    elif lattice == 38:        
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(D)
        penalty += math.fabs(E)
        penalty += math.fabs(2 * F + A)
    elif lattice == 39:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(E)
        penalty += math.fabs(2 * F + A)
    elif lattice == 40:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(2 * D + B)
        penalty += math.fabs(E)
        penalty += math.fabs(F)
    elif lattice == 41:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(2 * D + B)
        penalty += math.fabs(F)
    elif lattice == 42:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(2 * D + B)
        penalty += math.fabs(2 * E + A)
        penalty += math.fabs(F)
    elif lattice == 43:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)
        penalty += math.fabs(2 * math.fabs(D + E + F) - (A + B))
        penalty += math.fabs(math.fabs(2 * D + F) - B)
    elif lattice == 44:
        penalty = TypePenalty(2, a, b, c) + GroupPenalty(4, a, b, c)

    return penalty / scale

def Lattice(number):
    lattices = ['cF', 'hR', 'cP', 'hR', 'cI', 'tI', 'tI', 'oI', \
                'hR', 'mC', 'tP', 'hP', 'oC', 'mC', 'tI', 'oF', 'mC', \
                'tI', 'oI', 'mC', 'tP', 'hP', 'oC', 'hR', 'mC', \
                'oF', 'mC', 'mC', 'mC', 'mC', 'aP', 'oP', \
                'mP', 'mP', 'mP', 'oC', 'mC', 'oC', 'mC', \
                'oC', 'mC', 'oI', 'mI', 'aP']

    return lattices[number - 1]

def LType(number):
    types = [1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, \
             1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, \
             2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

    return types[number - 1]

def Lattice(spacegroup):
    for lattice in ['aP', 'mP', 'mC', 'oP', 'oC', 'tP', 'hP', 'oF',
                    'hR', 'cP', 'oI', 'tI', 'cF', 'cI', 'mI']:
        if spacegroup in Spacegroups(lattice):
            return lattice

def Sublattices(lattice):
    if lattice == 'aP':
        return ['aP']
    if lattice == 'mP':
        return ['mP', 'aP']
    if lattice == 'oP':
        return ['oP', 'mP', 'aP']
    if lattice == 'oC':
        return ['oC', 'mC', 'aP']
    if lattice == 'oI':
        return ['oI', 'aP']
    if lattice == 'oF':
        return ['oF', 'aP']
    if lattice == 'tP':
        return ['tP', 'oP', 'mP', 'aP']
    if lattice == 'tI':
        return ['tI', 'oI', 'aP']
    if lattice == 'hR':
        return ['hR', 'aP']
    if lattice == 'hP':
        return ['hP', 'oC', 'mC', 'aP']
    if lattice == 'cP':
        return ['cP', 'tP', 'oP', 'mP', 'aP']#
    if lattice == 'cI':
        return ['cI', 'tI', 'oI', 'aP']
    if lattice == 'cF':
        return ['cF', 'oF', 'aP']

    return []

def Spacegroups(lattice):
    if lattice == "aP":
        return ["p1"]
    elif lattice == "mP":
        return ["p2", "p21"]
    elif lattice == "oP":
        return ["p222", "p2221", "p21212", "p212121"]
    elif lattice == "mC":
        return ["c2"]
    elif lattice == "oC":
        return ["c222", "c2221"]
    elif lattice == "tP":
        return ["p4", "p41", "p42", "p43", "p422", "p4212", "p4122", \
                "p41212", "p4222", "p42212", "p4322", "p43212"]
    elif lattice == "hP":
        return ["p3", "p31", "p32", "p312", "p3112", "p3121", "p3212", \
                "p3221", "p6", "p61", "p65", "p62", "p64", "p63", \
                "p622", "p6122", "p6522", "p6222", "p6422", \
                "p6322"]
    elif lattice == "oF":
        return ["f222"]
    elif lattice == "hR":
        return ["h3", "h32"]
    elif lattice == "cP":
        return ["p23", "p213", "p432", "p4232", "p4332", "p4123"]
    elif lattice == "oI":
        return ["i222", "i212121"]
    elif lattice == "tI":
        return ["i4", "i41", "i422", "i4122"]
    elif lattice == "cF":
        return ["f23", "f432", "f4132"]
    elif lattice == "cI":
        return ["i23", "i213", "i432", "i4132"]
    elif lattice == "mI":
        return ["c2"]
    else:
        return ""


class Solution:
    def __init__(self, penalty, lattice, type, number):
        self.penalty = penalty
        self.lattice = lattice
        self.type = type
        self.number = number

    def __cmp__(self, other):
        return cmp(self.penalty, other.penalty)

    def __str__(self):
        return str(self.penalty) + " " + str(self.lattice)

def Cell(A, B, C):
    pi = 4.0 * math.atan(1.0)
    dtor = pi / 180.0
    rtod = 180.0 / pi

    a = math.sqrt(Vector.Dot(A, A))
    b = math.sqrt(Vector.Dot(B, B))
    c = math.sqrt(Vector.Dot(C, C))
    alpha = rtod * math.acos(Vector.Dot(B, C) / (b * c))
    beta = rtod * math.acos(Vector.Dot(A, C) / (a * c))
    gamma = rtod * math.acos(Vector.Dot(A, B) / (a * b))
    return a, b, c, alpha, beta, gamma

def Generate(a_star, b_star, c_star):
    ReduceCell(a_star, b_star, c_star)

    a, b, c = Invert(a_star, b_star, c_star)

    solutions = []

    for i in range(1, 45):
        p = Penalty(i, a, b, c)
        solutions.append(Solution(p, Lattice(i), LType(i), i))

    solutions.sort()

    s = Pick(solutions)
    print '======================================='
    print 'PICKED: ', s.penalty, s.type, s.lattice, s.number
    print Spacegroups(s.lattice)[0]
    print '======================================='

    solutions.reverse()
    for s in solutions:
        print s.penalty, s.lattice, s.type, s.number, Spacegroups(s.lattice)[0]

    return

def Constrain(lattice, cell):
    '''where cell is defined as a, b, c, alpha, beta, gamma'''

    a, b, c, alpha, beta, gamma = cell

    pi = 4.0 * math.atan(1.0)
    dtor = pi / 180.0
    rtod = 180.0 / pi

    if lattice == 'aP':
        return a, b, c, alpha, beta, gamma
    elif lattice == 'mP':
        alpha = 90.0
        gamma = 90.0
        return a, b, c, alpha, beta, gamma
    elif lattice == 'mC':
        average = 0.5 * (alpha + beta)
        alpha = average
        beta = average
        average = 0.5 * (a + b)
        a = average
        b = average
        return a, b, c, alpha, beta, gamma
    elif lattice == 'oP':
        alpha = 90.0
        beta = 90.0
        gamma = 90.0
        return a, b, c, alpha, beta, gamma
    elif lattice == 'oC':
        alpha = 90.0
        beta = 90.0
        average = 0.5 * (a + b)
        a = average
        b = average
        return a, b, c, alpha, beta, gamma
    elif lattice == 'oI':
        average = 0.3333333333 * (a + b + c)
        a = average
        b = average
        c = average
        # now have to worry about cos(alpha) + cos(beta) + cos(gamma) = -1..
        # this needs to be fixed - NOW! potential BUG!
        sum = math.cos(dtor * alpha) + \
              math.cos(dtor * beta) + \
              math.cos(dtor * gamma)
        if math.fabs(sum) > 0.1:
            scale = -1.0 / sum
            alpha = rtod * math.acos(scale * math.cos(dtor * alpha))
            beta = rtod * math.acos(scale * math.cos(dtor * beta))
            gamma = rtod * math.acos(scale * math.cos(dtor * gamma))
        else:
            angle = rtod * math.acos(-1.0 / 3.0)
            alpha = angle
            beta = angle
            gama = angle
        return a, b, c, alpha, beta, gamma
    elif lattice == 'oF':
        # this one is a worry - I feel there may be a typo in the
        # international tables volume A
        alpha = rtod * math.acos((- a * a + b * b + c * c) / (2 * b * c))
        beta = rtod * math.acos((a * a - b * b + c * c) / (2 * a * c))
        gamma = rtod * math.acos((a * a + b * b - c * c) / (2 * a * b))
        return a, b, c, alpha, beta, gamma
    elif lattice == 'tP':
        # At this stage we should worry about permutations of the
        # a, b, c axes, since we could have a, b = c not a = b, c
        # BUG!
        alpha = 90.0
        beta = 90.0
        gamma = 90.0
        average = 0.5 * (a + b)
        a = average
        b = average
        return a, b, c, alpha, beta, gamma
    elif lattice == 'tI':
        average = 0.3333333333 * (a + b + c)
        a = average
        b = average
        c = average
        # now have to worry about cos(alpha) + cos(beta) + cos(gamma) = -1..
        average = 0.5 * (alpha + beta)
        alpha = average
        beta = average
        sum = math.cos(dtor * alpha) + \
              math.cos(dtor * beta) + \
              math.cos(dtor * gamma)
        if math.fabs(sum) > 0.1:
            scale = -1.0 / sum
            alpha = rtod * math.acos(scale * math.cos(dtor * alpha))
            beta = rtod * math.acos(scale * math.cos(dtor * beta))
            gamma = rtod * math.acos(scale * math.cos(dtor * gamma))
        else:
            angle = rtod * math.acos(-1.0 / 3.0)
            alpha = angle
            beta = angle
            gama = angle
        return a, b, c, alpha, beta, gamma
    elif lattice == 'hR':
        average = 0.3333333333 * (a + b + c)
        a = average
        b = average
        c = average
        average = 0.3333333333 * (alpha + beta + gamma)
        alpha = average
        beta = average
        gamma = average
        return a, b, c, alpha, beta, gamma
    elif lattice == 'hP':
        average = 0.5 * (a + b)
        a = average
        b = average
        alpha = 90.0
        beta = 90.0
        gamma = 120.0
        return a, b, c, alpha, beta, gamma
    elif lattice == 'cP':
        alpha = 90.0
        beta = 90.0
        gamma = 90.0
        average = 0.3333333333 * (a + b + c)
        a = average
        b = average
        c = average
        return a, b, c, alpha, beta, gamma
    elif lattice == 'cI':
        angle = rtod * math.acos(-1.0/3.0)
        alpha = angle
        beta = angle
        gamma = angle
        average = 0.3333333333 * (a + b + c)
        a = average
        b = average
        c = average
        return a, b, c, alpha, beta, gamma
    elif lattice == 'cF':
        alpha = 60.0
        beta = 60.0
        gamma = 60.0
        average = 0.3333333333 * (a + b + c)
        a = average
        b = average
        c = average
        return a, b, c, alpha, beta, gamma
    else:
        raise RuntimeError, 'lattice ' + lattice + ' not known'
    
def BMatrix(a_star, b_star, c_star, alpha_star, beta_star, gamma_star):

    rtod = 180.0 / (4.0 * math.atan(1.0))

    dtor = 1.0 / rtod
    
    B = Vector.Matrix()
    B[0][0] = a_star
    B[0][1] = 0.0
    B[0][2] = 0.0

    B[1][0] = b_star * math.cos(dtor * gamma_star)
    B[1][1] = b_star * math.sin(dtor * gamma_star)
    B[1][2] = 0.0

    B[2][0] = c_star * math.cos(dtor * beta_star)
    B[2][1] = c_star * math.sin(dtor * beta_star) * \
              math.cos(dtor * alpha_star)
    B[2][2] = c_star * math.sin(dtor * beta_star) * \
              math.sin(dtor * alpha_star)

    return B
    

def Pick(solutions):
    for i in range(0, len(solutions) - 1):
        this = solutions[i]
        next = solutions[i + 1]

        if this.penalty == 0 and next.penalty > 0.08:
            return this

        if this.penalty > 0.0:
        
            if this.penalty < 0.02 and next.penalty / this.penalty > 5 and \
                   next.penalty - this.penalty > 0.08:
                return this

    return solutions[0]

if __name__ == '__main__':
    a_star = Vector.Vector()
    b_star = Vector.Vector()
    c_star = Vector.Vector()
    
    # this data comes from an A matrix for a real crystal,
    # with wavelength 0.870 angstroms and should be monoclinic centred
    # i.e. mC bravais lattice

    spacegroup = 'P4'

    if spacegroup == 'C2':

        a_star[0] = 0.005719
        b_star[0] = 0.000938
        c_star[0] = 0.009607
        
        a_star[1] = 0.012574
        b_star[1] = -0.002257
        c_star[1] = -0.004409
        
        a_star[2] = -0.004825
        b_star[2] = 0.014405
        c_star[2] = -0.001058

    elif spacegroup == 'P4':
        # this one is from lysozyme at a different wavelength (0.933 angstroms)
        
        a_star[0] = 0.016396
        b_star[0] = 0.000727
        c_star[0] = -0.010073
        
        a_star[1] = 0.017327
        b_star[1] = -0.007960
        c_star[1] = 0.005623
        
        a_star[2] = 0.012917
        b_star[2] = 0.009917
        c_star[2] = 0.005289

    Generate(a_star, b_star, c_star)

    A, B, C = Invert(a_star, b_star, c_star)

    a, b, c, alpha, beta, gamma = Cell(A, B, C)

    lattices = ['aP', 'mP', 'mC', 'oP', 'oC', 'oI', 'oF',
                'tP', 'tI', 'hR', 'hP', 'cP', 'cI', 'cF']

    lattice = 'tP'

    a, b, c, alpha, beta, gamma = Constrain(lattice, \
                                            [a, b, c, alpha, beta, gamma])

    cell = InvertCell(a, b, c, alpha, beta, gamma)

    B = BMatrix(cell[0], cell[1], cell[2], cell[3], cell[4], cell[5])

    print B

    
