# -*- coding: utf-8 -*-
"""
Created on Mon Feb 23 15:51:57 2015

@author: chris
"""

#==============================================================================
# This makes sparse matrix approximations D1, D2 to the first two derivatives
# as FD operators on a (not necessarily equally spaced) grid of points on the
# line.   Calling syntax is 
# D1,D2  = MakeD(xx)
def MakeD(xx):
    Nxx = xx.size
    ii  = arange(0,Nxx)
    jj  = ii
    jp  = r_[ jj[1:], jj[-3] ]
    jm  = r_[ii[2],ii[:-1] ]
    xp  = xx[jp]
    xm  = xx[jm]
    xp  = xp - xx
    xm  = xm - xx
    lam = 1/( xp*xm*(xm-xp) )
    b   = lam*(xp*xp - xm*xm)
    ap  = lam*xm*xm
    am  = -lam*xp*xp
    II  = r_[ii,ii,ii]
    JJ  = r_[jm,jj,jp]
    dat = r_[am,b,ap]
    D1  = sparse.coo_matrix( (dat,(II,JJ)), shape = (Nxx,Nxx))
    lam2= 2/( xp*xm*(xp-xm) )
    ap  = lam2*xm
    am = -lam2*xp
    b  = -(ap+am)
    dat = r_[am,b,ap]
    D2  = sparse.coo_matrix( (dat,(II,JJ)), shape = (Nxx,Nxx))
    return D1.tocsr(), D2.tocsr()
    
###############################################################################

def U(x,R):
    xp  = maximum(x,EPS)
    xm  = maximum(EPS-x,0.0)
    z   = xp**(1.0-R) / (1.0-R)
    dU  = EPS**(-R)
    d2U = -R* EPS**(-1.0-R)
    z   = z - dU*xm + 0.5*d2U*xm*xm
    return z

def dU(x,R):
    xp  = maximum(x,EPS)
    xm  = maximum(EPS-x,0.0)
    z   = xp**(-R) 
    d2U = -R* EPS**(-1.0-R)
    z   = z - d2U*xm 
    return z   
    
def I(y,R):
    yp  = minimum( y, EPS**(-R)    )
    ym  = maximum( y - EPS**(-R), 0 )
    sl  = - EPS**(1.0+R) / R
    z   = yp**(-1.0/R) + sl * ym    
    return z

def Ut(y,R):  
    x   = I(y,R)
    return U(x,R) - y*x
    
##############################################################################

def phi(t):
    
    return 1.0   
    
def V_from_pi(c,theta,nextV,tl,tr):
    drift   = r*xx-c+(mu-r)*theta
    var     = sig*theta
    var     = var*var
    M1      = spdiags(drift,0, Nx,Nx)
    M2      = spdiags(var,0,Nx,Nx) * 0.5
    L1      = M1 * D1
    L2      = M2 * D2
    L       = L1 + L2
    dt      = tr - tl
    RHS     = 0.5* (phi(tl)+phi(tr))*U(c,R1)
    Id      = sparse.eye(Nx)
    RHS     +=( (1-p_FD)*L + (1.0/dt)*Id )*nextV
    M       = (1/dt)*Id - p_FD*L
    ans     = spsolve( M, RHS)
    
    return ans 
    
def pi_from_V(oldV,nextV,tl,tr):
    Vw  = D1*( (1-p_FD)*nextV + p_FD*oldV  )
    Vww = D2*( (1-p_FD)*nextV + p_FD*oldV  )
    Vww = minimum(-1e-12,Vww)
    la  = (phi(tl)+phi(tr))/2.0
    c   = I(Vw/la,R1)
    c   = maximum(c, 1e-12)
    th  = - (mu-r)*Vw/Vww
    th  = th/sig/sig

    return c,th    
    