__revision__ = '$Revision: 1.5 $'

import sys
import MySQLdb
import string
import logging

class ISPyB:
    _typeDict={'int':'integer', 'varchar':'string', 'datetime':'string',\
        'tinyint':'boolean', 'float':'double', 'timestamp':'string',\
        'longtext':'string', 'enum':'string', 'date':'string',\
        'double':'double', 'tinytext':'string'}
    _fieldType={}
    _attList={}

    CONTAINER_STATUS='Processing'

    def __init__(self,host,user,passwd,db):
        self.db = self.Connect(host=host,user=user,passwd=passwd,db=db)
        self.cursor = self.db.cursor()

    def _getAttList(self,tName):
        logging.getLogger().debug("ISPyB: _getAttList %s" % tName)
        try:
            att_list=ISPyB._attList[tName]
        except KeyError:
            sql = "DESCRIBE " + tName
            rs = self.doQuery(sql)
            rows = []
            types = []
            firstkey = 0
            pkname=None
            for row in range(len(rs)):
                rows.append(rs[row][0])
                types.append(rs[row][1])

                if rs[row][3]=='PRI' and firstkey == 0:
                    pkeyname = rs[row][0]
                    firstkey = 1
                if rs[row][3]=='PRI' and rs[row][5] == 'auto_increment':
                    pkeyname = rs[row][0]

            att_list=(rows,types,pkeyname)
            ISPyB._attList[tName]=att_list
        return att_list

    def _getPrimaryKey(self,tName):
        logging.getLogger().debug("ISPyB: _getPrimaryKey %s" % tName)
        att_list=self._getAttList(tName)
        return att_list[2]

    def _formatFieldValue(self,tName,field,value):
        logging.getLogger().debug("ISPyB: _formatFieldValue %s %s %s" % (tName,field,value))
        try:
            field_type=ISPyB._fieldType[tName][field]
        except KeyError:
            table_att_list=self._getAttList(tName)
            field_dict={}
            for i in range(len(table_att_list[0])):
                f_name=table_att_list[0][i]
                f_type=table_att_list[1][i]
                new_type=None
                for k in ISPyB._typeDict.keys():
                    if f_type.startswith(k):
                        new_type=ISPyB._typeDict[k]
                        break
                if new_type is not None:
                    field_dict[f_name]=new_type
                else:
                    logging.getLogger().debug("ISPyB: unknown type convertion for field '%s' in table '%s'" % (field,tName))
                    field_dict[f_name]="unknown"
                    #raise TypeError
                ISPyB._fieldType[tName]=field_dict
            field_type=ISPyB._fieldType[tName][field]
        if field_type=="string":
            return "'%s'" % value
        else:
            return "%s" % value
        
    def doQuery(self,querystr):
        logging.getLogger().debug("ISPyB: %s" % querystr)
        try:
            self.cursor.execute(querystr)
        except Exception,msg:
            logging.getLogger().error("ISPyB: exception running sql statement (%s); type=%s" %\
                (str(msg),sys.exc_type))

            raise
        rows=self.cursor.fetchall()
        logging.getLogger().debug("ISPyB: after cursor.fetchall")
        return rows

    def _queryResultToDictList(self,table,queryresult):
        keylist = self._getAttList(table)[0]
        __resultList = []
        for result in queryresult:
            idx = 0
       	    __resultDict = {}
            for key in keylist:
                __resultDict[key] = result[idx]
                idx += 1
            __resultList.append(__resultDict)
        return __resultList

    def _doInsertQuery(self,table,keyvalues):
        fieldlist = ''
        valuelist = ''
        for key in keyvalues.keys():
            fieldlist += key + ','
            valuelist += "%s," % self._formatFieldValue(table,key,keyvalues[key])
        stmt = "INSERT into %s (%s) values (%s)" % (table,\
                fieldlist[0:len(fieldlist)-1],valuelist[0:len(valuelist) -1])
        logging.getLogger().debug("ISPyB: LOCK TABLE %s write;"%table)
        self.cursor.execute("LOCK TABLE %s write;"%table)
        try:
            queryResult = self.doQuery(stmt)
        except Exception,ex:
            logging.getLogger().debug("ISPyB: UNLOCK TABLES;")
            self.cursor.execute("UNLOCK TABLES;")
            raise
        try:
            insert_id = self.db.insert_id()
        except:
            insert_id = self.db._db.insert_id()
        logging.getLogger().debug("ISPyB: UNLOCK TABLES;")
        self.cursor.execute("UNLOCK TABLES;")
        return insert_id

    def _doUpdateQuery(self,table,keyvalues,condvalues):
        update_fields=""
        cond_fields=""
        for key in keyvalues:
            update_fields += "%s=%s," % (key,\
                self._formatFieldValue(table,key,keyvalues[key]))
        for key in condvalues:
            cond_fields += "%s=%s," % (key,\
                self._formatFieldValue(table,key,condvalues[key]))

        stmt = "UPDATE %s SET %s WHERE %s" % (table,\
            update_fields[0:len(update_fields)-1],\
            cond_fields[0:len(cond_fields)-1])
        queryResult = self.doQuery(stmt)

    def Connect(self,host,user,passwd,db):
        return MySQLdb.Connect(host,user,passwd,db)

    def insertRecord(self,table,dict):
        return (self._getPrimaryKey(table),self._doInsertQuery(table,dict))

    def updateRecord(self,table,fields_dict,cond_dict):
        return self._doUpdateQuery(table,fields_dict,cond_dict)

    def getRecord(self,table,key,value):
        stmt = "SELECT * from %s where %s = %d" % (table,key,value)
        queryResult=self.doQuery(stmt)
        return self._queryResultToDictList(table,queryResult)

    def getProposals(self,code,number):
        stmt = "SELECT * from Proposal where Code = '%s' and Number = %d" % (code,number)
        queryResult=self.doQuery(stmt)
        return self._queryResultToDictList('Proposal',queryResult)

    def getLoadedSamples(self,proposal_key,sample_references):
        responseList = []
        for sample_reference in sample_references:
            stmt = "SELECT * FROM `BLSample` \
                LEFT JOIN Crystal ON BLSample.CrystalId=Crystal.CrystalId \
                LEFT JOIN Protein ON Crystal.ProteinId=Protein.ProteinId \
                LEFT JOIN DiffractionPlan ON Crystal.diffractionPlanId=DiffractionPlan.diffractionPlanId \
                LEFT JOIN Container ON BLSample.containerId=Container.containerId \
                WHERE Protein.proposalId = %d AND \
                BLSample.code='%s' AND \
                (Container.containerStatus LIKE '%s' OR BLSample.blSampleStatus LIKE '%s')" \
                % (proposal_key,sample_reference['code'].strip(),ISPyB.CONTAINER_STATUS,\
                ISPyB.CONTAINER_STATUS)

            result = self.doQuery(stmt)
            # should be one result, if not then there are multiple entries for one sample_reference/proposal
            proteinkeylist =  self._getAttList('Protein')[0]
            crystalkeylist = self._getAttList('Crystal')[0]
            blsamplekeylist = self._getAttList('BLSample')[0]
            diffractionplankeylist = self._getAttList('DiffractionPlan')[0]
            containerkeylist = self._getAttList('Container')[0]

            if len(result)>1:
                logging.getLogger().debug("ISPyB: sample %s has multiple entries in proposal %d" %\
                    (sample_reference['code'].strip(),proposal_key))

            for sample in result:
                tmpDict = {}
                keyidx = 0
                for key in blsamplekeylist:
                    tmpDict[key]=sample[keyidx]
                    keyidx+=1
                tmpDict2 = {}
                for key in crystalkeylist:
                    tmpDict2[key]=sample[keyidx]
                    keyidx+=1
                tmpDict3 = {}
                for key in proteinkeylist:
                    tmpDict3[key]=sample[keyidx]
                    keyidx+=1
                tmpDict4 = {}
                for key in diffractionplankeylist:
                    tmpDict4[key]=sample[keyidx]
                    keyidx+=1
                tmpDict5 = {}
                for key in containerkeylist:
                    tmpDict5[key]=sample[keyidx]
                    keyidx+=1
                dict = {}
                tmpDict6 = {}
                if tmpDict['diffractionPlanId'] > 0:
                    # look up the diffraction plan for the sample if there is one
                    stmt = "SELECT * FROM `DiffractionPlan` WHERE \
                        DiffractionPlan.diffractionPlanId = %d" % tmpDict['diffractionPlanId']
                    result = self.doQuery(stmt)
                    keyidx = 0
                    for key in diffractionplankeylist:
                        tmpDict6[key]=result[0][keyidx]
                        keyidx+=1
                dict['diffractionplan_blsample'] = tmpDict6
                
                

                dict['blsample'] = tmpDict
                dict['crystal'] = tmpDict2
                dict['protein'] = tmpDict3
                dict['diffractionplan_crystaltype'] = tmpDict4
                dict['container'] = tmpDict5
                responseList.append(dict)
 
        return(responseList)

    def getSessionSamples(self,proposal_key):
        responseList = []

        # Confirms that every (Protein,Crystal) tuple has at least one BLSample
        crystal_list = self.doQuery("SELECT \
            Crystal.CrystalId,Crystal.Name FROM Crystal,Protein WHERE \
            Crystal.ProteinId=Protein.ProteinId AND \
            Protein.proposalId=%d" % proposal_key)
        for crystal in crystal_list:
            crystal_id=int(crystal[0])
            crystal_name=crystal[1]
            blsample_list=self.doQuery("SELECT \
                BLSample.BLSampleId \
                FROM BLSample,Crystal,Protein \
                WHERE BLSample.CrystalId=%d \
                AND Crystal.ProteinId=Protein.ProteinId \
                AND Protein.proposalId=%d" % (crystal_id,proposal_key))
            if len(blsample_list)==0:
                container_list = self.doQuery("SELECT \
                    Container.containerId FROM Container WHERE \
                    Container.proposalId=%d AND \
                    Container.containerStatus='%s'" \
                    % (proposal_key,ISPyB.CONTAINER_STATUS))
                if len(container_list)==0:
                    temp_dict={'proposalId':proposal_key, 'containerStatus':ISPyB.CONTAINER_STATUS}
                    container_id=self.insertRecord('Container',temp_dict)[1]
                else:
                    container_id=container_list[0][0]

                temp_dict={'name':crystal_name, 'crystalId':crystal_id, 'containerId':container_id, 'blSampleStatus':ISPyB.CONTAINER_STATUS}
                blsample_id=self.insertRecord('BLSample',temp_dict)[1]

        stmt = "SELECT * FROM `BLSample` \
            left join Crystal on BLSample.CrystalId=Crystal.CrystalId \
            left join Protein on Crystal.ProteinId=Protein.ProteinId \
            left join Container on BLSample.containerId=Container.containerId \
            where Protein.proposalId = %d and \
            Container.containerStatus like '%s'" % (proposal_key,ISPyB.CONTAINER_STATUS)

        result = self.doQuery(stmt)
        # should be one result, if not then there are multiple entries for one sample_reference/proposal
        proteinkeylist =  self._getAttList('Protein')[0]
        crystalkeylist = self._getAttList('Crystal')[0] 
        blsamplekeylist = self._getAttList('BLSample')[0] 

        for sample in result:
            tmpDict = {}
            keyidx = 0
            for key in blsamplekeylist:
                tmpDict[key]=sample[keyidx]
                keyidx+=1
            tmpDict2 = {}
            for key in crystalkeylist:
                tmpDict2[key]=sample[keyidx]
                keyidx+=1
            tmpDict3 = {}
            for key in proteinkeylist:
                tmpDict3[key]=sample[keyidx]
                keyidx+=1
            dict = {}
            dict['blsample'] = tmpDict
            dict['crystal'] = tmpDict2
            dict['protein'] = tmpDict3
            responseList.append(dict)
 
        return(responseList)

    def getPersonsInSession(self,session_id,name=None,role=None):
        responseList = []
        stmt2=""
        if role is not None:
            stmt2=" AND Session_has_Person.role='%s'" % role
        stmt3=""
        if name is not None:
            stmt3=" AND Person.givenName='%s'" % name
        stmt = "SELECT * FROM Person JOIN Session_has_Person WHERE \
            Person.personId=Session_has_Person.personId AND \
            Session_has_Person.sessionId=%d %s %s" % (session_id,stmt2,stmt3)
        queryResult=self.doQuery(stmt)
        return self._queryResultToDictList("Person",queryResult)

    def getPersonFromLab(self,person_name,lab_name,lab_country):
        stmt = "SELECT * FROM Person JOIN Laboratory WHERE \
            Person.givenName='%s' AND \
            Laboratory.name='%s' AND \
            Laboratory.country='%s' AND \
            Person.laboratoryId=Laboratory.laboratoryId" % (person_name,lab_name,lab_country)
        queryResult=self.doQuery(stmt)
        return self._queryResultToDictList("Person",queryResult)

    def getRecordByFields(self,table,fields,values,order=None,asc=True):
        stmt = "SELECT * FROM %s WHERE " % table
        for i in range(len(fields)):
            f=fields[i]
            v=values[i]
            if i>0:
                stmt += " AND "
            stmt += " %s=%s " % (f,self._formatFieldValue(table,f,v))

        if order is not None:
            stmt += " ORDER BY %s" % order
            if not asc:
                stmt += " DESC"

        queryResult=self.doQuery(stmt)
        return self._queryResultToDictList(table,queryResult)
