causaltree 実装メモ

谷口友哉

setup

RのcausalTreeパッケージは Susan Athey氏 の GitHub リポにある
(https://github.com/susanathey/causalTree)
Github からインストールするために devtools というパッケージを事前にインストールする必要がある

#未インストールの場合はインストールする。
#install.packages('devtools')

devtoolsを使ってcausalTreeパッケージをインストール

# devtools::install_github('susanathey/causalTree')

ライブラリ読み込み

library(causalTree)
library(tidyverse)

サンプルデータとしてcausalTreeに入っているデータ’simulation.1’を使う

data(simulation.1)
#データの確認
glimpse(simulation.1)
## Rows: 500
## Columns: 12
## $ x1        <dbl> -1.033764718, -2.212021289, 0.816665094, 0.948383003, -0.3…
## $ x2        <dbl> 1.096283952, -2.376494211, -0.393203916, 0.185135177, 0.88…
## $ x3        <dbl> 1.44293704, -0.06114225, -0.73867288, -0.20959247, -0.6639…
## $ x4        <dbl> 1.10412393, 0.68054135, 0.19381487, -0.20531452, 1.1514359…
## $ x5        <dbl> -0.18084280, -0.09248799, 1.01578250, -1.82232166, -1.5087…
## $ x6        <dbl> -1.04777073, -1.75208767, 1.09590458, 2.27424736, -0.92324…
## $ x7        <dbl> 0.7518880, 0.7996756, -1.2007716, -1.6440004, -0.6371071, …
## $ x8        <dbl> 1.14393981, 1.07846167, 0.36498574, -0.70872635, 0.8016476…
## $ x9        <dbl> -0.1912169, -0.1751672, 1.3911513, 1.6412298, 0.5429438, -…
## $ x10       <dbl> 1.00188041, -2.03595993, 1.19111637, -1.54651067, 0.357306…
## $ y         <dbl> 2.46427360, 0.79174783, 0.08207566, -0.23772178, 0.4574083…
## $ treatment <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…

実装

Susan Athey氏 のGitHubリポにあるドキュメントを参考にする (https://github.com/susanathey/causalTree/blob/master/briefintro.pdf)

Splitting rules

4つの異なる分割ルールが用意されている
1. TOT: Transformed Outcome Trees
2. CT: Causal Trees
3. fit: Fit-based Trees
4. tstats: Squared T-statistic Trees

causalTree function offers four different splitting rules for user to choose.
Each splitting rule corresponds to a specific risk function,   
and each split at a node aims to minimize the risk function.

Discrete splitting

試しにtreeを構築してみる

#tree <- causalTree(y ~ x1 + x2 + x3 + x4, data = simulation.1,
                   #treatment = simulation.1$treatment,
                   #split.Rule = "TOT",
                   #cv.option = "fit", cv.Honest = F,
                   #split.Bucket = T,xval = 10,
                   #cv.alpha = 0.5, propensity = 0.5)
ret <- capture.output({
    tree <- causalTree(y ~ x1 + x2 + x3 + x4, data = simulation.1,
                       treatment = simulation.1$treatment,
                       split.Rule = "TOT",
                       cv.option = "fit", cv.Honest = F,
                       split.Bucket = T,xval = 10)
} )

tree を可視化するにはrpart.plot関数を使う
roundint = FALSEとしておかないと警告が出ることがある
capture.output() を causalTree() に適用するとメッセージ出力を抑制できる

rpart.plot(tree,roundint = FALSE)

  • 剪定をしていないため、深いtreeが構築された

引数の説明

第1引数 : 目的変数と説明変数
data : treeの構築に使用するデータ
treatment : dataに格納したデータで処置を表す2値データ
split.Rule : treeの分割基準.’TOT’,‘CT’,‘fit’,‘tstats’
cv.option : cv(クロスバリデーション)のオプション.’CT’,‘TOT’,‘fit’,‘matching’
cv.Honest : cvをHonest型で実施するかどうか.TRUE or FALSE
split.Bucket : 離散分割を使用するかどうか.TRUE or FALSE
xval : cvのホールド数.xval = 10は10 folds cv
cv.alpha : \(\hat{-EMSE}_\tau\left(S^{tr,cv},N^{est},\Pi,\alpha\right)=\alpha・\frac{1}{N^{tr,cv}}\sum_{i\in S^{tr,cv}}\hat{\tau}^2\left(X_i;S^{tr,cv},\Pi\right)-(1-\alpha)・\left(\frac{1}{N^{tr,cv}}+\frac{1}{N^{est}}\right)・\sum_{\ell\in\Pi}\left(\frac{S_{S^{tr,cv}_{treat}}^2(\ell)}{p}+\frac{S_{S^{tr,cv}_{control}}^2(\ell)}{1-p}\right)\)\(\alpha\)を選択. 式中の第2項の調整のための追加係数
propensity : 処置確率.TOT splitting ruleで必要

Cross Validation and Pruning

剪定に使用される最小のcvエラーに対応する複雑度パラメータを選択するためにcv treeを構築
cvでは、エラーを計算するために異なる評価基準を選択することができる

Cross validation optionsは以下の4種類が用意されている
1. TOT : There is no “honest” option for this method.
2. CT : CT-A or CT-H. cv.alphaで調整可能.
3. fit : fit-A or fit-H. cv.alphaで調整可能.
4. matching : 以下のリスク関数で評価 \[\hat{MSE}_\tau(S^{tr,cv},S^{tr,tr},\Pi)=\sum_{i \in S^{tr,cv}}\left(\tau^*(X_i,W_i;S)-\frac{\hat{\tau}(X_i;S^{tr,tr},\Pi)+\hat{\tau}(X_{n(W_i,X_i;S^{tr,cv})};S^{tr,tr},\Pi)}{2}\right)^2\]
ここで \(\tau^*\left(X_i,W_i;S\right)\equiv(2W_i-1)\left(Y_i-Y_{n\left(W_i,X_i;S\right)}\right)\) であり, \(n\left(W_i,X_i;S\right)\) は特徴空間における \(\left(X_i,Y_i,W_i\right)\) の最近傍と定義

Example

treeを構築
以下では honest splitting rule を CT-H (split.Rule =“CT”, split.Honest = T)とし、 cv method を matching (cv.option = “matching” and cv.Honest = F)の設定を考える

tree1 <- causalTree(y ~ x1 + x2 + x3 + x4, data = simulation.1,
                    treatment = simulation.1$treatment, 
                    split.Rule = "CT",split.Honest = T,
                    cv.option = "matching", cv.Honest = F,
                    split.Bucket = F, xval = 10)
## [1] 2
## [1] "CT"

treeを可視化

rpart.plot(tree1,roundint = FALSE)

複雑度パラメータ(cp)と正規化クロスバリデーションエラー(xerror)を確認するための cptable を出力

tree1$cptable
##              CP nsplit rel error    xerror         xstd
## 1  1.145837e-02      0 1.0000000 1.0000000 0.0023046092
## 2  1.941622e-03      1 0.9885416 0.5522208 0.0013198525
## 3  1.862875e-03      2 0.9866000 0.5628324 0.0013266241
## 4  1.031403e-03      3 0.9847371 0.4987659 0.0011544581
## 5  7.763028e-04      5 0.9826743 0.4830998 0.0010693470
## 6  5.128562e-04      8 0.9803454 0.4788877 0.0010203462
## 7  4.327552e-04     10 0.9793197 0.4799875 0.0010191770
## 8  3.525900e-04     11 0.9788870 0.4755650 0.0009755064
## 9  3.052253e-04     13 0.9781818 0.4776425 0.0009740526
## 10 2.632508e-04     14 0.9778766 0.5143858 0.0010576026
## 11 2.198657e-04     16 0.9773500 0.5434564 0.0012039792
## 12 1.945478e-04     17 0.9771302 0.5885634 0.0012653941
## 13 1.887622e-04     18 0.9769356 0.6105212 0.0013211017
## 14 1.792430e-04     19 0.9767469 0.6048267 0.0013213458
## 15 1.677139e-04     20 0.9765676 0.6157248 0.0013312279
## 16 1.366938e-04     25 0.9756676 0.6225027 0.0013750740
## 17 1.111979e-04     28 0.9752575 0.6320591 0.0013705208
## 18 1.096178e-04     32 0.9747981 0.6303267 0.0013753677
## 19 1.002128e-04     36 0.9743597 0.6319194 0.0013886842
## 20 7.376079e-05     37 0.9742595 0.6333912 0.0013976805
## 21 4.858865e-05     38 0.9741857 0.6322808 0.0013954828
## 22 2.870558e-05     39 0.9741371 0.6356926 0.0014113942
## 23 0.000000e+00     40 0.9741084 0.6325458 0.0014066692

プロットしたtreeは大きくて深いため、最小のcvエラー(xerror:cptable[,4])に対応する複雑度パラメータopcp(:cptable[,1])を選択し、prune()関数を用いて剪定する

opcp <- tree1$cptable[, 1][which.min(tree1$cptable[,4])]
optree <- prune(tree1, cp = opcp)
rpart.plot(optree,roundint = FALSE)

Honest Estimation

In addtion to causalTree, we also support one-step honest re-estimation in function honest.causalTree. It can fit a causalTree model and get honest estimation results with
tree structre built on training sample (including cross validation) and leaf treatment effect
estimates taken from estimation sample.

causalTreeに加えて、関数 honest.causalTreeでワンステップのhonestな再推定もサポートしている
これは、causalTreeモデルを適合させ、学習サンプル(クロスバリデーションを含む)と推定サンプルから取得した葉の処置効果推定値に基づいて構築された木構造を用いて、honestな推定結果を得ることができる

まずはデータをトレーニングデータと推定データに分ける

n <- nrow(simulation.1)
trIdx <- which(simulation.1$treatment == 1)
conIdx <- which(simulation.1$treatment == 0)
train_idx <- c(sample(trIdx, length(trIdx) / 2),
               sample(conIdx, length(conIdx) / 2))
train_data <- simulation.1[train_idx, ]
est_data <- simulation.1[-train_idx, ]

honest.causalTree()関数で木を構築

honestTree <- honest.causalTree(y ~ x1 + x2 + x3 + x4,
                                data = train_data,
                                treatment =train_data$treatment,
                                est_data = est_data,
                                est_treatment=est_data$treatment,
                                split.Rule = "CT", split.Honest = T,
                                HonestSampleSize = nrow(est_data),
                                split.Bucket = T, cv.option = "fit",
                                cv.Honest = F)
## [1] 6
## [1] "CTD"

引数の説明

第1引数          : 目的変数と説明変数
data             : treeの構築に使用するデータ
treatment        : dataに格納したデータで処置を表す2値データ
est_data         : 葉の推定に使用するデータ
est_treatment    : est_dataに格納したデータで処置を表す2値データ
split.Rule     : treeの分割基準.'TOT','CT'
split.Honest     : Honest型の分割をするか.TRUE or FALSE
HonestSampleSize : est_dataの行数
split.Bucket     : 離散分割を使用するかどうか.
cv.option        : cv(クロスバリデーション)のオプション.'CT','TOT','fit','matching'
cv.Honest        : cvをHonest型で実施するかどうか.

treeを可視化

rpart.plot(honestTree,roundint = FALSE)

深くなったtreeを剪定する

opcp <- honestTree$cptable[,1][which.min(honestTree$cptable[,4])]
opTree <- prune(honestTree, opcp)

剪定後のtreeを可視化

rpart.plot(opTree)

ランダムフォレストでよく使われる変数重要度を可視化

変数重要度を可視化する関数

#関数作成
plotVarImp <- function(ranger_fit, top=NULL){
  library(ggplot2)
  
  n <-length(ranger_fit$variable.importance)
  
  pd <- data.frame(Variable = names(ranger_fit$variable.importance[1:n]),
                   Importance = as.numeric(ranger_fit$variable.importance[1:n])) %>% 
    arrange(desc(Importance))
  
  if(is.null(top)){
    pd <- arrange(pd, Importance)
  } else {
    pd <- arrange(pd[1:top,], Importance)
  }
  p <- ggplot(pd, aes(x=factor(Variable, levels=unique(Variable)), y=Importance)) +
    geom_bar(stat="identity") +
    xlab("Variables") + 
    coord_flip()
  plot(p)
}

plotVarImp(opTree)

memo

causalTree():treeの構築はadaptive
honest.causalTree():treeの構築がhonest