# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2018-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr

__all__ = ['rotation_matrix', 'rotate']

import numpy as np


def rotation_matrix(axis, angle):
    """
    Creates a rotation matrix as a numpy array. The convention is the NeXus one so that a positive rotation of +pi/2:
    - with axis='x', transforms +y into z
    - with axis='y', transforms +z into x
    - with axis='z', transforms +x into y
    :param axis: the rotation axis, either 'x', 'y' or 'z'
    :param angle: the rotation angle in radians
    :return: the rotation matrix
    """
    c, s = np.cos(angle), np.sin(angle)
    if axis.lower() == 'x':
        return np.array([[1, 0, 0], [0, c, -s], [0, s, c]], dtype=np.float32)
    elif axis.lower() == 'y':
        return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]], dtype=np.float32)
    else:
        return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=np.float32)


def rotate(m, x, y, z):
    """
    Perform a rotation given a rotation matrix and x, y, z coordinates (which can be arrays)
    :param m:
    :param x:
    :param y:
    :param z:
    :return: a tuple of x, y, z coordinated after rotation
    """
    return m[0, 0] * x + m[0, 1] * y + m[0, 2] * z, \
           m[1, 0] * x + m[1, 1] * y + m[1, 2] * z, \
           m[2, 0] * x + m[2, 1] * y + m[2, 2] * z
