# Copyright (c) 2004 LOGILAB S.A. (Paris, FRANCE).
# http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# Copyright (c) 2004 DoCoMo Euro-Labs GmbH (Munich, Germany).
# http://www.docomolab-euro.com/ -- mailto:tarlano@docomolab-euro.com
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
"""helper classes to ask a user for information required to fill information
according to a XML template

:version: $Revision:$  
:author: Logilab

:copyright:
  2004 LOGILAB S.A. (Paris, FRANCE)
  
  2004 DoCoMo Euro-Labs GmbH (Munich, Germany)
  
:contact:
  http://www.logilab.fr/ -- mailto:contact@logilab.fr
  
  http://www.docomolab-euro.com/ -- mailto:tarlano@docomolab-euro.com
"""

__revision__ = '$Id$'

from cStringIO import StringIO
from xml.sax import ContentHandler, make_parser

def dump_xml_node(node):
    """create and return a duplicated XMLNode instance from the given node
    """
    return XMLNode(node.qname, node.attrs.copy(), node.content, node.prefixes)

def xml_attributes(dictionary):
    """given a dictionary return it formatted as a XML attributes string"""
    if not dictionary:
        return ''
    result = []
    keys = dictionary.keys()
    keys.sort()
    for key in keys:
        result.append('%s="%s"' % (key, dictionary[key]))
    return ' %s' % ' '.join(result)
    
def xml_prefixes(dictionary):
    """given a dictionary of prefix mapping return it formatted as a XML
    attributes string for xmlns declaration
    """
    if not dictionary:
        return ''
    result = []
    keys = dictionary.keys()
    keys.sort()
    for key in keys:
        result.append('xmlns:%s="%s"' % (key, dictionary[key]))
    return ' %s' % ' '.join(result)
    
class EnterChildren(Exception): pass
class DumpNode(Exception): pass
class NextStep(Exception): pass
class NoMoreQuestion(Exception): pass


class XMLNode(list):
    """a base xml node, handling a qualified name, qualified attributes
    and any chidren or a textual content
    """
    
    def __init__(self, qname, attrs, content='', prefixes=None):
        self.qname = qname
        self.attrs = attrs
        self.content = content
        self.prefixes = prefixes or {}
        
    def __repr__(self):
        return '<%s %s with %s children>' % (self.qname, self.attrs, len(self))
        return '<%s %s %s>' % (self.qname, self.attrs, [str(x) for x in self])

    def as_xml(self, encoding, indent=''):
        attrstr = xml_attributes(self.attrs) + xml_prefixes(self.prefixes)
        if self.content:
            return '%s<%s%s>%s</%s>' % (indent, self.qname, attrstr, self.content, self.qname)
        elif not len(self):
            return '%s<%s%s/>' % (indent, self.qname, attrstr)
        res = ['%s<%s%s>' % (indent, self.qname, attrstr)]
        _indent = indent + '  '
        for child in self:
            res.append(child.as_xml(encoding, _indent))
        res.append('%s</%s>' % (indent, self.qname))
        return '\n'.join(res)


class XMLTemplateNode(XMLNode):
    """a special xml node to be used in xml templates tree, not xml data
    trees
    """

    def __init__(self, qname, attrs, content='', prefixes=None):
        XMLNode.__init__(self, qname, attrs, content, prefixes)
        self.cardinality = ''
        self._in_step = 0
    
    def set_cardinality(self, card):
        assert card in ('', '+', '*')
        self.cardinality = card

    def question(self, path, node):
        """get the current question for this node
        
        :param path: the current context path
        :param node: the currently filling data node
        
        :raise EnterChildren: if the given node is filled
        """
        if self._in_step in (0, 3):
            if self.cardinality in ('+', '*'):
                if len(path) > 1:
                    return 'add a new %s to %s ?' % (path[-1], '/'.join(path[:-1]))
                return 'add a new %s ?' % path[-1]
            if self._in_step == 3:
                self._in_step = 0
                raise NextStep()
            self._in_step = 1
        if self._in_step == 1:
            if '%s' in node.content:
                return '%s%s%s :' % (self.qname,
                                     path[:-1] and ' for ' or '',
                                     '/'.join(path[:-1]))
            for attr, value in node.attrs.items():
                if '%s' in value:
                    return '%s%s%s :' % (attr,
                                            path and ' for ' or '',
                                            '/'.join(path))
            self._in_step = 2
        if self._in_step == 2:
            self._in_step = 3
            raise EnterChildren()
        raise RuntimeError('should not be there !')
    
    def answer(self, answer, node):
        """get the answer for the latest question

        :param answer: the answer given by the user
        :param node: the currently filling data node
        
        :raise DumpNode:
          if the current node should be duplicated to create a new
          sibbling node
        :raise NextStep: if we should go to the next sibbling node
        """
        in_step = self._in_step
        if in_step in (0, 3):
            if answer in ('y', 'yes', 'oui', 'o'):
                self._in_step = 1
                if in_step == 3:
                    raise DumpNode()
            else:
                raise NextStep()
        elif in_step == 1:
            if '%s' in node.content:
                node.content %= answer
            else:
                for attr, value in node.attrs.items():
                    if value == '%s':
                        node.attrs[attr] = answer
                        break
                else:
                    self._in_step = 2

    def __eq__(self, other):
        return id(other) == id(self)



class Generator:
    """state full class to generate xml data from a xml template
    """
    
    def __init__(self, template):
        self._poping = False
        self._result_doc = dump_xml_node(template)
        self._path = [template]
        self._actual = [self._result_doc]
        
    def path(self):
        """return the current qnames path in the template document,
        skipping the root
        """
        return [node.qname for node in self._path[1:]]
    
    def next_question(self):
        """return the next question"""
        current = self._path[-1]
        try:
            question = current.question(self.path(), self._actual[-1])
            self._poping = False
            return question
        except EnterChildren:
            if not self._poping and len(current):
                self._path.append(current[0])
                self._generate_node(current[0])
            #else:
            #    self.next_step()
        except NextStep:
            self.next_step()
        return self.next_question()
        
    def push_answer(self, answer):
        """push the latest answer"""
        try:
            self._path[-1].answer(answer, self._actual[-1])
        except DumpNode:
            self._actual.pop()
            self._generate_node(self._path[-1])
        except NextStep:
            self.next_step()
            
    def next_step(self):
        """go to the next node in the template, while syncing the data tree"""
        self._poping = True
        self._actual.pop()
        current = self._path.pop()
        try:
            head = self._path[-1]
        except IndexError:
            raise NoMoreQuestion
        try:
            current = head[head.index(current) + 1]
        except IndexError:
            pass
        else:
            self._path.append(current)
            self._generate_node(current)
            
    def as_xml(self, encoding='UTF-8'):
        """return the xml representation of the resulting data"""
        return '''<?xml version="1.0" encoding="%s"?>
%s''' % (encoding, self._result_doc.as_xml(encoding).strip())
        
    def _generate_node(self, template_node):
        """generate a data node from the given template node and update
        the resulting data structure
        """
        generated = dump_xml_node(template_node)
        self._actual[-1].append(generated)
        self._actual.append(generated)
        
class XMLTemplateSAXHandler(ContentHandler):
    def __init__(self):
        self.root = None
        self._path = []
        self._all_path = []
        self._pop = []
        
    def startElement(self, qname, attrs):
        #print 'startElement', qname, dict(attrs)
        if self.root is None:
            attrs, prefixes = ns_split(attrs)
            self.root = XMLTemplateNode(qname, attrs, prefixes=prefixes)
            self._path.append(self.root)
            self._all_path.append({})
            self._pop.append(True)
        elif self._pop[-1]:
            if not qname in self._all_path[-1]:
                attrs, prefixes = ns_split(attrs)
                node = XMLTemplateNode(qname, attrs, prefixes=prefixes)
                self._path[-1].append(node)
                self._path.append(node)
                self._all_path[-1][qname] = node
                self._all_path.append({})
                self._pop.append(True)
            else:
                self._all_path[-1][qname].set_cardinality('*')
                self._pop.append(False)
        else:
            self._pop.append(False)
    
    def endElement(self, qname):
        #print 'endElement', qname
        if self._pop.pop():
            self._path.pop()
            self._all_path.pop()
    
    def characters(self, ch):
        #print 'characters', repr(ch)
        if self._pop[-1] and ch.strip():
            self._path[-1].content += ch


class XMLTemplateReader:
    def from_string(self, string):
        return self.from_stream(StringIO(string))
    
    def from_stream(self, stream):
        parser = make_parser()
        hdlr = XMLTemplateSAXHandler()
        parser.setContentHandler(hdlr)
        parser.parse(stream)
        return hdlr.root

def ns_split(attrs):
    actual_attrs, prefixes = {}, {}
    for key, val in attrs.items():
        if key.startswith('xmlns:'):
            prefixes[key.split(':', 1)[1]] = val
        else:
            actual_attrs[key] = val
    return actual_attrs, prefixes
