library(devtools)
load_all("../covalign")

library(cccp)
EigenPrism <- function(y,X,invsqrtSig=NULL,alpha=0.05,target='beta2',zero.ind=c(),diagnostics=FALSE){
                                        # Author: Lucas Janson (statweb.stanford.edu/~ljanson)
                                        # Runs EigenPrism procedure for estimating and generating confidence
                                        #  intervals for variance components in high-dimensional linear model:
                                        #       y = X%*%beta + e,   rows of X iid~ N(0,Sig),   e iid~ N(0,sigma^2)
                                        #  Requires cccp package for solving second order cone optimization.
                                        #  Note confidence interval endpoints may lie outside parameter domain, so it may be appropriate
                                        #   to clip them after the fact.
                                        #
                                        # Inputs:
                                        #  y: response vector of length n (will automatically be centered)
                                        #  X: n by p design matrix; columns will automatically be centered and scaled to variance 1;
                                        #      should not contain intercept column, since both y and X will be centered
                                        #  invsqrtSig: if columns of X not independent, p by p positive definite matrix which is the square-root
                                        #               of the inverse of Sig, where Sig is the *correlation* matrix of the X (default is identity)
                                        #  alpha: significance level for confidence interval (default = 0.05)
                                        #  target: target of estimation/inference
                                        #		  'beta2' (default) is the squared 2-norm of the coefficient vector: sum(beta^2)
                                        #           'sigma2' is the noise variance sigma^2
                                        #           'heritability' is the fraction of variance of y explained by X%*%beta: t(beta)%*%Sig%*%beta/var(y)
                                        #  zero.ind: vector of which indices of the weight vector w to constrain to zero (default is none)
                                        #  diagnostics: boolean (default = T) for whether to generate diagnostic plots for the V_i, lambda_i, and w_i
                                        #
                                        # Outputs:
                                        #  estimate: unbiased estimate of the target (for heritability, only approximately unbiased)
                                        #  CI: 100*(1-alpha)% confidence interval for target

                                        # Get dimensionality of problem
    n = nrow(X)
    p = ncol(X)

                                        # Transform y and X to proper form
    y = y-mean(y)
    X = scale(X,TRUE,TRUE)*n/(n-1)
    if(!is.null(invsqrtSig)) X = X%*%invsqrtSig

                                        # Take singular value decomposition and rescale singular values
    svd = svd(X, n, p)
    lambda = svd$d^2/p
    if (length(lambda) < n) {
        lambda <- c(lambda, rep(0, n - length(lambda)))
    }

                                        # Defined cone-constrained linear problem to optimize weights; [v; w] is vector of optimization variables
    q = c(1,rep(0,n)) #coefficient vector in objective function
    A = rbind(c(0,rep(1,n)),c(0,lambda)) #matrix for linear constraints
    b = c(0,1) #vector for linear constraints
    if(target=='sigma2') b = c(1,0) #switch constraints if target is sigma^2
                                        # Constrain some weights to be zero if desired
    if(!is.null(zero.ind)){
        A = rbind(A,cbind(rep(0,length(zero.ind)),diag(rep(1,n))[zero.ind,]))
        b = c(b,rep(0,length(zero.ind)))
    }
                                        # Define second-order cone constraints
    soc1 = socc(diag(c(1/4,rep(1,n))),c(-1/2,rep(0,n)),c(1/4,rep(0,n)),1/2)
    soc2 = socc(diag(c(1/4,lambda)),c(-1/2,rep(0,n)),c(1/4,rep(0,n)),1/2)
    prob = dlp(as.vector(q),A,as.vector(b),list(soc1,soc2))

                                        # Solve optimization problem and extract variables
    opt = cps(prob,ctrl(trace=F))
    v = getx(opt)[1]
    w = getx(opt)[-1]

                                        # Compute estimate and y's variance
    est = sum(w*(t(svd$u)%*%y)^2)
    yvar = sum(y^2)/n

                                        # Compute confidence interval
    CI = est + yvar*sqrt(v)*qnorm(1-alpha/2)*c(-1,1)
    if(target=='heritability'){
        est = est/yvar
        CI = CI/yvar
    }

                                        # Generate list with results
    result=list()
    result$estimate = est
    result$CI = CI

                                        # Generate diagnostic plots
    if(diagnostics){
        par(mfrow=c(1,3))

                                        # Check that eigenvectors are approximately Gaussian
        nV = floor(log10(n))
        srtV = svd$v[,10^(0:nV)]
        labs = c()
        for(i in 1:(nV+1)){
            srtV[,i] = sort(srtV[,i])
            ind = 10^(i-1)
            labs = c(labs,bquote(V[.(ind)]))
        }
        matplot(qnorm((1:p)/(p+1)),srtV,type="l",lwd=2,
                ylab="Quantiles of Eigenvectors",xlab="Gaussian Quantiles",
                main=expression(paste("Check Gaussianity of Eigenvectors ",V[i])))
        legend("topleft",as.expression(labs),col=1:(nV+1),lty=1:(nV+1),lwd=2)

                                        # Check that there are no outliers in the eigenvalues
        hist(lambda,main=expression(paste("Histogram of Normalized Eigenvalues ",lambda[i])),
             xlab=expression(lambda[i]))

                                        # Check that the weights are not dominated by just a few values
        srtw = sort(abs(w),T)
        plot(1:n,cumsum(srtw)/sum(srtw),type="l",lwd=2,
             main=expression(paste("Fraction of Total Weight in Largest k ",w[i])),
             xlab="k",ylab="Fraction of Total Weight")
    }

    return(result)
}


one.sim <- function(t.scenario, y.scenario, t.strength) {
    n <- 500
    p <- 100
    sigma <- 2
    Sigma <- abs(0.5^outer(1:p, 1:p, "-"))

    library(MASS)
    X <- mvrnorm(n, rep(0, p), Sigma)

    beta.t <- switch(t.scenario,
                     ultra.sparse = c(rep(1, 5), rep(0, p - 5)),
                     sparse = c(rep(1, 20), rep(0, p - 20)),
                     dense = rep(1, p))
    beta.t <- beta.t / sqrt(sum(beta.t^2))
    beta.t <- beta.t * switch(t.strength,
                              weak = 0.5,
                              strong = 2)

    beta.y <- switch(y.scenario,
                     ultra.sparse = c(rep(1, 5), rep(0, p - 5)),
                     sparse = c(rep(1, 20), rep(0, p - 20)),
                     dense = rep(1, p))
    beta.y <- beta.y / sqrt(sum(beta.y^2))

    T <- as.numeric(runif(n) < plogis(X %*% beta.t))
    Y.true <- X %*% beta.y
    Y <- Y.true + sigma * rnorm(n)

    lambda.seq <- 10^seq(8, -8, length = 33)
    result <- kernel.balance.path(T, X, d = rep(1, p), alpha = 0, beta = -1, lambda = lambda.seq)

    cv <- function(w) {
        sd(w) / mean(w)
    }

    weights <- sapply(1:length(lambda.seq), function(j)
        compute.weights(T, result[[j]]$p,
                        alpha = 0, beta = -1,
                        normalize = TRUE))
    cv.path <- apply(weights, 2, function(w) cv(w))

    w1 <- weights[, max(which(cv.path < max(1.5 * cv.path[1], 0.2)))]
    w2 <- weights[, max(which(cv.path < max(5 * cv.path[1], 1)))]

    s <- sample(1:sum(T == 0))
    s1 <- s[1:(length(s)/2)]
    library(glmnet)
    fit.lasso <- cv.glmnet(X[T == 0, ][s1, ], Y[T == 0][s1])
    fit.ridge <- cv.glmnet(X[T == 0, ][s1, ], Y[T == 0][s1], alpha = 0)

    Y.res.lasso <- Y[T == 0][-s1] - predict(fit.lasso, X[T == 0, ][-s1, ], s = fit.lasso$lambda.min)
    Y.res.ridge <- Y[T == 0][-s1] - predict(fit.ridge, X[T == 0, ][-s1, ], s = fit.ridge$lambda.min)
    Y0.pred.lasso <- predict(fit.lasso, X, s = fit.lasso$lambda.min)
    Y0.pred.ridge <- predict(fit.ridge, X, s = fit.ridge$lambda.min)

    theta.plain <- sqrt(max(EigenPrism(y = Y[T == 0], X = X[T == 0, ], alpha = 0.05, diagnostics = FALSE)$estimate, 0))
    theta.lasso <- sqrt(max(EigenPrism(y = Y.res.lasso, X = X[T == 0, ][-s1, ], alpha = 0.05, diagnostics = FALSE)$estimate, 0))
    theta.ridge <- sqrt(max(EigenPrism(y = Y.res.ridge, X = X[T == 0, ][-s1, ], alpha = 0.05, diagnostics = FALSE)$estimate, 0))
    print(c(theta.plain, theta.lasso, theta.ridge))

    theta.plain.cp <- sqrt(max(EigenPrism(y = Y[T == 0], X = X[T == 0, ], alpha = 0.05, diagnostics = FALSE)$CI[2], 0))
    theta.lasso.cp <- sqrt(max(EigenPrism(y = Y.res.lasso, X = X[T == 0, ][-s1, ], alpha = 0.05, diagnostics = FALSE)$CI[2], 0))
    theta.ridge.cp <- sqrt(max(EigenPrism(y = Y.res.ridge, X = X[T == 0, ][-s1, ], alpha = 0.05, diagnostics = FALSE)$CI[2], 0))

    output.ipw <- data.frame(
        method = "ipw",
        est = c(t(weights * (2 * T - 1)) %*% Y / n),
        max.bias = theta.plain * sqrt(rowSums((t(weights * (2 * T - 1)) %*% X / n)^2)),
        max.bias.cp = theta.plain.cp * sqrt(rowSums((t(weights * (2 * T - 1)) %*% X / n)^2)),
        bias = (t(weights * (2 * T - 1)) %*% X / n) %*% beta.y)

    output.drl <- data.frame(
        method = "drl",
        est = c(t(weights * (2 * T - 1)) %*% (Y - Y0.pred.lasso) / n),
        max.bias = theta.lasso * sqrt(rowSums((t(weights * (2 * T - 1)) %*% X / n)^2)),
        max.bias.cp = theta.lasso.cp * sqrt(rowSums((t(weights * (2 * T - 1)) %*% X / n)^2)),
        bias = (t(weights * (2 * T - 1)) %*% X / n) %*% (beta.y - coef(fit.lasso, s = fit.lasso$lambda.min)[-1]))

    output.drr <- data.frame(
        method = "drr",
        est = c(t(weights * (2 * T - 1)) %*% (Y - Y0.pred.ridge) / n),
        max.bias = theta.ridge * sqrt(rowSums((t(weights * (2 * T - 1)) %*% X / n)^2)),
        max.bias.cp = theta.ridge.cp * sqrt(rowSums((t(weights * (2 * T - 1)) %*% X / n)^2)),
        bias = (t(weights * (2 * T - 1)) %*% X / n) %*% (beta.y - coef(fit.ridge, s = fit.ridge$lambda.min)[-1]))

    output <- rbind(output.ipw, output.drl, output.drr)
    output$t.scenario <- t.scenario
    output$y.scenario <- y.scenario
    output$t.strength <- t.strength
    output$lambda <- lambda.seq
    output$cv <- apply(weights, 2, cv)
    output$se <- sqrt(colSums((weights/n)^2)) * sigma ## standard error
    output$se.est <- sqrt(colSums((weights/n)^2)) * sqrt(mean(Y.res.lasso^2)) ## estimated standard error

    return(output)
}


## settings <- expand.grid(t.scenario = c("ultra.sparse", "sparse", "dense"),
##                         y.scenario = c("ultra.sparse", "dense"),
##                         t.strength = c("weak", "strong"))
settings <- expand.grid(t.scenario = c("dense"),
                        y.scenario = c("dense"),
                        t.strength = c("weak"))


nsim <- 100
library(parallel)
output <- mclapply(1:nsim,
                   function(sim) {
                       print(sim);
                       output <- do.call(rbind, lapply(1:nrow(settings), function(i) {one.sim(as.character(settings[i, 1]), as.character(settings[i, 2]), as.character(settings[i, 3]))}));
                       output$sim <- sim;
                       return(output)},
                   mc.cores = 1)
output <- do.call(rbind, output)

output <- replace(output, is.na(output), 0) ## some maximum bias is NaN, replace with 0

save(output, file = "linear_oct25.rda")

## load("linear.rda")
load("linear_oct25.rda")

library(plyr)
output1 <-
    ddply(output, .(t.scenario, y.scenario, t.strength, sim, method), summarise,
          s = max(which(cv < max(cv[1] * 5, 1))),
          est = est[s],
          max.bias = max.bias[s],
          max.bias.cp = max.bias.cp[s],
          bias = bias[s],
          cv = cv[s],
          se = se[s],
          se.est = se.est[s],
          ci.low = est + qnorm(0.025) * se.est,
          ci.up = est + qnorm(1 - 0.025) * se.est,
          ci.cover = (ci.low < 0 & ci.up > 0),
          cinew.low = est - max.bias.cp + qnorm(0.025 / 2) * se.est,
          cinew.up = est + max.bias.cp + qnorm(1 - 0.025 / 2) * se.est,
          cinew.cover = (cinew.low < 0 & cinew.up > 0))

output2 <- ddply(output1, .(t.scenario, y.scenario, t.strength, method), summarise,
                 rmse = sqrt(mean(bias^2)),
                 mean.bias = mean(abs(bias)),
                 mean.max.bias = mean(max.bias),
                 ci.cover = mean(ci.cover),
                 ci.len = mean(ci.up - ci.low),
                 cinew.cover = mean(cinew.cover),
                 cinew.len = mean(cinew.up - cinew.low))

output2$t.scenario <- factor(output2$t.scenario)
output2$y.scenario <- factor(output2$y.scenario)
output2$t.strength <- factor(output2$t.strength)
output2$method <- factor(output2$method)
output2$rmse <- round(output2$rmse, 2)
output2$mean.bias <- round(output2$mean.bias, 2)
output2$mean.max.bias <- round(output2$mean.max.bias, 2)
output2$ci.cover <- round(output2$ci.cover, 2)
output2$cinew.cover <- round(output2$cinew.cover, 2)
output2$ci.ratio <- round(output2$cinew.len / output2$ci.len, 2)

output2 <- subset(output2, t.scenario != "sparse")
output2$t.scenario <- droplevels(output2$t.scenario)

levels(output2$y.scenario) <- c(100, 20, 5)
levels(output2$t.scenario) <- c(100, 5)
levels(output2$t.strength) <- c(1, 2)
levels(output2$method) <- c("IPW", "AIPW-L", "AIPW-R")

library(tables)
latex(tabular(y.scenario * t.strength * t.scenario * method ~ (rmse + mean.bias + mean.max.bias + ci.cover + cinew.cover + ci.ratio) * Heading() * identity, data = output2))
