File: plot.train.R

package info (click to toggle)
r-cran-caret 7.0-1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 4,036 kB
  • sloc: ansic: 210; sh: 10; makefile: 2
file content (265 lines) | stat: -rw-r--r-- 11,654 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
#' Plot Method for the train Class
#'
#' This function takes the output of a \code{\link{train}} object and creates a
#' line or level plot using the \pkg{lattice} or \pkg{ggplot2} libraries.
#'
#' If there are no tuning parameters, or none were varied, an error is
#' produced.
#'
#' If the model has one tuning parameter with multiple candidate values, a plot
#' is produced showing the profile of the results over the parameter. Also, a
#' plot can be produced if there are multiple tuning parameters but only one is
#' varied.
#'
#' If there are two tuning parameters with different values, a plot can be
#' produced where a different line is shown for each value of of the other
#' parameter. For three parameters, the same line plot is created within
#' conditioning panels/facets of the other parameter.
#'
#' Also, with two tuning parameters (with different values), a levelplot (i.e.
#' un-clustered heatmap) can be created. For more than two parameters, this
#' plot is created inside conditioning panels/facets.
#'
#' @aliases plot.train ggplot.train
#' @param x an object of class \code{\link{train}}.
#' @param metric What measure of performance to plot. Examples of possible
#' values are "RMSE", "Rsquared", "Accuracy" or "Kappa". Other values can be
#' used depending on what metrics have been calculated.
#' @param plotType a string describing the type of plot (\code{"scatter"},
#' \code{"level"} or \code{"line"} (\code{plot} only))
#' @param digits an integer specifying the number of significant digits used to
#' label the parameter value.
#' @param xTrans a function that will be used to scale the x-axis in scatter
#' plots.
#' @param data an object of class \code{\link{train}}.
#' @param output either "data", "ggplot" or "layered". The first returns a data
#' frame while the second returns a simple \code{ggplot} object with no layers.
#' The third value returns a plot with a set of layers.
#' @param nameInStrip a logical: if there are more than 2 tuning parameters,
#' should the name and value be included in the panel title?
#' @param highlight a logical: if \code{TRUE}, a diamond is placed around the
#' optimal parameter setting for models using grid search.
#' @param mapping,environment unused arguments to make consistent with
#' \pkg{ggplot2} generic method
#' @param \dots \code{plot} only: specifications to be passed to
#' \code{\link[lattice]{levelplot}}, \code{\link[lattice]{xyplot}},
#' \code{\link[lattice:xyplot]{stripplot}} (for line plots). The function
#' automatically sets some arguments (e.g. axis labels) but passing in values
#' here will over-ride the defaults
#' @author Max Kuhn
#' @seealso \code{\link{train}}, \code{\link[lattice]{levelplot}},
#' \code{\link[lattice]{xyplot}}, \code{\link[lattice:xyplot]{stripplot}},
#' \code{\link[ggplot2]{ggplot}}
#' @references Kuhn (2008), ``Building Predictive Models in R Using the caret''
#' (\doi{10.18637/jss.v028.i05})
#' @keywords hplot
#' @method plot train
#' @export
#' @examples
#'
#'
#' \dontrun{
#' library(klaR)
#' rdaFit <- train(Species ~ .,
#'                 data = iris,
#'                 method = "rda",
#'                 control = trainControl(method = "cv"))
#' plot(rdaFit)
#' plot(rdaFit, plotType = "level")
#'
#' ggplot(rdaFit) + theme_bw()
#'
#' }
#'
#' @export plot.train
"plot.train" <-  function(x,
                    plotType = "scatter",
                    metric = x$metric[1],
                    digits = getOption("digits") - 3,
                    xTrans = NULL,
                    nameInStrip = FALSE,
                    ...)
  {


    ## Error trap
    if(!(plotType %in% c("level", "scatter", "line"))) stop("plotType must be either level, scatter or line")

    cutpoints <- c(0, 1.9, 2.9, 3.9, Inf)

    ## define some functions
    prettyVal <- function(u, dig, Name = NULL)
      {
        if(is.numeric(u))
          {
            if(!is.null(Name)) u <- paste(gsub(".", " ", Name, fixed = TRUE),
                                          ": ",
                                          format(u, digits = dig), sep = "")
            return(factor(u))
          } else return(if(!is.factor(u)) as.factor(u) else u)
      }

    ## Get tuning parameter info

    params <- as.character(x$modelInfo$parameters$parameter)

    if(grepl("adapt", x$control$method))
      warning("When using adaptive resampling, this plot may not accurately capture the relationship between the tuning parameters and model performance.")

    plotIt <- "yes"
    if(all(params == "parameter"))
      {
        plotIt <- "There are no tuning parameters for this model."
      } else {
        dat <- x$results

        ## Some exceptions for pretty printing
        if(x$method == "nb") dat$usekernel <- factor(ifelse(dat$usekernel, "Nonparametric", "Gaussian"))
        if(x$method == "gam") dat$select <- factor(ifelse(dat$select, "Feature Selection", "No Feature Selection"))
        if(x$method == "qrnn") dat$bag <- factor(ifelse(dat$bag, "Bagging", "No Bagging"))
        if(x$method == "C5.0") dat$winnow <- factor(ifelse(dat$winnow, "Winnowing", "No Winnowing"))
##        if(x$method %in% c("M5Rules", "M5", "PART")) dat$pruned <- factor(ifelse(dat$pruned == "Yes", "Pruned", "Unpruned"))
##        if(x$method %in% c("M5Rules", "M5")) dat$smoothed <- factor(ifelse(dat$smoothed == "Yes", "Smoothed", "Unsmoothed"))
        if(x$method == "M5") dat$rules <- factor(ifelse(dat$rules == "Yes", "Rules", "Trees"))
##        if(x$method == "vbmpRadial") dat$estimateTheta <- factor(ifelse(dat$estimateTheta == "Yes", "Estimate Theta", "Do Not Estimate Theta"))


        ## Check to see which tuning parameters were varied

        paramValues <- apply(dat[,params,drop = FALSE],
                             2,
                             function(x) length(unique(x)))
        ##paramValues <- paramValues[order(paramValues)]
        if(any(paramValues > 1))
          {
            params <- names(paramValues)[paramValues > 1]
          } else plotIt <- "There are no tuning parameters with more than 1 value."
      }

    if(plotIt == "yes")
      {
        p <- length(params)
        dat <- dat[, c(metric, params)]
        if(p > 1) {
          numUnique <- unlist(lapply(dat[, -1], function(x) length(unique(x))))
          numUnique <- sort(numUnique,  decreasing = TRUE)
          dat <- dat[, c(metric, names(numUnique))]
          params <- names(numUnique)
        }
        ## The conveintion is that the first parameter (in
        ## position #2 of dat) is plotted on the x-axis,
        ## the second parameter is the grouping variable
        ## and the rest are conditioning variables
        if(!is.null(xTrans) & plotType == "scatter") dat[,2] <- xTrans(dat[,2])

        ## We need to pretty-up some of the values of grouping
        ## or conditioning variables


        resampText <- resampName(x, FALSE)

        if(plotType %in% c("line", "scatter"))
          {

            if(plotType == "scatter")
              {
                if(p >= 2) for(i in 3:ncol(dat))
                  dat[,i] <- prettyVal(dat[,i], dig = digits, Name = if(i > 3) params[i-1] else  NULL)
              } else {
                for(i in 2:ncol(dat))
                  dat[,i] <- prettyVal(dat[,i], dig = digits, Name = if(i > 3) params[i-1] else  NULL)
              }
            for(i in 2:ncol(dat)) if(is.logical(dat[,i])) dat[,i] <- factor(dat[,i])
            if(p > 2 & nameInStrip) {
              strip_vars <- params[-(1:2)]
              strip_lab <- subset(x$modelInfo$parameters, parameter %in% strip_vars)$label
              for(i in seq_along(strip_vars))
                dat[, strip_vars[i]] <- factor(paste(strip_lab[i], dat[, strip_vars[i]], sep = ": "))
            }
            ## make formula
            form <- if(p <= 2)
              {
                as.formula(
                           paste(metric, "~", params[1], sep = ""))
              } else as.formula(paste(metric, "~", params[1], "|",
                                      paste(params[-(1:2)], collapse = "*"),
                                      sep = ""))
            defaultArgs <- list(x = form,
                                data = dat,
                                groups = if(p > 1) dat[,params[2]] else NULL)
            if(length(list(...)) > 0) defaultArgs <- c(defaultArgs, list(...))
            lNames <- names(defaultArgs)
            if(!("ylab" %in% lNames))  defaultArgs$ylab <- paste(metric, resampText)

            if(!("type" %in% lNames) & plotType == "scatter") defaultArgs$type <- c("g", "o")
            if(!("type" %in% lNames) & plotType == "line") defaultArgs$type <- c("g", "o")
            if(p > 1)
              {
                ## I apologize in advance for the following 3 line kludge.
                groupCols <- 4
                if(length(unique(dat[,3])) < 4) groupCols <- length(unique(dat[,3]))
                if(length(unique(dat[,3])) %in% 5:6) groupCols <- 3

                groupCols <- as.numeric(
                                        cut(length(unique(dat[,3])),
                                            cutpoints,
                                            include.lowest = TRUE))

                if(!(any(c("key", "auto.key") %in% lNames)))
                  defaultArgs$auto.key <- list(columns = groupCols,
                                               lines = TRUE,
                                               title = as.character(x$modelInfo$parameter$label)[x$modelInfo$parameter$parameter == params[2]],
                                               cex.title = 1)
              }
            if(!("xlab" %in% lNames)) defaultArgs$xlab <- as.character(x$modelInfo$parameter$label)[x$modelInfo$parameter$parameter == params[1]]

            if(plotType == "scatter")
              {
                out <- do.call("xyplot", defaultArgs)
              } else {
                ## line plot #########################
                out <- do.call("stripplot", defaultArgs)
              }

          }

        if(plotType == "level")
          {
            if(p == 1) stop("There must be at least 2 tuning parameters with multiple values")

            for(i in 2:ncol(dat))
              dat[,i] <- prettyVal(dat[,i], dig = digits, Name = if(i > 3) params[i-1] else  NULL)
            if(p > 2 & nameInStrip) {
              strip_vars <- params[-(1:2)]
              strip_lab <- subset(x$modelInfo$parameters, parameter %in% strip_vars)$label
              for(i in seq_along(strip_vars))
                dat[, strip_vars[i]] <- factor(paste(strip_lab[i], dat[, strip_vars[i]], sep = ": "))
            }
            ## make formula
            form <- if(p <= 2)
              {
                as.formula(paste(metric, "~", params[1], "*", params[2], sep = ""))
              } else as.formula(paste(metric, "~", params[1], "*", params[2], "|",
                                      paste(params[-(1:2)], collapse = "*"),
                                      sep = ""))
            defaultArgs <- list(x = form, data = dat)
            if(length(list(...)) > 0) defaultArgs <- c(defaultArgs, list(...))
            lNames <- names(defaultArgs)
            if(!("sub" %in% lNames)) defaultArgs$sub <- paste(metric, resampText)

            if(!("xlab" %in% lNames)) defaultArgs$xlab <- as.character(x$modelInfo$parameter$label)[x$modelInfo$parameter$parameter == params[1]]
            if(!("ylab" %in% lNames)) defaultArgs$ylab <- as.character(x$modelInfo$parameter$label)[x$modelInfo$parameter$parameter == params[2]]



            out <- do.call("levelplot", defaultArgs)
          }

      } else stop(plotIt)

    out


  }