#!/usr/bin/env python
# coding: utf-8

# DO NOT EDIT
# Autogenerated from the notebook treatment_effect.ipynb.
# Edit the notebook and then sync the output with this file.
#
# flake8: noqa
# DO NOT EDIT

# ## Treatment effects under conditional independence
#
# Author: Josef Perktold
#
# This notebook illustrates the basic usage of the new treatment effect
# functionality in statsmodels.
#
# The main class is
# `statsmodels.treatment.treatment_effects.TreatmentEffect`.
#
#
# This class estimates treatment effect and potential outcome using 5
# different methods, ipw, ra, aipw, aipw-wls, ipw-ra. The last three methods
# require both a treatment or selection model and an outcome model.
# Standard errors and inference are based on the joint GMM representation
# of selection or treatment model, outcome model and effect functions. The
# approach for inference follows Stata, however Stata support a wider range
# of models.
# Estimation and inference are valid under conditional independence or
# ignorability.
#
# The outcome model is currently limited to a linear model based on OLS.
# Treatment is currently restricted to binary treatment which can be
# either Logit or Probit.
#
# The example follows Cattaneo.

import os
import numpy as np
from numpy.testing import assert_allclose
import pandas as pd

from statsmodels.regression.linear_model import OLS
from statsmodels.discrete.discrete_model import Probit
from statsmodels.treatment.treatment_effects import (TreatmentEffect)

from statsmodels.treatment.tests.results import results_teffects as res_st

# Load data for example
cur_dir = os.path.abspath(os.path.dirname(res_st.__file__))
file_name = 'cataneo2.csv'
file_path = os.path.join(cur_dir, file_name)
dta_cat = pd.read_csv(file_path)

methods = ['ra', 'ipw', 'aipw', 'aipw_wls', 'ipw_ra']
methods_st = [
    ("ra", res_st.results_ra),
    ("ipw", res_st.results_ipw),
    ("aipw", res_st.results_aipw),
    ("aipw_wls", res_st.results_aipw_wls),
    ("ipw_ra", res_st.results_ipwra),
]

# allow wider display of data frames
pd.set_option('display.width', 500)

dta_cat.head()

# ### Create TreatmentEffect instance and compute ipw
#
# The TreatmentEffect class requires
# - a OLS model instance for the outcome model,
# - a results instance of the selection model and
# - a treatment indicator variable.
#
# In the following example we use Probit as the selection model. Using
# Logit is also supported.
#

# treatment selection model
formula = 'mbsmoke_ ~ mmarried_ + mage + mage2 + fbaby_ + medu'
res_probit = Probit.from_formula(formula, dta_cat).fit()

# outcome model
formula_outcome = 'bweight ~ prenatal1_ + mmarried_ + mage + fbaby_'
mod = OLS.from_formula(formula_outcome, dta_cat)

# treatment indicator variable
tind = np.asarray(dta_cat['mbsmoke_'])

teff = TreatmentEffect(mod, tind, results_select=res_probit)

# After creating the TreatmentEffect instance, we can call any of the 5
# methods to compute potential outcomes, POM0, POM1, and average treatment
# effect, ATE. POM0 is the potential outcome for the no treatment group,
# POM1 is the potential outcome for the treatment group, treatment effect is
# POM1 - POM0.
#
# For example `teff.ipw()` computes POM and ATE using inverse probability
# weighting. The probability of treatment is also commonly called the
# propensity score. The `summary` of the estimation includes standard errors
# and confidence interval for POM and ATE.
#
#
# Standard errors and other inferential statistics are based on the
# Generalized Method of Moments (GMM) representation of the selection and
# outcome models and the moment conditions for the results statistic.
# Method `ipw` uses the selection model but not the outcome model.
# Method `ra` uses the outcome model but not the selection model.
# The doubly robust estimators `aipw`, `aipw-wls`, `ipw-ra` include both
# selection and outcome models, where at least one of those two has to be
# correctly specified to get consistent estimates of the treatment effect.
# The moment conditions for the target variables, POM0, POM1, and ATE are
# based on POM0 and ATE. The remaining POM1 is computed as a linear
# combination of POM0 and ATE.
#
# The internal gmm results are attached to the treatment results as
# `results_gmm`.
#
# By default the treatment effect methods computes average treatment
# effect, where average is take over the sample observations.
# Option `effect_group` can be used to compute either average treatment
# effect on the treated, ATT, using `effect_group=1` or average treatment
# effect on the non-treated using `effect_group=0`.
#

res = teff.ipw()
res

res.summary_frame()

print(res.results_gmm.summary())

# **average treatment effect on the treated**
#
# see more below

teff.ipw(effect_group=1)

# **average treatment effect on the untreated**

teff.ipw(effect_group=0)

# Other methods to compute ATE work in the same or similar way as for
# `ipw` for example regression adjustment `ra` and double robust `ipw_ra`.

res_ra = teff.ra()
res_ra

res_ra.summary_frame()

ra2 = teff.ipw_ra(effect_group=1, return_results=True)
ra2.summary_frame()

# ## All methods in TreatmentEffect
#
# The following computes and prints ATE and POM for all methods.
# (We include the call to TreatmentEffect as a reminder.)

teff = TreatmentEffect(mod, tind, results_select=res_probit)

for m in methods:
    res = getattr(teff, m)()
    print("\n", m)
    print(res.summary_frame())

# ## Results in Stata
#
# The results in statsmodels are very close to the results in Stata
# because both packages use the same approach.

for m, st in methods_st:
    print("\n", m)
    res = pd.DataFrame(st.table[:2, :6],
                       index=["ATE", "POM0"],
                       columns=st.table_colnames[:6])
    print(res)

# ### Treatment effects without inference
#
# It is possible to compute POM and ATE without computing standard errors
# and inferential statistics. In this case the GMM model is not computed.

for m in methods:
    print("\n", m)
    res = getattr(teff, m)(return_results=False)
    print(res)

# ## Treatment effect on the treated
#
# Treatment effects on subgroups are not available for `aipw` and `aipw-
# wls`.
#
# `effect_group` choses the group for which treatement effect and
# potential outcomes are computed
# Options are
# "all" for sample average treatment effect,
# `1` for average treatment effect on the treated and
# `0` for average treatment effect on the untreated.
#
# Note: The row labels in the pandas dataframe, POM and ATE, are the same
# even for treatment effect on subgroups.

for m in methods:
    if m.startswith("aipw"):
        continue
    res = getattr(teff, m)(effect_group=1)
    print("\n", m)
    print(res.summary_frame())

# ### Treatment effect on the untreated
#
# Similar to ATT, we can compute average treatment effect on the untreated
# by using option `effect_group=0`.

for m in methods:
    if m.startswith("aipw"):
        # not available
        continue
    res = getattr(teff, m)(effect_group=0)
    print("\n", m)
    print(res.summary_frame())

# The docstring for the TreatmentEffect class and it's methods can be
# obtained using help
#
# `help(teff)`
