# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2008-2015 : Univ. Joseph Fourier (Grenoble 1), CEA/INAC/SP2M
#   (c) 2013-2014 : Fondation Nanosciences, Grenoble
#   (c) 2016-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr
#         Ondrej Mandula

from __future__ import division

import numpy as np


def calc_obj_shape(posy, posx, probe_shape, margin=16, multiple=4):
    """
    Determines the required size for the reconstructed object.

    :param posy: array of the y scan positions
    :param posx: array of the x scan positions
    :param probe_shape: shape of the probe
    :param margin: margin to extend the object area, in case the positions will change (optimization)
    :param multiple: the shape must be a multiple of that number
    :return:
    """
    ny = int(2 * (abs(np.ceil(posy)) + 1).max() + probe_shape[0])
    nx = int(2 * (abs(np.ceil(posx)) + 1).max() + probe_shape[1])

    if margin is not None:
        ny += margin
        nx += margin

    if multiple is not None:
        dy = ny % multiple
        if dy:
            ny += (multiple - dy)
        dx = nx % multiple
        if dx:
            nx += (multiple - dx)

    return ny, nx


def get_view_coord(obj_shape, probe_shape, dy, dx):
    """
    Get integer pixel coordinates of corner for the part of the object illuminated by a probe, given the shift of the probe relative to the object
    center. Object, probe and shift correspond to 2D coordinates.
    
    Args:
        obj_shape: the shape of the object
        probe_shape: the shape of the probe
        shift: the shift relative to the center of the object, along each dimension: (dy, dx)

    Returns:

    """
    cy = int((obj_shape[0] - probe_shape[0]) // 2 + dy)
    cx = int((obj_shape[1] - probe_shape[1]) // 2 + dx)
    msg = 'Getting outside of the object (dy=%d,dx=%d)(cy=%d,cx=%d). Consider increasing the object size.' \
          % (dy, dx, cy, cx)
    if cx <= 0:
        cx = 0
        print(msg)
    elif (cx + probe_shape[1]) >= obj_shape[1]:
        cx = obj_shape[1] - probe_shape[1]
        print(msg)
    if cy <= 0:
        cy = 0
        print(msg)
    elif (cy + probe_shape[0]) >= obj_shape[0]:
        cy = obj_shape[0] - probe_shape[0]
        print(msg)
    return cy, cx
