from numpy import identity,array,zeros,dot,transpose,argwhere,ndarray
from numpy import linalg,sqrt,floor,cumsum,sum,arange,argmin,argmax,min,median
from scipy.sparse import issparse,csc_matrix
from numpy.fft import fft2

def pcgls(A,b,k,d):
    from time import clock
    t1 = clock()
    """
     PCGLS "Precond." CG appl. implicitly to normal equations.
       
        pcgls(A,b,k,d)
     
        A = coefficient matrix
        b = right-hand side (measured data)
        k = maximum iteration number
            can also be a list e.g. [0,4,9,14,20]
            means make 20 iterations and return solutions:
            0,4,9,14 and the last one 19.
     
      Performs k steps of the 'preconditioned' conjugate gradient
      algorithm applied implicitly to the normal equations
         (A*L_p)'*(A*L_p)*x = (A*L_p)'*b ,
      where L_p is the inverse of L.  In this version, L can be the
      first or second derivative, as controlled by input parameter d.
      In both cases, zero boundary conditions are enforced.
     
      It is assumed that the solution is on an N-by-N-by-N grid !!
     
      The routine returns all k solutions, stored as columns of the
      matrix X.  The solution seminorm || L*X(:,k) ||_2 and residual
      norm || A*X(:,k) - b ||_2 are returned in eta and rho, resp.
     
      In this modified version, L is a full-rank Kronecker product of
      the square second derivative matrix with zero boundary conditions.
   
      References: A. Bjorck, "Numerical Methods for Least Squares Problems",
      SIAM, Philadelphia, 1996.
      P. C. Hansen, "Rank-Deficient and Discrete Ill-Posed Problems.
      Numerical Aspects of Linear Inversion", SIAM, Philadelphia, 1997.
   
      Per Christian Hansen, IMM, DTU, Nov. 14, 2007.

      Implemented in python by Henning Osholm Sorensen, Riso-DTU, June, 2008
     """

    # Making an array with the iteration numbers should be returned 
    # making a copy of k in order not to change k


    if type(k) == int:
        kc = arange(k) 

    elif type(k) == list:
        kc = array(k).copy()
        
        kc[-1] = kc[-1]-1
    elif  type(k) == ndarray:
        kc = k.copy()
        kc[-1] = kc[-1]-1
    assert (kc[-1] > 0), 'Number of steps k must be positive'



    # Initialization
    if issparse(A) != True:
        print 'Matrix not sparse - converting to sparse matrix'
        A = csc_matrix(A)
    (m,n) = A.shape
    X = zeros((n,len(kc)))
    rho = zeros((len(kc),1))


    # Prepare for computations with L_p.
    nn = int(round(n**(1./3.)))
    assert (n == nn**3), 'Length of solution must be a cube %i %i ' %(n,nn**3)
    x_0 = zeros((n))

    L = get_l(nn+2,2)[:,1:-1]
    if (d==1):
        # trick to make square matrix for d=1 
        L = transpose(-linalg.cholesky(-L))
    Li = linalg.inv(L)

    # Prepare for CG iteartion.
    x  = x_0
    r  = b - A.dot(x_0)
    s = A.transpose().dot(r)
    q1 = Lkronsolve(transpose(Li),s)
    q  = Lkronsolve(Li,q1)
    z  = q
    dq = dot(transpose(s),q)
    #  Iterate.
    for j in range(kc[-1]+1):
        # Update x and r vectors; compute q1.
        Az  = A.dot(z)
        alpha = dq/dot(transpose(Az),Az)
        #print alpha
        x = x + alpha*z
        r = r - alpha*Az
        s = A.transpose().dot(r)
        q1  = Lkronsolve(transpose(Li),s)
    
        # Update z vector.
        q   = Lkronsolve(Li,q1)
        dq2 = dot(transpose(s),q)
        beta = dq2/dq
        dq  = dq2
        z = q + beta*z
        if j in kc:
            kk = argwhere(kc==j)[0,0]
            X[:,kk] = x
            rho[kk] = linalg.norm(r)


    #print clock()-t1
    
    return (X ,rho)

def  get_l(n,d): 
    """
    Implemented from the Matlab 
     GET_L Compute discrete derivative operators. 
     
     L = get_l(n,d) 
     
     Computes the discrete approximation L to the derivative operator 
     of order d on a regular grid with n points, i.e. L is (n-d)-by-n. 
     
     L is stored as a sparse matrix. 
    
     Per Christian Hansen, IMM, 02/05/98. 

     Implemented in python by Henning Osholm Sorensen, Riso-DTU, June, 2008
    
    """
    # Initialization. 
    assert d >= 0, 'Order d must be nonnegative'
    # Zero'th derivative. 
    if d==0:
        L = identity(n)
        #W = zeros(n,0)
        return L
 
    # Compute L.
    c = array([-1,1]+[0]*(d-1))
    nd = n-d
    for i in range(1,d):
        c1 = array([0] + c[0:d].tolist())
        c2 = array(c[0:d].tolist() + [0])
        c = c1 - c2
    L = zeros((nd,n))

    for i in range(d+1):
        for j in range(nd):
            L[j,j+i] = L[j,j+i] + c[i]

 
    return L


def Lkronsolve(Li,x):
    """
    function z = Lkronsolve(L,x)
    """
    nx = x.shape[0]
    n = int(round(nx**(1./3.)))
    X = x.reshape(n,n,n)
    
    Y = zeros((n,n,n))
    Z = Y.copy()
    for i in range(n):
        Y[:,:,i] = dot(Li,dot(X[:,:,i],transpose(Li)))
    for i in range(n):
        Z[:,i,:] = dot(Y[:,i,:],transpose(Li))

    z = Z.reshape(nx)
    return z



def stoppingrule(A,b,x,uv_size,no_uv):
    """
    STOPPINGRULE  Stopping rule for iter. regul. based on residual statistics
    
     k_stop = stoppingrule(W,b,x,uv_size,no_uv)
    
       A = coefficient matrix
       b = right-hand side (measured data)
       x = matrix of solution vectors stored as columns
       uv_size = size of uv chart (it is uv_size-times-uv_size)
       no_uv = number of uv charts
    
       k_stop = estimate of the number of required iterations to
                fit the data such that the residuals are white noise.
    
       Per Christian Hansen, IMM, Januar 3, 2007.
    
    Implemented in python by Henning Osholm Sorensen, Riso-DTU, June, 2008
        """
    from scipy.sparse import issparse
    

    nit = x.shape[1]
    R = zeros((A.shape[0],nit))
    if issparse(A):
        for j in range(nit):
            R[:,j] = b - A.dot(x[:,j])
    else:
        for j in range(nit):
            R[:,j] = b - dot(A,x[:,j])
        
    ncpsum= 0
    # Loop over all uv charts.
    II = zeros((no_uv))
    for k in range(no_uv):
        # Extract the residual for the k'th uv chart.
        Rk = R[k*uv_size**2 : (k+1)*uv_size**2 , : ]
        # Computhe the normalized cumulative periodograms for each iteration.
        # Treat each residual as a 2D signal.
        cp = ncp2D(Rk)
        # Find the iteration number for which the noise is "most white."
        nz = cp.shape[0]*1.0
        z = arange(1,nz+1)/nz
        ncp = zeros((nit))
        for j in range(nit):
            ncp[j] = linalg.norm(cp[:,j]-z)
        minncp = min(ncp)
        I = argmin(ncp)
        ncpsum = ncpsum+ncp;
        # Save the found it. number for the k'th chart.
        II[k] = I
    #end k-loop

    # Return the median of the found it. numbers.
    k_stop = round(median(II))
    return k_stop



# INCLUDED FUNCTION BELOW THIS LINE ----------------------------------

def ncp2D(XX):

    (nn,p) = XX.shape
    n = int(round(sqrt(nn)))
    X = zeros((n,n,p))

    for i in range(p):
        X[:,:,i] = XX[:,i].reshape(n,n)

    q = int(floor(n/2.)+1)
    R = zeros((q,q))

    for i in range(q):
        for j in range(q):
            R[i,j] = (i+1)**2+(j+1)**2

    perm = R.ravel().argsort()
    D = zeros((perm.shape[0],p))
    for k in range(p):
        Z = abs(fft2(X[:,:,k]))**2
        d = Z[0:q,0:q].reshape(q**2)
        D[:,k] = d[perm]
        
    D = D[1:,:]  # Get rid of DC component.
    test = cumsum(D[:,k])/sum(D[:,k])
    cp = zeros((perm.shape[0]-1,p))
    for k in range(p):
        cp[:,k] = cumsum(D[:,k])/sum(D[:,k])
    
    return cp
