File: README.md

package info (click to toggle)
r-cran-prediction 0.3.18-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 604 kB
  • sloc: sh: 13; makefile: 2
file content (184 lines) | stat: -rw-r--r-- 7,400 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
---
title: "Tidy, Type-Safe 'prediction()' Methods"
output: github_document
---

<img src="man/figures/logo.png" align="right" />

The **prediction** and **margins** packages are a combined effort to port the functionality of Stata's (closed source) [`margins`](https://www.stata.com/help.cgi?margins) command to (open source) R. **prediction** is focused on one function - `prediction()` - that provides type-safe methods for generating predictions from fitted regression models. `prediction()` is an S3 generic, which always return a `"data.frame"` class object rather than the mix of vectors, lists, etc. that are returned by the `predict()` methods for various model types. It provides a key piece of underlying infrastructure for the **margins** package. Users interested in generating marginal (partial) effects, like those generated by Stata's `margins, dydx(*)` command, should consider using `margins()` from the sibling project, [**margins**](https://cran.r-project.org/package=margins).

In addition to `prediction()`, this package provides a number of utility functions for generating useful predictions:

 - `find_data()`, an S3 generic with methods that find the data frame used to estimate a regression model. This is a wrapper around `get_all_vars()` that attempts to locate data as well as modify it according to `subset` and `na.action` arguments used in the original modelling call.
 - `mean_or_mode()` and `median_or_mode()`, which provide a convenient way to compute the data needed for predicted values *at means* (or *at medians*), respecting the differences between factor and numeric variables.
 - `seq_range()`, which generates a vector of *n* values based upon the range of values in a variable
 - `build_datalist()`, which generates a list of data frames from an input data frame and a specified set of replacement `at` values (mimicking the `atlist` option of Stata's `margins` command)

## Simple code examples



A major downside of the `predict()` methods for common modelling classes is that the result is not type-safe. Consider the following simple example:


```r
library("stats")
library("datasets")
x <- lm(mpg ~ cyl * hp + wt, data = mtcars)
class(predict(x))
```

```
## [1] "numeric"
```

```r
class(predict(x, se.fit = TRUE))
```

```
## [1] "list"
```

**prediction** solves this issue by providing a wrapper around `predict()`, called `prediction()`, that always returns a tidy data frame with a very simple `print()` method:


```r
library("prediction")
(p <- prediction(x))
```

```
## Data frame with 32 predictions from
##  lm(formula = mpg ~ cyl * hp + wt, data = mtcars)
## with average prediction: 20.0906
```

```r
class(p)
```

```
## [1] "prediction" "data.frame"
```

```r
head(p)
```

```
##    mpg cyl disp  hp drat    wt  qsec vs am gear carb   fitted se.fitted
## 1 21.0   6  160 110 3.90 2.620 16.46  0  1    4    4 21.90488 0.6927034
## 2 21.0   6  160 110 3.90 2.875 17.02  0  1    4    4 21.10933 0.6266557
## 3 22.8   4  108  93 3.85 2.320 18.61  1  1    4    1 25.64753 0.6652076
## 4 21.4   6  258 110 3.08 3.215 19.44  1  0    3    1 20.04859 0.6041400
## 5 18.7   8  360 175 3.15 3.440 17.02  0  0    3    2 17.25445 0.7436172
## 6 18.1   6  225 105 2.76 3.460 20.22  1  0    3    1 19.53360 0.6436862
```

The output always contains the original data (i.e., either data found using the `find_data()` function or passed to the `data` argument to `prediction()`). This makes it much simpler to pass predictions to, e.g., further summary or plotting functions.

Additionally the vast majority of methods allow the passing of an `at` argument, which can be used to obtain predicted values using modified version of `data` held to specific values:


```r
prediction(x, at = list(hp = seq_range(mtcars$hp, 5)))
```

```
## Data frame with 160 predictions from
##  lm(formula = mpg ~ cyl * hp + wt, data = mtcars)
## with average predictions:
```

```
##     hp      x
##   52.0 22.605
##  122.8 19.328
##  193.5 16.051
##  264.2 12.774
##  335.0  9.497
```

This more or less serves as a direct R port of (the subset of functionality of) Stata's `margins` command that calculates predictive marginal means, etc. For calculation of marginal or partial effects, see the [**margins**](https://cran.r-project.org/package=margins) package.

## Supported model classes

The currently supported model classes are:

 - "lm" from `stats::lm()`
 - "glm" from `stats::glm()`, `MASS::glm.nb()`, `glmx::glmx()`, `glmx::hetglm()`, `brglm::brglm()`
 - "ar" from `stats::ar()`
 - "Arima" from `stats::arima()`
 - "arima0" from `stats::arima0()`
 - "biglm" from `biglm::biglm()` (including `"ffdf"` backed models)
 - "betareg" from `betareg::betareg()`
 - "bruto" from `mda::bruto()`
 - "clm" from `ordinal::clm()`
 - "coxph" from `survival::coxph()`
 - "crch" from `crch::crch()`
 - "earth" from `earth::earth()`
 - "fda" from `mda::fda()`
 - "Gam" from `gam::gam()`
 - "gausspr" from `kernlab::gausspr()`
 - "gee" from `gee::gee()`
 - "glimML" from `aod::betabin()`, `aod::negbin()`
 - "glimQL" from `aod::quasibin()`, `aod::quasipois()`
 - "glmnet" from `glmnet::glmnet()`
 - "gls" from `nlme::gls()`
 - "hurdle" from `pscl::hurdle()`
 - "hxlr" from `crch::hxlr()`
 - "ivreg" from `AER::ivreg()`
 - "knnreg" from `caret::knnreg()`
 - "kqr" from `kernlab::kqr()`
 - "ksvm" from `kernlab::ksvm()`
 - "lda" from `MASS:lda()`
 - "lme" from `nlme::lme()`
 - "loess" from `stats::loess()`
 - "lqs" from `MASS::lqs()`
 - "mars" from `mda::mars()`
 - "mca" from `MASS::mca()`
 - "mclogit" from `mclogit::mclogit()`
 - "mda" from `mda::mda()`
 - "merMod" from `lme4::lmer()` and `lme4::glmer()`
 - "mnlogit" from `mnlogit::mnlogit()`
 - "mnp" from `MNP::mnp()`
 - "naiveBayes" from `e1071::naiveBayes()`
 - "nlme" from `nlme::nlme()`
 - "nls" from `stats::nls()`
 - "nnet" from `nnet::nnet()`, `nnet::multinom()`
 - "plm" from `plm::plm()`
 - "polr" from `MASS::polr()`
 - "ppr" from `stats::ppr()`
 - "princomp" from `stats::princomp()`
 - "qda" from `MASS:qda()`
 - "rlm" from `MASS::rlm()`
 - "rpart" from `rpart::rpart()`
 - "rq" from `quantreg::rq()`
 - "selection" from `sampleSelection::selection()`
 - "speedglm" from `speedglm::speedglm()`
 - "speedlm" from `speedglm::speedlm()`
 - "survreg" from `survival::survreg()`
 - "svm" from `e1071::svm()`
 - "svyglm" from `survey::svyglm()`
 - "tobit" from `AER::tobit()`
 - "train" from `caret::train()`
 - "truncreg" from `truncreg::truncreg()`
 - "zeroinfl" from `pscl::zeroinfl()`

## Requirements and Installation

[![CRAN](https://www.r-pkg.org/badges/version/prediction)](https://cran.r-project.org/package=prediction)
![Downloads](https://cranlogs.r-pkg.org/badges/prediction)
[![Build status](https://ci.appveyor.com/api/projects/status/a4tebeoa98cq07gy/branch/master?svg=true)](https://ci.appveyor.com/project/leeper/prediction/branch/master)
[![codecov.io](https://app.codecov.io/github/leeper/prediction?branch=master)](https://app.codecov.io/github/leeper/prediction?branch=master)
[![Project Status: Active - The project has reached a stable, usable state and is being actively developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)

The development version of this package can be installed directly from GitHub using `remotes`:

``` r
if (!require("remotes")) {
    install.packages("remotes")
}
remotes::install_github("leeper/prediction")
```