#' Covariate Alignment Package
#'
#' @docType package
#' @name covalign-package
NULL

#' A new version of PS.weights
#' @description This returns the weights in the general framework
#'              described in Kang&Schaffer (2007)
#' @export
#' @import glmnet gbm ebal CBPS
#' @param label labels, 0 for control and 1 for treatment
#' @param features covaraites to be used in propensity scoring
#' @param type either "pop" (full population) or "nr" (non-respondents).
#'             If the input is a factor level,
#'             it will be automatically transformed into a character string.
#' @param method various propensity scoring methods
#' @param ... additional parameters for glmnet and gbm.
#' @return a list of \describe{
#'   \item{model}{the model used.}
#'   \item{weights}{estimated weights}
#'   \item{propensity.score}{estimated propensity score}
#'   \item{method}{method used}
#' }
#' @details If type="nr", this function only weights the control units.
#' @references
#'   Lunceford & Davidian, Stratification and Weighting
#'   Via the Propensity Score in Estimation of
#'   Casual Treatment Effects: A Comparative Study,
#'   2004;
#'   Kang & Schaffer, Demystifying Double Robustness:
#'   A Comparison of Alternative Strategies
#'   for Estimating a Population Mean from Incomplete Data,
#'   2007
get.weights <- function(label,
                        features,
                        type = c("pop", "nr"),
                        method =
                            c("glm", "glmnet", "gbm", "cbps", "ebal", "unif", "mmcb"),
                        ...) {
    # check arguments
    type <- match.arg(as.character(type), c("pop", "nr"))
    method <- match.arg(as.character(method),
                        c("glm", "glmnet", "gbm", "cbps", "ebal", "unif", "mmcb"))

    n <- length(label)
    data <- data.frame(label = label, features = features)

    if (method %in% c("glm", "glmnet", "gbm", "cbps")) {

        # First estimate propensity scores
        if (method == "glm") {
            model <- glm(label ~ ., family = "binomial", data = data)
            propensity.score <- fitted(model)
        }
        else if (method == "glmnet") {
            library(glmnet)
            model <- cv.glmnet(x = features, y = label,
                               family = "binomial", ...)
            propensity.score <- plogis(predict(model, features, s = "lambda.min"))
        }
        else if (method == "gbm") {
            library(gbm)
            model <- gbm(label ~ ., data = data,
                         distribution = "bernoulli", ...)
            best.iter.oob <- gbm.perf(model, method = "OOB", plot.it = FALSE)
            prob <- predict(model, data, n.trees = best.iter.oob,
                            type = "response")
        }
        else if (method == "cbps") {
            library(CBPS)
            model <- CBPS(label ~ ., data = data, ATT = FALSE)
            # since all we want is the propensity score,
            # the ATT argument in CBPS doesn't make any difference,
            # but CBPS requires a value for ATT.
            propensity.score <- model$fitted.values
        }

        # get weights by estimated propensity score
        weights <- rep(0, n)
        if (type == "pop") {
            weights[label == 1] <- 1 / propensity.score[label == 1]
            weights[label == 0] <- 1 / (1 - propensity.score[label == 0])
        } else if (type == "nr") {
            weights[label == 1] <- (1 - propensity.score[label == 1]) /
                propensity.score[label == 1]
            weights[label == 0] <- propensity.score[label == 0] /
                (1 - propensity.score[label == 0])
        }
    } else {
        # Non propensity score methods
        propensity.score = NULL
        if (method == "unif") {
            # uniform weight
            weights <- rep(1, n)
            model <- NULL
        } else if (method == "ebal") {
            # get weights by entropy balancing
            library(ebal)
            if (type == "nr") {
                sink('/dev/null')
                # ebalance prints a message every time, this is quite annoying
                # this sink command removes that output
                model.c <- ebalance(label, features)
                model.t <- ebalance(1-label, features)
                sink()
                weights <- rep(0, n)
                weights[label == 0] <- model.c$w
                weights[label == 1] <- model.t$w
                propensity.score <-
                    plogis(cbind(rep(1, n), features) %*% model.c$coefs)
            } else if (type == "pop") {
                warning(paste("pop weights are not correctly implemented for ebal",
                              "This option should only be used",
                              "in simulations demonstrating this."))
                sink('/dev/null')
                model.c <- ebalance(c(rep(1, n), label[label == 0]),
                                    rbind(features, features[label == 0, ]))
                model.t <- ebalance(c(rep(1, n), 1 - label[label == 1]),
                                    rbind(features, features[label == 1, ]))
                sink()
                weights <- rep(0, n)
                weights[label == 0] <- model.c$w
                weights[label == 1] <- model.t$w
            } else {
                stop("No such effect!")
            }
            model <- list(model.c = model.c, model.t = model.t)
        } else if (method == "mmcb") {
            model <- NULL

            basis <- get.basis(features, nbasis = 11)
            X <- basis$features
            c <- basis$alpha

            mmcb.weights <- function(X, c, target, index) {
                library(quadprog)
                A <- t(t(X) / c^2) %*% t(X)
                b <- t(t(X) / c^2) %*% target
                nn <- length(index)
                result <- solve.QP(
                    A[index, index],
                    b[index],
                    cbind(rep(1, nn), diag(nn)),
                    c(1, rep(0, nn)),
                    1
                )
                result$solution
            }

            if (type == "pop") {
                target <- apply(X, 2, mean)
                weights.t <- mmcb.weights(X, c, target, which(label == 1))
                weights.c <- mmcb.weights(X, c, target, which(label == 0))
            } else if (type == "nr") {
                target.t <- apply(X[label == 0, ], 2, mean)
                weights.t <- mmcb.weights(X, c, target.t, which(label == 1))
                target.c <- apply(X[label == 1, ], 2, mean)
                weights.c <- mmcb.weights(X, c, target.c, which(label == 0))
            }
            weights <- rep(0, n)
            weights[label == 0] <- weights.c
            weights[label == 1] <- weights.t
        } else {
            stop("No such method!")
        }
    }

    # normalize weights
    if (sum(weights < -0.0001) > 0) {
        stop("Can weights be negative?")
    }
    weights[weights < 0] <- 0
    weights[label == 1] <- weights[label == 1] / sum(weights[label == 1])
    weights[label == 0] <- weights[label == 0] / sum(weights[label == 0])

    return(list(model = model,
                weights = weights,
                propensity.score = propensity.score,
                type = type,
                method = method))
}


#' Estimate mean
#' @description estimate E[Y(0)] or E[Y(0)|T=1]
#' @export
#' @param T treatment (1) or control (0)
#' @param Y outcome
#' @param features features
#' @param features.T features to be used in T.model (optional)
#' @param features.Y features to be used in Y.model (optional)
#' @param over either "treated" (E[Y(0)|T=1]) or "all" (E[Y(0)])
#'             If the input is a factor level,
#'             it will be automatically transformed into a character string.
#' @param T.method weighting method (usually based on propensity score)
#' @param Y.method outcome regression method
#' @param T.model a fitted propensity model (optional)
#' @param Y.model a fitted outcome model on the control (optional)
#' @param combine.method combining method
#' @details The weights are always normalized.
#'          In the usual causal effect model, Y = T * Y(1) + (1-T) * Y(0).
#'          \code{get.mean} only estimates the mean of Y(0), so it treats
#'          Y[T==1] as unobserved response.
#' @return \describe {
#'   \item{mean.est}{the mean estimate}
#'   \item{T.model}{the propensity score model}
#'   \item{Y.model}{the outcome model}
#' }
get.mean <- function(T,
                     Y,
                     features,
                     features.T = features, features.Y = features,
                     over = c("treated", "all"),
                     T.method = c("glm", "unif", "glmnet", "cbps", "ebal", "mmcb"),
                     weight.type = c("pop", "nr"),
                     Y.method = c("lm", "none", "wls"),
                     combine.method = c("dr", "none"),
                     T.model = NULL, Y.model = NULL) {

    # check arguments
    over <- match.arg(as.character(over), c("treated", "all"))
    T.method <- match.arg(as.character(T.method),
                          c("glm", "unif", "glmnet", "cbps", "ebal", "mmcb"))
    Y.method <- match.arg(as.character(Y.method),
                          c("lm", "none", "wls"))
    weight.type <- match.arg(as.character(weight.type),
                          c("pop", "nr"))
    combine.method <- match.arg(as.character(combine.method),
                          c("dr", "none"))
    if (combine.method == "dr" & (T.method == "unif" | Y.method == "none")) {
        stop("Needs nontrivial T and Y method if combine.method = dr.")
    }

    # construct a data frame to facilitate model fitting
    data <- data.frame(Y = Y, features = features.Y)

    # Fit T model
    if (is.null(T.model)) {
        T.model <- get.weights(T, features.T, weight.type, T.method)
    } else {
        weight.type = T.model$type
    }

    # Fit Y model
    if (is.null(Y.model)) {
        if (Y.method != "none") {
            if (Y.method == "wls") {
                weights <- T.model$weights
                Y.model <- lm(Y ~ ., data[T == 0, ],
                              weights = weights[T == 0])
            } else {
                Y.model <- lm(Y ~ ., data[T == 0, ])
            }
        } else {
            Y.model <- NULL
        }
    }

    p1 <- mean(T)
    p0 <- 1 - p1
    # Combine T model and Y model
    if (combine.method == "none") {
        if (is.null(Y.model)) { # only use T.model
            if (over == "treated") {
                mean.est <- switch(
                    weight.type,
                    nr = sum(T.model$weights[T == 0] * Y[T == 0]),
                    pop = ( sum(T.model$weights[T == 0] * Y[T == 0]) -
                        p0 * mean(Y[T == 0]) ) / p1
                )
            } else if (over == "all") {
                mean.est <- switch(
                    weight.type,
                    pop = sum(T.model$weights[T == 0] * Y[T == 0]),
                    nr =  p1 * sum(T.model$weights[T == 0] * Y[T == 0]) +
                        p0 * mean(Y[T == 0])
                )
            }
        } else { # only use Y.model
            mean.est <- switch(
                over,
                treated = mean(predict(Y.model, data[T == 1, ])),
                all = mean(predict(Y.model, data))
            )
        }
    } else if (combine.method == "dr") {
        # First computes the unaugmented mean
        m <- switch(weight.type,
                    nr = mean(predict(Y.model, data[T == 1, ])),
                    pop = mean(predict(Y.model, data))
                    )
        # Now correct the bias due to incorrect Y.model
        # m.bc estimates E[Y(0)] if type = "pop"
        # m.bc estimates E[Y(0)|T=1] if type = "nr"
        m.bc <- m + sum(T.model$weights[T == 0] * Y.model$residuals)
        if ( (over == "all" & weight.type == "pop") |
                (over == "treated" & weight.type == "nr") ) {
            mean.est <- m.bc
        } else if (over == "all" & weight.type == "nr") {
            mean.est <- p1 * m.bc + p0 * mean(Y[T == 0])
        } else if (over == "treated" & weight.type == "pop") {
            mean.est <- (m.bc - p0 * mean(Y[T == 0])) / p1
        }
    }

    return(list(mean.est = mean.est,
                T.model = T.model,
                Y.model = Y.model))

}


#' Estimate causal effect
#' @description estimate E[Y(1)] - E[Y(0)] or E[Y(1)|T=1] - E[Y(0)|T=1]
#' @export
#' @param T treatment (1) or control (0)
#' @param Y outcome
#' @param features features
#' @param effect ATT or ATE
#' @inheritParams get.mean
#' @return  \describe {
#'   \item{effect.est}{the causal effect estimate}
#'   \item{mean}{the mean estimates}
#'   \item{m0}{model for control response}
#'   \item{m1}{model for treatment response}
#'   \item{T.model}{propensity score model}
#' }
get.effect <- function(T,
                       Y,
                       features,
                       features.T = features, features.Y = features,
                       effect = c("ATT", "ATE"),
                       T.method = c("glm", "unif", "glmnet", "cbps", "ebal", "mmcb"),
                       weight.type = c("pop", "nr"),
                       Y.method = c("lm", "none", "wls"),
                       combine.method = c("dr", "none")
                       ) {

    # check arguments
    effect <- match.arg(as.character(effect), c("ATT", "ATE"))
    T.method <- match.arg(as.character(T.method),
                          c("glm", "unif", "glmnet", "cbps", "ebal", "mmcb"))
    Y.method <- match.arg(as.character(Y.method),
                          c("lm", "none", "wls"))
    weight.type <- match.arg(as.character(weight.type),
                          c("pop", "nr"))
    combine.method <- match.arg(as.character(combine.method),
                          c("dr", "none"))
    if (combine.method == "dr" & (T.method == "unif" | Y.method == "none")) {
        stop("Needs nontrivial T and Y method if combine.method = dr.")
    }

    # construct a data frame to facilitate model fitting
    data <- data.frame(Y = Y, features = features.Y)

    # Fit T model
    T.model <- get.weights(T, features.T, weight.type, T.method)

    if (effect == "ATT") {
        # m0 estiamtes E[Y(0)|T=1], m1 estimates E[Y(1)|T=1]
        m0 <- get.mean(T, Y, features, features.T, features.Y,
                       over = "treated",
                       T.model = T.model, Y.method = Y.method,
                       combine.method = combine.method)
        m1 <- list(mean.est = mean(Y[T == 1]))
        effect.est <- m1$mean.est - m0$mean.est
    } else if (effect == "ATE"){
        # m0 estimates E[Y(0)], m1 estimates E[Y(1)]
        m0 <- get.mean(T, Y, features, features.T, features.Y,
                       over = "all",
                       T.model = T.model, Y.method = Y.method,
                       combine.method = combine.method)
        m1 <- get.mean(1 - T, Y, features, features.T, features.Y,
                       over = "all",
                       T.model = T.model, Y.method = Y.method,
                       combine.method = combine.method)
        effect.est <- m1$mean.est - m0$mean.est
    }

    return.list <-
        list(effect.est = effect.est,
             mean.est = c(control = m0$mean.est, treatment = m1$mean.est),
             m0 = m0,
             m1 = m1,
             T.model = T.model,
             T = T,
             Y = Y,
             effect = effect,
             T.method = T.method,
             weight.type = weight.type,
             Y.method = Y.method,
             combine.method = combine.method,
             features.T = features.T,
             features.Y = features.Y)
    class(return.list) <- "obeffect"
    return(return.list)

}

#' Get standard error of an causal effect estimate
#' @export
#' @param object an object returned by \code{get.effect}
get.effect.se <- function(object) {

    if (class(object) != "obeffect") {
        stop("object should be the list returned by get.effect.")
    }

    effect.se <- NA

    if (object$effect == "ATE") {
        if (object$combine.method == "dr") {
            if (object$weight.type == "pop") {
                I1 <- (object$T * object$Y -
                           predict(object$m1$Y.model, object$T.model$model$data) *
                               (object$T - object$T.model$propensity.score)) /
                                   object$T.model$propensity.score
                I2 <- ((1 - object$T) * object$Y +
                           predict(object$m0$Y.model, object$T.model$model$data) *
                               (object$T - object$T.model$propensity.score)) /
                                   (1 - object$T.model$propensity.score)
                I <- I1 - I2
                if (abs(mean(I) - object$effect.est) /
                        object$effect.est > 0.01) {
                    message(paste("mean of I is", mean(I)))
                    message(paste("object$effect.est is", object$effect.est))
                    message("something is wrong.")
                    effect.se <- NA
                } else {
                    effect.se <- sqrt(sum((I - object$effect.est)^2) / (length(I)^2))
                }
            }
        }
    } else { # object$effect = "ATT"
        if (object$combine.method == "none") {
            if (object$weight.type == "nr") {
                if (object$T.method == "ebal") {
                    ## compute some necessary quantities
                    data <- list(T = object$T,
                                 Y = object$Y,
                                 features = object$features.T)
                    mu.11 <- mean(data$Y[data$T == 1])
                    mu.01 <- object$m1$mean.est
                    ps <- object$T.model$propensity.score
                    index.0 <- which(data$T == 0)
                    index.1 <- which(data$T == 1)
                    p1 <- mean(data$T)
                    p0 <- 1 - p1
                    odds <- ps / (1 - ps)
                    p <- ncol(data$features)
                    X <- data$features
                    X.centered <- t(t(data$features - colMeans(data$features[index.1, ])))
                    ## empirical sandwich method
                    A.ebal <- rbind(
                        cbind(sum(data$T) * diag(p), matrix(0, p, p+2) ),
                        cbind(sum((1 - data$T) * odds) * diag(p),
                              - t(as.vector((1 - data$T) * odds) * X.centered) %*% X,
                              matrix(0, p, 2)),
                        cbind(matrix(0, 1, 2*p), sum(data$T), 0),
                        cbind(matrix(0, 1, p),
                              - t(t(X) %*% as.vector((1 - data$T) * odds * (data$Y - mu.01))),
                              sum((1 - data$T) * odds),
                              - sum((1 - data$T) * odds))
                        ) / length(data$T)
                    zeta.ebal <- cbind(
                        data$T * X.centered,
                        as.vector((1 - data$T) * odds) * X.centered,
                        data$T * (data$Y - mu.11),
                        (1 - data$T) * odds * (data$Y - mu.01)
                        )
                    B.ebal <- t(zeta.ebal) %*% zeta.ebal / length(data$T)
                    A.ebal.inv <- solve(A.ebal)
                    V.ebal.full <- A.ebal.inv %*% B.ebal %*% t(A.ebal.inv)
                    V.ebal <- diag(V.ebal.full)[2*p+2]
                    effect.se <- sqrt( V.ebal / length(data$T) )
                } else if (object$T.method == "glm") {
                    ## compute some necessary quantities
                    data <- list(T = object$T,
                                 Y = object$Y,
                                 features = object$features.T)
                    mu.11 <- mean(data$Y[data$T == 1])
                    mu.01 <- object$m1$mean.est
                    ps <- object$T.model$propensity.score
                    index.0 <- which(data$T == 0)
                    index.1 <- which(data$T == 1)
                    p1 <- mean(data$T)
                    p0 <- 1 - p1
                    odds <- ps / (1 - ps)
                    p <- ncol(data$features)
                    X <- data$features
                    X.centered <- t(t(data$features - colMeans(data$features[index.1, ])))
                                        # empirical sandwich method
                    A.ipw <- rbind(
                        cbind(t(as.vector(ps * (1 - ps)) * X) %*% X, matrix(0, p, 2)),
                        cbind(matrix(0, 1, p), sum(data$T), 0),
                        cbind(- t(t(X) %*% as.vector((1 - data$T) * odds * (data$Y - mu.01))),
                              sum((1 - data$T) * odds), -sum((1 - data$T) * odds))
                        ) / length(data$T)
                    zeta.ipw <- cbind(
                        as.vector(data$T - ps) * X,
                        data$T * (data$Y - mu.11),
                        (1 - data$T) * odds * (data$Y - mu.01)
                        )
                    B.ipw <- t(zeta.ipw) %*% zeta.ipw / length(data$T)
                    A.ipw.inv <- solve(A.ipw)
                    V.ipw.full <- A.ipw.inv %*% B.ipw %*% t(A.ipw.inv)
                    V.ipw <- diag(V.ipw.full)[p+2]
                    effect.se <- sqrt( V.ipw / length(data$T) )
                }
            }
        }
    }

    return(effect.se)

}


#' Estimate mean
#' @description estimate E[Y(1)] - E[Y(0)] or E[Y(1)|T=1] - E[Y(0)|T=1]
#' @export
#' @param data a "obdata" object
#' @param stratified if TRUE, stratify the data by data$stratum
#' @param print.level print level (0, 1 or 2)
#' @param method.list a vector or list of methods, paste together
#'                    T.method, weight.type, Y.method, combine.method
#'                    with sep = ","
#' @param ... parameters to be passed in \code{get.mean}
#' @return estimated mean
get.mean.obdata <- function(data,
                            stratified = FALSE,
                            print.level = 0,
                            method.list = c("ebal,nr,none,none"),
                            ...) {
    if (stratified) {
        if (!("stratum" %in% names(data))) {
            stop("The data list doesn't have a field named stratum.")
        }
    } else {
        data$stratum <- factor(rep("0", length(data$T)))
    }

    mean.est <- matrix(0, length(levels(data$stratum)), length(method.list))
    output <- list()

    i <- 0
    for (stratum in levels(data$stratum)) {
        i <- i + 1
        if (print.level >= 1) {
            print(paste("Stratum: ", stratum))
        }

        ind <- which(data$stratum == stratum)
        T <- data$T[ind]
        Y <- data$Y[ind]
        features <- data$features[ind, ]

        j <- 0
        for (method in method.list) {
            j <- j+1
            if (print.level >= 2) {
                print(paste("Method: ", method))
            }

            # unpack method
            method.split <- unlist(strsplit(method, ","))
            T.method <- method.split[1]
            weight.type <- method.split[2]
            Y.method <- method.split[3]
            combine.method <- method.split[4]

            tryCatch(
                output[[stratum]][[method]] <-
                    get.mean(T,
                             Y,
                             features,
                             T.method = T.method,
                             weight.type = weight.type,
                             Y.method = Y.method,
                             combine.method = combine.method,
                             ...),
                error = function(e) {
                    message(paste("An error occured in fitting stratum",
                                  stratum,
                                  "by method",
                                  method))
                    output[[stratum]][[method]] <- NULL
                    try(sink())
                })
            if (!is.null(output[[stratum]][[method]])) {
                mean.est[i, j] <- output[[stratum]][[method]]$mean.est
            } else {
                mean.est[i, j] <- NA
            }
        }
    }

    rownames(mean.est) <- levels(data$stratum)
    colnames(mean.est) <- method.list

    if (!stratified) {
        output <- output[["0"]]
    }

    return(list(mean.est = mean.est,
                output = output,
                method.list = method.list))

}



#' Estimate causal effect
#' @description estimate E[Y(1)] - E[Y(0)] or E[Y(1)|T=1] - E[Y(0)|T=1]
#' @export
#' @param data a "obdata" object
#' @param stratified if TRUE, stratify the data by data$stratum
#' @param print.level print level (0, 1 or 2)
#' @return estimated causal effect
get.effect.obdata <- function(data, stratified = FALSE, print.level = 0,
                              method.list = c("ebal,nr,none,none"),
                              ...) {
    if (stratified) {
        if (!("stratum" %in% names(data))) {
            stop("The data list doesn't have a field named stratum.")
        }
    } else {
        data$stratum <- factor(rep("0", length(data$T)))
    }

    effect.est <- matrix(0, length(levels(data$stratum)), length(method.list))
    mean0.est <- matrix(0, length(levels(data$stratum)), length(method.list))
    mean1.est <- matrix(0, length(levels(data$stratum)), length(method.list))
    output <- list()

    i <- 0
    for (stratum in levels(data$stratum)) {
        i <- i + 1
        if (print.level >= 1) {
            print(paste("Stratum: ", stratum))
        }

        ind <- which(data$stratum == stratum)
        T <- data$T[ind]
        Y <- data$Y[ind]
        features <- data$features[ind, ]

        j <- 0
        for (method in method.list) {
            j <- j+1
            if (print.level >= 2) {
                print(paste("Method: ", method))
            }

            # unpack method
            method.split <- unlist(strsplit(method, ","))
            T.method <- method.split[1]
            weight.type <- method.split[2]
            Y.method <- method.split[3]
            combine.method <- method.split[4]

            tryCatch(
                output[[stratum]][[method]] <-
                    get.effect(T,
                               Y,
                               features,
                               T.method = T.method,
                               weight.type = weight.type,
                               Y.method = Y.method,
                               combine.method = combine.method,
                               ...),
                error = function(e) {
                    message(paste("An error occured in fitting stratum",
                                  stratum,
                                  "by method",
                                  method))
                    output[[stratum]][[method]] <- NULL
                    try(sink())
                })
            if (!is.null(output[[stratum]][[method]])) {
                mean0.est[i, j] <- output[[stratum]][[method]]$mean.est[1]
                mean1.est[i, j] <- output[[stratum]][[method]]$mean.est[2]
                effect.est[i, j] <- output[[stratum]][[method]]$effect.est
            } else {
                effect.est[i, j] <- NA
            }
        }
    }

    rownames(mean0.est) <- levels(data$stratum)
    colnames(mean0.est) <- method.list
    rownames(mean1.est) <- levels(data$stratum)
    colnames(mean1.est) <- method.list
    rownames(effect.est) <- levels(data$stratum)
    colnames(effect.est) <- method.list

    if (!stratified) {
        output <- output[["0"]]
    }

    return(list(effect.est = effect.est,
                mean0.est = mean0.est,
                mean1.est = mean1.est,
                output = output,
                method.list = method.list))

}
