1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
|
---
title: "Tidy, Type-Safe 'prediction()' Methods"
output: github_document
---
<img src="man/figures/logo.png" align="right" />
The **prediction** and **margins** packages are a combined effort to port the functionality of Stata's (closed source) [`margins`](https://www.stata.com/help.cgi?margins) command to (open source) R. **prediction** is focused on one function - `prediction()` - that provides type-safe methods for generating predictions from fitted regression models. `prediction()` is an S3 generic, which always return a `"data.frame"` class object rather than the mix of vectors, lists, etc. that are returned by the `predict()` methods for various model types. It provides a key piece of underlying infrastructure for the **margins** package. Users interested in generating marginal (partial) effects, like those generated by Stata's `margins, dydx(*)` command, should consider using `margins()` from the sibling project, [**margins**](https://cran.r-project.org/package=margins).
In addition to `prediction()`, this package provides a number of utility functions for generating useful predictions:
- `find_data()`, an S3 generic with methods that find the data frame used to estimate a regression model. This is a wrapper around `get_all_vars()` that attempts to locate data as well as modify it according to `subset` and `na.action` arguments used in the original modelling call.
- `mean_or_mode()` and `median_or_mode()`, which provide a convenient way to compute the data needed for predicted values *at means* (or *at medians*), respecting the differences between factor and numeric variables.
- `seq_range()`, which generates a vector of *n* values based upon the range of values in a variable
- `build_datalist()`, which generates a list of data frames from an input data frame and a specified set of replacement `at` values (mimicking the `atlist` option of Stata's `margins` command)
## Simple code examples
A major downside of the `predict()` methods for common modelling classes is that the result is not type-safe. Consider the following simple example:
```r
library("stats")
library("datasets")
x <- lm(mpg ~ cyl * hp + wt, data = mtcars)
class(predict(x))
```
```
## [1] "numeric"
```
```r
class(predict(x, se.fit = TRUE))
```
```
## [1] "list"
```
**prediction** solves this issue by providing a wrapper around `predict()`, called `prediction()`, that always returns a tidy data frame with a very simple `print()` method:
```r
library("prediction")
(p <- prediction(x))
```
```
## Data frame with 32 predictions from
## lm(formula = mpg ~ cyl * hp + wt, data = mtcars)
## with average prediction: 20.0906
```
```r
class(p)
```
```
## [1] "prediction" "data.frame"
```
```r
head(p)
```
```
## mpg cyl disp hp drat wt qsec vs am gear carb fitted se.fitted
## 1 21.0 6 160 110 3.90 2.620 16.46 0 1 4 4 21.90488 0.6927034
## 2 21.0 6 160 110 3.90 2.875 17.02 0 1 4 4 21.10933 0.6266557
## 3 22.8 4 108 93 3.85 2.320 18.61 1 1 4 1 25.64753 0.6652076
## 4 21.4 6 258 110 3.08 3.215 19.44 1 0 3 1 20.04859 0.6041400
## 5 18.7 8 360 175 3.15 3.440 17.02 0 0 3 2 17.25445 0.7436172
## 6 18.1 6 225 105 2.76 3.460 20.22 1 0 3 1 19.53360 0.6436862
```
The output always contains the original data (i.e., either data found using the `find_data()` function or passed to the `data` argument to `prediction()`). This makes it much simpler to pass predictions to, e.g., further summary or plotting functions.
Additionally the vast majority of methods allow the passing of an `at` argument, which can be used to obtain predicted values using modified version of `data` held to specific values:
```r
prediction(x, at = list(hp = seq_range(mtcars$hp, 5)))
```
```
## Data frame with 160 predictions from
## lm(formula = mpg ~ cyl * hp + wt, data = mtcars)
## with average predictions:
```
```
## hp x
## 52.0 22.605
## 122.8 19.328
## 193.5 16.051
## 264.2 12.774
## 335.0 9.497
```
This more or less serves as a direct R port of (the subset of functionality of) Stata's `margins` command that calculates predictive marginal means, etc. For calculation of marginal or partial effects, see the [**margins**](https://cran.r-project.org/package=margins) package.
## Supported model classes
The currently supported model classes are:
- "lm" from `stats::lm()`
- "glm" from `stats::glm()`, `MASS::glm.nb()`, `glmx::glmx()`, `glmx::hetglm()`, `brglm::brglm()`
- "ar" from `stats::ar()`
- "Arima" from `stats::arima()`
- "arima0" from `stats::arima0()`
- "biglm" from `biglm::biglm()` (including `"ffdf"` backed models)
- "betareg" from `betareg::betareg()`
- "bruto" from `mda::bruto()`
- "clm" from `ordinal::clm()`
- "coxph" from `survival::coxph()`
- "crch" from `crch::crch()`
- "earth" from `earth::earth()`
- "fda" from `mda::fda()`
- "Gam" from `gam::gam()`
- "gausspr" from `kernlab::gausspr()`
- "gee" from `gee::gee()`
- "glimML" from `aod::betabin()`, `aod::negbin()`
- "glimQL" from `aod::quasibin()`, `aod::quasipois()`
- "glmnet" from `glmnet::glmnet()`
- "gls" from `nlme::gls()`
- "hurdle" from `pscl::hurdle()`
- "hxlr" from `crch::hxlr()`
- "ivreg" from `AER::ivreg()`
- "knnreg" from `caret::knnreg()`
- "kqr" from `kernlab::kqr()`
- "ksvm" from `kernlab::ksvm()`
- "lda" from `MASS:lda()`
- "lme" from `nlme::lme()`
- "loess" from `stats::loess()`
- "lqs" from `MASS::lqs()`
- "mars" from `mda::mars()`
- "mca" from `MASS::mca()`
- "mclogit" from `mclogit::mclogit()`
- "mda" from `mda::mda()`
- "merMod" from `lme4::lmer()` and `lme4::glmer()`
- "mnlogit" from `mnlogit::mnlogit()`
- "mnp" from `MNP::mnp()`
- "naiveBayes" from `e1071::naiveBayes()`
- "nlme" from `nlme::nlme()`
- "nls" from `stats::nls()`
- "nnet" from `nnet::nnet()`, `nnet::multinom()`
- "plm" from `plm::plm()`
- "polr" from `MASS::polr()`
- "ppr" from `stats::ppr()`
- "princomp" from `stats::princomp()`
- "qda" from `MASS:qda()`
- "rlm" from `MASS::rlm()`
- "rpart" from `rpart::rpart()`
- "rq" from `quantreg::rq()`
- "selection" from `sampleSelection::selection()`
- "speedglm" from `speedglm::speedglm()`
- "speedlm" from `speedglm::speedlm()`
- "survreg" from `survival::survreg()`
- "svm" from `e1071::svm()`
- "svyglm" from `survey::svyglm()`
- "tobit" from `AER::tobit()`
- "train" from `caret::train()`
- "truncreg" from `truncreg::truncreg()`
- "zeroinfl" from `pscl::zeroinfl()`
## Requirements and Installation
[](https://cran.r-project.org/package=prediction)

[](https://ci.appveyor.com/project/leeper/prediction/branch/master)
[](https://app.codecov.io/github/leeper/prediction?branch=master)
[](https://www.repostatus.org/#active)
The development version of this package can be installed directly from GitHub using `remotes`:
``` r
if (!require("remotes")) {
install.packages("remotes")
}
remotes::install_github("leeper/prediction")
```
|