File: gbm.backcompat.R

package info (click to toggle)
r-cran-plotmo 3.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,400 kB
  • sloc: sh: 13; makefile: 2
file content (152 lines) | stat: -rw-r--r-- 4,941 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# gbm.backcompat.R:
#
# TODO change name of this module? this is actually for new functions (not back compat funcs)
#
# The following functions were added in Oct 2016 for
# Paul Metcalfe's changes to gbm (version 2.2 and higher).
#
# The idea is that we work with both the old and the new gbm models, and
# give error messages appropriate to the object (not to an object
# converted by to_old_gbm).

plotmo.prolog.GBMFit <- function(object, ...)
{
    if(is.null(object$gbm_data_obj))
        stop0("use keep_gbm_data=TRUE in the call to gbmt ",
              "(object$gbm_data_obj is NULL)")

    # "importance" is a vector of variable indices (column numbers in x), most
    # important vars first, no variables with relative.influence < 1%.  We attach
    # it to the object to avoid calling summary.gbm twice (it's expensive).

    attr(object, "plotmo.importance") <- order.GBMFit.vars.on.importance(object)

    object
}
order.GBMFit.vars.on.importance <- function(object)
{
    # order=FALSE so importances correspond to orig variable indices
    importance <- summary(object, plot_it=FALSE,     # calls summary.GBMFit
                          order=FALSE, normalize=TRUE)$rel_inf
    stopifnot(!is.null(importance))
    # NA assignment below so order() drops vars with importance < .01
    importance[importance < .01] <- NA
    importance <- order(importance, decreasing=TRUE, na.last=NA)
    # return a vector of variable indices, most important vars first
    importance[!is.na(importance)]
}
plotmo.singles.GBMFit <- function(object, x, nresponse, trace, all1, ...)
{
    plotmo.singles.gbm(object, x, nresponse, trace, all1, ...)
}
plotmo.pairs.GBMFit <- function(object, ...)
{
    plotmo.pairs.gbm(object, ...)
}
plotmo.x.GBMFit <- function(object, ...)
{
    plotmo_x_gbm_aux(object$gbm_data_obj$x,
                     object$gbm_data_obj$x_order,
                     object$variables$var_levels)
}
plotmo.y.GBMFit <- function(object, ...)
{
    plotmo_y_gbm_aux(object$gbm_data_obj$y, object$gbm_data_obj$x_order)
}
plotmo.predict.GBMFit <- function(object, newdata, type, ..., TRACE)
{
    plotmo.predict.gbm(object, newdata, type, ..., TRACE=TRACE)
}
gbm.short.distribution.name <- function(obj)
{
    substr(tolower(obj$distribution$name), 1, 2)
}
gbm.n.trees <- function(obj)
{
    ncol.fit <- NCOL(obj[["fit"]])
    stopifnot(ncol.fit >= 1) # paranoia
    n.trees <- length(obj$trees) / ncol.fit
    if(!is.null(obj$n.trees))
        stopifnot(obj$n.trees == n.trees) # paranoia
    n.trees
}
gbm.train.fraction <- function(obj)
{
    train.fraction <-
        if(is.null(obj$train.fraction)) {
            # TODO following returns the wrong results
            # obj$params$train_fraction
            # TODO work around
            if(is.null(obj$gbm_data_obj))
                stop0("use keep_gbm_data=TRUE in the call to gbmt ",
                    "(obj$gbm_data_obj is NULL)")
            stopifnot(!is.null(obj$gbm_data_obj$original_data))
            train.fraction <- obj$params$num_train /
                              NROW(obj$gbm_data_obj$original_data)
            # check.numeric.scalar(train.fraction, min=0, max=1)
            # stopifnot(train.fraction > 0)
            train.fraction
        } else
            obj$train.fraction
    check.numeric.scalar(train.fraction, min=0, max=1)
    train.fraction
}
gbm.bag.fraction <- function(obj)
{
    bag.fraction <-
        if(is.null(obj$bag.fraction))
            obj$params$bag_fraction
        else
            obj$bag.fraction
    check.numeric.scalar(bag.fraction, min=0, max=1)
    bag.fraction
}
gbm.cv.folds <- function(obj)
{
    cv.folds <-
        if(is.null(obj$cv.folds))
            obj$cv_folds
        else
            obj$cv.folds
    check.numeric.scalar(cv.folds, min=1, null.ok=TRUE)
    cv.folds
}
gbm.train.error <- function(obj)
{
    train.error <- obj$train.error
    stopifnot(!is.null(train.error))
    stopifnot(is.numeric(train.error))
    stopifnot(length(train.error) == gbm.n.trees(obj))
    train.error
}
gbm.valid.error <- function(obj)
{
    valid.error <- obj$valid.error
    if(!is.null(valid.error)) {
        stopifnot(is.numeric(valid.error))
        stopifnot(length(valid.error) == gbm.n.trees(obj))
    }
    valid.error
}
gbm.oobag.improve <- function(obj)
{
    oobag.improve <- obj$oobag.improve
    if(!is.null(oobag.improve)) {
        stopifnot(is.numeric(oobag.improve))
        stopifnot(length(oobag.improve) == gbm.n.trees(obj))
    }
    oobag.improve
}
gbm.cv.error <- function(obj)
{
    cv.error <-
        if(is.null(obj$cv.error))
            obj$cv_error
        else
            obj$cv.error
    if(!is.null(cv.error)) {
        stopifnot(is.numeric(cv.error))
        stopifnot(length(cv.error) == gbm.n.trees(obj))
    }
    cv.error
}