"""some functions to ease testing"""

__revision__ = '$Id:$'

from narval.elements.test import Element, ErrorElement
from narval.plan import PlanElement

def get_elements(elements, klass=Element):
    """return an iterator on Element instances from the elements list"""
    for elmt in elements:
        if isinstance(elmt, klass):
            yield elmt
    
def get_errors(elements):
    """return an iterator on ErrorElement instances from the elements list"""
    return get_elements(elements, ErrorElement)
    
def get_plans(elements):
    """return an iterator in  PlanElement instance from given elements
    """
    return get_elements(elements, PlanElement)

def get_plan(elements):
    """return a plan elements from given elements
    raise AssertionError if not found
    """
    return get_plans(elements).next()

def assert_raises(exc_class, callable_obj, *args, **kwargs):
    """Fail unless an exception of class excClass is thrown
       by callableObj when invoked with arguments args and keyword
       arguments kwargs. If a different type of exception is
       thrown, it will not be caught, and the test case will be
       deemed to have suffered an error, exactly as for an
       unexpected exception.
    """
    try:
        callable_obj(*args, **kwargs)
    except exc_class:
        return
    except Exception, e:
        expected_name = getattr(exc_class, '__name__',
                                str(exc_class))
        caught_name = e.__class__.__name__
        raise AssertionError('Expected %s, got %s' % (expected_name,
                                                      caught_name))
        
    raise AssertionError('No %s raised' % getattr(exc_class, '__name__',
                                                      str(exc_class)))

def assert_equals(obj1, obj2, err_msg = None):
    """asserts <obj1> == <obj2>"""
    err_msg = err_msg or "%s != %s" % (obj1, obj2)
    if not( obj1 == obj2):
        raise AssertionError(err_msg)

def assert_non_equals(obj1, obj2, err_msg = None):
    """asserts <obj1> != <obj2>"""
    err_msg = err_msg or "%s == %s" % (obj1, obj2)
    if obj1 == obj2:
        raise AssertionError(err_msg)

def assert_identity(obj1, obj2, err_msg = None):
    """asserts <obj1> is <obj2>"""
    err_msg = err_msg or "%s is not %s" % (obj1, obj2)
    if obj1 is not obj2:
        raise AssertionError(err_msg)

def assert_element_state(plan, element_id, state, err_msg = None):
    element_state = plan.elements[element_id].state
    err_msg = err_msg or '%s != %s' % (element_state, state)
    if element_state != state:
        raise AssertionError(err_msg)

def assert_elements_state(plan, element_ids, state):
    elements = plan.elements
    for eid in element_ids:
        element_state = elements[eid].state
        if element_state != state:
            raise AssertionError('%s != %s (%s)' % (element_state,
                                                    state,
                                                    eid))
        
def assert_has_instances(elements, klass, nb=1):
    assert_equals(len(list(get_elements(elements, klass))), nb)
