import BaseHTTPServer
import os, string
import logging
import XSD
import ISPyB
import pprint
from XML_utils import *

###
### Processes the HTTP requests, and passes them to the request handler
###
class DbHTTPServer:
    def __init__(self,hostname,port,dbhost,dbuser,dbpass,db,**args):
        self.hostname = hostname
        self.port = port
        self.is_shutdown = False

        # Setup the handler class
        server_class=BaseHTTPServer.HTTPServer
        handler_class=DbServerRequestHandler
        handler_class.dbclass  = ISPyB.ISPyB(dbhost,dbuser,dbpass,db)
        handler_class.hostname = hostname
        handler_class.shutdown = self.shutdown
        server_address = (self.hostname, self.port)
        try:
            self.httpd = server_class(server_address,handler_class)
        except:
            logging.getLogger().error("Exception running DB HTTP Server")
            raise

        # Handle the requests
        print "DB HTTP Server running on %s:%d" %(hostname,port)
        logging.debug("DB HTTP Server running on %s:%d" %(hostname,port))
        while not self.is_shutdown:
            self.httpd.handle_request()

    # Stops the handling loop
    def shutdown(self):
        logging.getLogger().warning("DB HTTP Server shutting down...")
        self.is_shutdown = True

###
### Request handler
###
class DbServerRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
    # Handles the request message, executes the requested method,
    # and returns the xml response
    def do_POST(self):
        db_status = XSD.Dbstatus()
        xml_response = ""

        try:
            content_length_string = self.headers.getheader("Content-length")
            if content_length_string is None:
                logging.getLogger().error("DBServer: Received request but content length is None")
                raise "BadRequestError"

            xml_message_length = string.atoi(content_length_string)
            xml_message = self.rfile.read(xml_message_length)
            print xml_message

            # Check the requested method
            if self.path == "/proposal_request":
                logging.getLogger().debug("DBServer: Proposal_request=%r" % xml_message)
                proposal_request = XSD.Proposal()
                proposal_request.unmarshal(xml_message)
                xsd_response = self.db_get_proposal_request(proposal_request)

            elif self.path == "/loaded_samples_request":
                #logging.getLogger().debug("DBServer: loaded_samples_request=%r" % xml_message)
                request = XSD.Loaded_samples_request()
                request.unmarshal(xml_message)
                xsd_response = self.db_get_loaded_samples_request(request)
                
            elif self.path == "/store_datacollection_request":
                logging.getLogger().debug("DBServer: store_datacollection_request=%r" % xml_message)
                request = XSD.Datacollection()
                request.unmarshal(xml_message)
                xsd_response = self.db_store_collect_request(request)

            elif self.path == "/store_object_request":
                logging.getLogger().debug("DBServer: store_object_request=%r" % xml_message)
                objectdict = XML_utils.XMLStringToDict(xml_message)
                #objectdict = request.unmarshal(xml_message)
                objectxsdtype = objectdict.keys()[0]
                objecttype = objectdict.keys()[0]
                exec("request = XSD.%s()"%objectxsdtype)
                request.unmarshal(xml_message)
                xsd_response = self.db_store_object(objecttype,objectdict)

            elif self.path == "/retrieve_object_request":
                logging.getLogger().debug("DBServer: retrieve_object_request=%r" % xml_message)
                objectdict = XML_utils.XMLStringToDict(xml_message)
                #objectdict = request.unmarshal(xml_message)
                objectxsdtype = objectdict.keys()[0]
                objecttype = objectdict.items()[0][1].keys()[0]
                exec("request = XSD.%s()"%objectxsdtype)
                request.unmarshal(xml_message)
                xsd_response = self.db_get_object(objecttype,objectdict)

            else:
                xsd_response = XSD.Dbstatus()
                xsd_response.setCode('error')
                xsd_response.setMessage('Internal server error')
                raise "InternalServerError"

            xml_response=xsd_response.marshal()
            logging.getLogger().debug("DBServer: xml_response=%r" % xml_response)
            self.wfile.write("HTTP/1.1 200 OK\n")
            
        except "InternalServerError":
            db_status.setCode("error")
            db_status.setMessage("Internal server error")
            self.wfile.write("HTTP/1.1 500 Internal Server Error\n")

        except "BadRequestError":
            db_status.setCode("error")
            db_status.setMessage("Bad request error")
            self.wfile.write("HTTP/1.1 400 Bad Request\n")

        if self.path != "/shutdown":
            server_name = self.hostname
            self.wfile.write("Host: %s\n"%(server_name))
            self.wfile.write("Content-type: text/xml\n")
            self.wfile.write("Content-length: %d\n\n"%(len(xml_response)))
            self.wfile.write(xml_response)
        else:
            self.shutdown()

    # Description    : Stores a data collection
    # Argument type  : XSD.Datacollection
    # Tables modified: DataCollection
    # Return type    : XSD.Dbstatus
    def db_store_collect_request(self,request):
        inputParDict = request.toDict()

        # Create a new entry in the DataCollection table        
        status = XSD.Dbstatus()
        try:
            result = DbServerRequestHandler.dbclass.insertRecord('DataCollection',inputParDict)
        except Exception,ex:
            status.setCode('error')
            msg="DbHTTPServer: %s" % str(ex)
            status.setMessage(msg)
            logging.getLogger().error(msg)
        else:
            status.setCode('ok')
            status.setMessage('New entry in DataCollection')
            status.setDatacollectionid(int(result))

        return status

    def db_store_object(self,objecttype,requestdict):
        pprint.pprint(requestdict)
        # Create a new entry in the DataCollection table        
        status = XSD.Dbstatus()
        try:
            result = DbServerRequestHandler.dbclass.insertRecord(objecttype,requestdict[objecttype])
        except Exception,ex:
            status.setCode('error')
            msg=DbServerRequestHandler.dbclass.ermsg
            status.setMessage(msg)
            logging.getLogger().error(msg)
        else:
            status.setCode('ok')
            status.setMessage('New entry inserted into %s, %s' % (objecttype,result))
            #print dir(status)
            #print result
            exec("status.set%s(int(%d))"%(self.capitalizeFirstLetter( result[0] ),int(result[1])))

        return status

    def db_get_object(self,objecttype,requestdict):
        pprint.pprint(requestdict)
        status = XSD.Dbstatus()
        # remove the _object to get the real database object name
        dbobjecttype = objecttype.split('_object')[0]
        # get the name of the primary key (it should be the only element given as a reference
        object_id = requestdict['dbobject'].items()[0][1].keys()[0]
        try:
            result = DbServerRequestHandler.dbclass.getRecord(dbobjecttype,object_id,int(requestdict['dbobject'][objecttype][object_id]))
        except Exception,ex:
            print "************* retrieve object error", ex
            status.setCode('error')
            msg=DbServerRequestHandler.dbclass.ermsg
            status.setMessage(msg)
            logging.getLogger().error(msg)
        else:
            status.setCode('ok')
            status.setMessage('Got results from query.')
        resultObject = XSD.Dbobject()
        for obj in result:
            exec('object = XSD.%s()' % dbobjecttype)
            for element in obj.keys():
                exec('object.set%s("%s")' % (element,obj[element]))
            exec('resultObject.add%s(object)' % objecttype)

        return resultObject

    def db_get_proposal_request(self,request):
    	dict = DbServerRequestHandler.dbclass.getProposals(request.getCode(),request.getNumber())
        proposal_response = XSD.Proposal_response()
        status = XSD.Dbstatus()
        if not dict is None:
            proposal = XSD.Proposal()
            #for el in dict.keys():
            #    exec("proposal.set%s(dict['%s'])"%(el.capitalize(),el))
            dict = dict[0]
            #print dict
            proposal.fromDict(dict)        

            person_dict = DbServerRequestHandler.dbclass.getRecord('Person','personId',dict['personId'])[0]
            person = XSD.Person()
            person.fromDict(person_dict)

            lab_dict = DbServerRequestHandler.dbclass.getRecord('Laboratory','laboratoryId',person_dict['laboratoryId'])[0]
            lab = XSD.Laboratory()
            lab.fromDict(lab_dict)

            session_dict_list = DbServerRequestHandler.dbclass.getRecord('Session', 'proposalId', dict["proposalId"])
            for session_dict in session_dict_list:
                session = XSD.Session()
                session.fromDict(session_dict)
                proposal_response.addSession(session)

            #    #for el in session_dict.keys():
            #    #    exec("session.set%s(session_dict['%s'])"%(el.capitalize(),el))
            #    #proposal_response.addSession(session)
            #
            #    persons_dict_list = DbServerRequestHandler.dbclass.getRecord('Session_has_Person', 'sessionId', session_dict["sessionId"])
            #    session_resp=XSD.Session()
            #    session_resp.setSession(session)
            #    if len(persons_dict_list):
            #        for p in persons_dict_list:
            #            if p['role']=="Local Contact":
            #                contact_dict=DbServerRequestHandler.dbclass.getRecord('Person','personId',p['personId'])[0]
            #                contact=XSD.Person()
            #                contact.fromDict(contact_dict)
            #                session_resp.setLocalcontact(contact)
            #                break


            status.setCode('ok')
            status.setMessage('testmessage')
            proposal_response.setProposal(proposal)
            #proposal_response.setPerson(person)
            #proposal_response.setLaboratory(lab)
        else:
            status.setCode('error')
            status.setMessage('No proposals in database available for %s%d' % (request.getCode(),request.getNumber()))
        proposal_response.setStatus(status)
        return proposal_response

    def capitalizeFirstLetter( self, _ostr ):
        return (_ostr[0].capitalize() + _ostr[1:])


    def db_get_loaded_samples_request(self,request):
        iProposalId              = request.getProposal().getProposalId()
        listSampleReferences     = request.getSample_reference()
        functionToDict           = lambda x: x.toDict()
        listDictSampleReferences = map( functionToDict, listSampleReferences )
        sampleList = DbServerRequestHandler.dbclass.getLoadedSamples(
            iProposalId, listDictSampleReferences )
        response = XSD.Loaded_samples_response()
        status = XSD.Dbstatus()
        status.setCode( "ok" )
        response.setStatus( status )
        for sample in sampleList:
            #pprint.pprint( sample )
            loaded_sample = XSD.Loaded_sample()

            protein = XSD.Protein()
            crystal = XSD.Crystal()
            blsample = XSD.BLSample()
            for el in sample['protein'].keys():
                exec("protein.set%s(sample['protein']['%s'])"   % ( self.capitalizeFirstLetter( el ), el ) )
            for el in sample['crystal'].keys():
                exec("crystal.set%s(sample['crystal']['%s'])"   % ( self.capitalizeFirstLetter( el ), el ) )
            for el in sample['blsample'].keys():
                exec("blsample.set%s(sample['blsample']['%s'])" % ( self.capitalizeFirstLetter( el ), el ) )
            
            """ Below do the translation for each element in the table...
            ISPyB needs changing, easier to create 3 subelement dictionaries
            for easier conversion here."""
            loaded_sample.setProtein(protein)
            loaded_sample.setCrystal(crystal)
            loaded_sample.setBLSample(blsample)
            sample_reference = XSD.Sample_reference()
            sample_reference.setCode(sample['blsample']['code'])
            sample_reference.setSample_location(sample['blsample']['location'])
            sample_reference.setContainer_reference(sample['container']['sampleChangerLocation'])
            loaded_sample.setSample_reference(sample_reference)
            response.addLoaded_sample(loaded_sample)

        return response


if __name__ == '__main__':
    import sys
    #print "Try to start ESRF Tcl Database server (%d) arguments- %r"%(len(sys.argv),sys.argv[1:])
    if (len(sys.argv) < 8):
        print "Server requires the following arguments: 'Database server host machine', 'Database server port number','Pxweb database host', 'Pxweb database account name', 'Pxweb database account pwd', 'Database name', 'logfile'"
        sys.exit(0)
        
    dbhttpserver = DbHTTPServer(sys.argv[1], string.atoi(sys.argv[2]),sys.argv[3], sys.argv[4],sys.argv[5],sys.argv[6])

