# -*- coding: utf-8 -*-
"""
Created on Mon Feb 23 15:41:28 2015

This script explores the idea that as an agent gets older he should shift
from stock to riskless investment. The proposal is that the objective is

    V(t,w) = sup E[ \int_t^T phi(s) U(c_s) ds + F(w_T) | w_t = w ]
    
where U is CRRA(R_1) and F is a multiple of CRRA(R_2) for some R_2 > R_1


@author: chris
"""

from scipy import arange, r_, linspace, exp, ones, zeros, maximum, dot, log,\
                        minimum, outer, r_
from scipy import sparse
from scipy.sparse import spdiags
from scipy.sparse.linalg import spsolve
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from time import time

R1      = 0.5
R2      = 6.0
T       = 60.0
Nx      = 1200                   # number of x-grid points
Nt      = 61                    # number of t-grid points
sig     = 0.25
r       = 0.05
mu      = 0.09
A       = 1.0                   #   F = A * U(.,R2)
EPS     = 1.0e-5                # cutoff in definition of CRRA
p_FD    = 0.5                   # controls the FD scheme. 0 is explicit, 1 is 
                                #  fully implicit

execfile("LifetimeFuns.py")

pi_1    = (mu-r)/sig/sig/R1
pi_2    = (mu-r)/sig/sig/R2

tt      = linspace(0,T,Nt)
xx      = exp( linspace(-5,14,Nx) )
D1,D2   = MakeD(xx)
Pi      = zeros([Nx,Nt-1])
Ga      = zeros([Nx,Nt-1])
Vall    = zeros([Nx,Nt-1])

nextV    = A * U(xx,R2)
oldc    = 0.1 * xx
oldth   = 0.1 * xx
oldV    = nextV.copy()

tic = time()
for ii in arange(Nt-1, 0,-1):
    print "\n *******  Time point ",tt[ii-1],"  ******* \n"
    count   = 0; check = 1.0
    while (count<75)&(check>5e-6):
        print " At count =  ", count,", check = ", check
        newV        = V_from_pi(oldc,oldth,nextV,tt[ii-1],tt[ii])
        newc,newth  = pi_from_V(newV,nextV,tt[ii-1],tt[ii])
        check       = abs(newV-oldV).max()
        count       += 1       
        oldc    = newc.copy(); oldth   = newth.copy(); oldV  = newV.copy()
    nextV   = oldV.copy()
    Pi[:,ii-1]  = oldth/xx
    Ga[:,ii-1]  = oldc/xx
    Vall[:,ii-1] = nextV

toc = time() 
print "\n ########### Time taken = ", toc -tic  ,"  ############\n"       
   
###########   END OF CALCULATION ########################################

###########    Outputs:          ########################################

J   = abs(log(xx))<=4
xx0 = xx[J]
Nx0 = len(xx0)
XX  = outer(xx0,ones(Nt-1)); TT = outer(ones(Nx0), tt[:-1])
plt.ion()
fig = plt.figure(1)
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(log(XX), TT, Pi[J,:])
ax.set_xlabel('Log wealth')
ax.set_ylabel('Time')
ax.set_zlabel('Proportion in stock')
##
fig = plt.figure(2)
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(log(XX), TT, maximum( -20, log(Ga[J,:])   ) )
ax.set_xlabel('Log wealth')
ax.set_ylabel('Time')
ax.set_zlabel('Log(consumption/wealth)')
##
fig = plt.figure(3)
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(log(XX), TT, maximum(-1e2,Vall[J,:]) )
ax.set_xlabel('Log wealth')
ax.set_ylabel('Time')
ax.set_zlabel('Value')

print "\n Merton proportion for running consumption = ", pi_1
print " Max proportion in solution                = ", Pi[J,:].max()
print "\n Merton proportion for terminal wealth     = ", pi_2  
print " Min proportion in solution                = ", Pi[J,:].min()

   