balance.boost <- function(T, X,
                          alpha = -1,
                          beta = -1,
                          rpart.param = list(maxdepth = 3),
                          nu = 0.1,
                          subsample = 1,
                          max.stage = 100,
                          intercept = TRUE,
                          prune.tree = TRUE) {

    require(rpart)

    n <- length(T)
    m <- ncol(X)

    if (nrow(X) != n) {
        stop("The dimensions of T and X don't match.")
    }

    S.func <- function(f, T) {
        prob <- link(f)
        S <- rep(0, length(f))
        if (alpha == -1 && beta == -1) {
            S[T == 1] <- f[T == 1] - 1 / prob[T == 1]
            S[T == 0] <- - f[T == 0] - 1 / (1 - prob[T == 0])
        } else if (alpha == -1 && beta == 0) {
            S[T == 1] <- - 1 / prob[T == 1]
            S[T == 0] <- - f[T == 0]
        } else if (alpha == 0 && beta == -1) {
            S[T == 1] <- f[T == 1]
            S[T == 0] <- - 1 / (1 - prob[T == 0])
        } else if (alpha == 0 && beta == 0) {
            S[T == 1] <- log(prob[T == 1])
            S[T == 0] <- log(1 - prob[T == 0])
        } else {
            for (i in 1:length(prob)) {
                S[i] <- integrate(function(p) (T[i] - p) * p^(alpha - 1) * (1 - p)^(beta - 1), lower = 1/2, upper = prob[i])$value
            }
        }
        return.value <- sum(S) / n
        if (is.nan(return.value) || return.value == Inf) {
            stop("Can't compute the score.")
        } else {
            return(return.value)
        }
    }

    link <- function(f) {
        1 / (1 + exp(-f))
    }

    get.weights <- function(f, T, normalize = FALSE) {
        prob <- link(f)
        w <- rep(0, length(prob))
        w[T == 1] <- prob[T == 1]^alpha * (1 - prob[T == 1])^(beta+1)
        w[T == 0] <- prob[T == 0]^(alpha+1) * (1 - prob[T == 0])^beta
        if (normalize) {
            w[T == 1] <- w[T == 1] / sum(w[T == 1]) * sum(T)
            w[T == 0] <- w[T == 0] / sum(w[T == 0]) * sum(T)
        }
        w
    }

    imbalance <- function(f, g, T, normalize = FALSE) {
        w <- get.weights(f, T, normalize)
        sum((2 * T - 1) * w * g)
    }

    update.f <- function(f, g, T, normalize = FALSE) {
        lower <- -1
        upper <- 1
        ## while (TRUE) {
        ##     opp.sign <- tryCatch(imbalance(f + lower * g, g, T) * imbalance(f + upper * g, g, T) < 0, error = function(cond) { message(cond); return(NA)})
        ##     if (is.na(opp.sign)) {
        ##         return(NaN)
        ##     }
        ##     if (opp.sign) {
        ##         break
        ##     }
        ##     lower <- lower * 2
        ##     upper <- upper * 2
        ##     if (upper >= 100) {
        ##         return(NaN)
        ##     }
        ## }
        eta <- tryCatch(uniroot(function(eta) imbalance(f + eta * g, g, T, normalize = normalize), c(lower, upper), extendInt = "yes")$root, error = function(e) NA)
        eta
    }

    g <- rep(0, n)
    f <- rep(0, n)
    f <- rep(update.f(f, rep(1, n), T), n)

    f.history <- matrix(0, max.stage + 1, n)
    f.history[1, ] <- f
    w.history <- matrix(0, max.stage + 1, n)
    imba.history <- rep(0, max.stage)
    trees <- list()
    eta <- c()
    f0 <- c()

    for (j in 1:max.stage) {

        if (j <= 10 || j %in% seq(20, 100, 10) || j %% 100 == 0) {
            print(paste("Step:", j))
        }
        ## print(S.func(f))

        w <- get.weights(f, T)
        w.history[j, ] <- w

        flag <- FALSE
        while (TRUE) {
            I <- sample(1:n, size = n * subsample)
            tree <- rpart(w * (2 * T - 1) ~ X, subset = I, control = rpart.param, method = "anova")
            if (prune.tree) {
                cptable <- tree$cptable
                prune.split <- max(which(cptable[, "rel error"] + 2 * cptable[, "xstd"] > cptable[, "xerror"]))
                trees[[j]] <- prune(tree, cp = tree$cptable[prune.split ,"CP"])
            } else {
                trees[[j]] <-  tree
            }
            g <- predict(trees[[j]], data.frame(X))
            eta[j] <- update.f(f[I], g[I], T[I], intercept)
            if (nrow(trees[[j]]$frame) == 1) {
                print("Early stop: no split possible (either balance is satisfactory or maxdepth = 1)")
                flag <- TRUE
                break
            }
            if (subsample == 1 && is.na(eta[j])) {
                ## rpart.param$maxdepth <- rpart.param$maxdepth - 1
                ## print("Cannot update gradient. Possible reason: insufficient overlap in the dataset. Decreasing tree depth...")
                eta[j] <- median(eta[-j]) / 2
                print("Cannot update gradient. Possible reason: insufficient overlap in the dataset. Using the best guess...")
                ## flag <- TRUE
                ## break
            }
            if (!is.na(eta[j])) {
                f.new <- f + eta[j] * nu * g
                break
            }
        }
        imba.history[j] <- abs(sum((2 * T - 1) * w * g)) / sqrt(sum(g^2))
        if (flag) { # if we cannot make split or move
            j <- j-1
            break
        }

        ## if (tryCatch(S.func(f.new[I], T[I]) < S.func(f[I], T[I]), error = function(cond) { message(cond); return(TRUE) })) { # if the gradient step does not increase the score
        ##     print("break at gradient step")
        ##     j <- j - 1
        ##     break
        ## }


        if (intercept) {
            f0[j] <- update.f(f.new, rep(1, n), T)
            if (is.na(f0[j])) { # if we couldn't find an update for intercept
                print("break at intercept step 1")
                j <- j-1
                break
            }
            f.new <- f.new + f0[j]

            ## if (tryCatch(S.func(f.new, T) < S.func(f.new - f0[j], T), error = function(cond) { message(cond); return(TRUE) })) {
            ##     print("break at intercept step 2")
            ##     j <- j - 1
            ##     break
            ## }
        }

        f <- f.new
        f.history[j + 1, ] <- f
    }

    w <- get.weights(f, T)
    w.history[j + 1, ] <- w
    w.history <- w.history[1:(j+1), ]
    f.history <- f.history[1:(j+1), ]
    imba.history <- imba.history[1:j]

    print(c("Steps:", j))

    return(list(trees = trees,
                f = f,
                w = w,
                eta = eta,
                f0 = f0,
                f.history = f.history,
                w.history = w.history,
                imba.history = imba.history))

}
