The prediction and margins packages
are a combined effort to port the functionality of Stata’s (closed
source) margins
command to (open source) R. prediction is focused on
one function - prediction()
- that provides type-safe
methods for generating predictions from fitted regression models.
prediction()
is an S3 generic, which always return a
"data.frame"
class object rather than the mix of vectors,
lists, etc. that are returned by the predict()
methods for
various model types. It provides a key piece of underlying
infrastructure for the margins package. Users
interested in generating marginal (partial) effects, like those
generated by Stata’s margins, dydx(*)
command, should
consider using margins()
from the sibling project, margins.
In addition to prediction()
, this package provides a
number of utility functions for generating useful predictions:
find_data()
, an S3 generic with methods that find the
data frame used to estimate a regression model. This is a wrapper around
get_all_vars()
that attempts to locate data as well as
modify it according to subset
and na.action
arguments used in the original modelling call.mean_or_mode()
and median_or_mode()
, which
provide a convenient way to compute the data needed for predicted values
at means (or at medians), respecting the differences
between factor and numeric variables.seq_range()
, which generates a vector of n
values based upon the range of values in a variablebuild_datalist()
, which generates a list of data frames
from an input data frame and a specified set of replacement
at
values (mimicking the atlist
option of
Stata’s margins
command)A major downside of the predict()
methods for common
modelling classes is that the result is not type-safe. Consider the
following simple example:
library("stats")
library("datasets")
<- lm(mpg ~ cyl * hp + wt, data = mtcars)
x class(predict(x))
## [1] "numeric"
class(predict(x, se.fit = TRUE))
## [1] "list"
prediction solves this issue by providing a wrapper
around predict()
, called prediction()
, that
always returns a tidy data frame with a very simple print()
method:
library("prediction")
<- prediction(x)) (p
## Data frame with 32 predictions from
## lm(formula = mpg ~ cyl * hp + wt, data = mtcars)
## with average prediction: 20.0906
class(p)
## [1] "prediction" "data.frame"
head(p)
## mpg cyl disp hp drat wt qsec vs am gear carb fitted se.fitted
## 1 21.0 6 160 110 3.90 2.620 16.46 0 1 4 4 21.90488 0.6927034
## 2 21.0 6 160 110 3.90 2.875 17.02 0 1 4 4 21.10933 0.6266557
## 3 22.8 4 108 93 3.85 2.320 18.61 1 1 4 1 25.64753 0.6652076
## 4 21.4 6 258 110 3.08 3.215 19.44 1 0 3 1 20.04859 0.6041400
## 5 18.7 8 360 175 3.15 3.440 17.02 0 0 3 2 17.25445 0.7436172
## 6 18.1 6 225 105 2.76 3.460 20.22 1 0 3 1 19.53360 0.6436862
The output always contains the original data (i.e., either data found
using the find_data()
function or passed to the
data
argument to prediction()
). This makes it
much simpler to pass predictions to, e.g., further summary or plotting
functions.
Additionally the vast majority of methods allow the passing of an
at
argument, which can be used to obtain predicted values
using modified version of data
held to specific values:
prediction(x, at = list(hp = seq_range(mtcars$hp, 5)))
## Data frame with 160 predictions from
## lm(formula = mpg ~ cyl * hp + wt, data = mtcars)
## with average predictions:
## hp x
## 52.0 22.605
## 122.8 19.328
## 193.5 16.051
## 264.2 12.774
## 335.0 9.497
This more or less serves as a direct R port of (the subset of
functionality of) Stata’s margins
command that calculates
predictive marginal means, etc. For calculation of marginal or partial
effects, see the margins
package.
The currently supported model classes are:
stats::lm()
stats::glm()
, MASS::glm.nb()
,
glmx::glmx()
, glmx::hetglm()
,
brglm::brglm()
stats::ar()
stats::arima()
stats::arima0()
biglm::biglm()
(including
"ffdf"
backed models)betareg::betareg()
mda::bruto()
ordinal::clm()
survival::coxph()
crch::crch()
earth::earth()
mda::fda()
gam::gam()
kernlab::gausspr()
gee::gee()
aod::betabin()
,
aod::negbin()
aod::quasibin()
,
aod::quasipois()
glmnet::glmnet()
nlme::gls()
pscl::hurdle()
crch::hxlr()
AER::ivreg()
caret::knnreg()
kernlab::kqr()
kernlab::ksvm()
MASS:lda()
nlme::lme()
stats::loess()
MASS::lqs()
mda::mars()
MASS::mca()
mclogit::mclogit()
mda::mda()
lme4::lmer()
and
lme4::glmer()
mnlogit::mnlogit()
MNP::mnp()
e1071::naiveBayes()
nlme::nlme()
stats::nls()
nnet::nnet()
,
nnet::multinom()
plm::plm()
MASS::polr()
stats::ppr()
stats::princomp()
MASS:qda()
MASS::rlm()
rpart::rpart()
quantreg::rq()
sampleSelection::selection()
speedglm::speedglm()
speedglm::speedlm()
survival::survreg()
e1071::svm()
survey::svyglm()
AER::tobit()
caret::train()
truncreg::truncreg()
pscl::zeroinfl()
The development version of this package can be installed directly
from GitHub using remotes
:
if (!require("remotes")) {
install.packages("remotes")
}::install_github("leeper/prediction") remotes