(R) Stan の出力加工方法
概要
忙しくて2ヶ月連続無更新になりそうになっているところをなんとか回避したいという妥協の産物
stan
およびrstan
のモデルの事後診断機能がやや物足りないのでそれを補うヒント- 本当に簡単な話
2016/10/07:
Accessing the contents of a stanfit object という開発チームの解説記事が出たのでこっちのほうが正確だと思います。
stan は高速で便利な MC シミュレーション専用プログラムだが, 単体ではトレースプロットや stan_rhat など事後分布の収束などを判定する機能はあるものの, モデル評価の指標はない. データの形式さえ合わせれば stan 以外のモジュールを使って計算できるが, 結果を格納する stanfit の構造を調べるのがめんどくさい. あるいは他の方法で計算結果を視覚化したい, という場合もあるだろう. そういう人のためにヒント集を書いておく*1. MCMC の適切な診断方法については各種専門書を呼んで欲しい. たとえば 小西他 (2008) など*2.
出力結果をテーブル状にする
stanfit オブジェクトでなく, テーブル状にすれば ggplot2
などに stan の結果を渡すことができる.
まず, 事後サンプルは配列型にすることができる.
stan.sample <- sampling(hogehoge) # MCサンプル計算 array.stan <- as.array(stan.sample)
すると, [サンプルサイズ,chain数, パラメータ数] の3次元配列array.stan
が得られる. そこで,
for( i in 1:ncol(stan.sample) ){ temp <- as.data.frame(array.stan[,i,]) temp$chain <- i temp$iteration <- 1:nrow(temp) if ( i==1 ) df <- temp else df <- rbind(df, temp) } rm(i, temp)
とする. ポイントは, ncol()
に stanfit オブジェクト(ここでは stan.sample
) を与えると chain の数を返すことだ. 実行すると以下のようなデータフレームになる(iteration, chain は後で追加したので右端に来る).
chain | iteration | パラメータ1 | パラメータ2 | ... | lp__ |
---|---|---|---|---|---|
1 | 1 | X.XX | Y.YY | ... | ... |
1 | 2 | X.XX | Y.YY | ... | ... |
1 | ... | ... | ... | ... | ... |
2 | 1 | ... | ... | ... | ... |
2 | 2 | ... | ... | ... | ... |
2 | ... | ... | ... | ... | ... |
このように, サンプルが行, パラメータが列のデータフレームが得られる. stanfit
オブジェクトに直接 as.data.frame()
でもテーブル上にできるが, chain=
を指定している場合, どの chain から得られたサンプルなのかの情報が欠落してしまう. chain=1
のときのみにしたほうがよい. さらに tidyr
で
df <- df %>% gather(key=pars, value=val, -chain, -iteration)
とすれば, 値がすべて縦持ちになり, dplyr
や ggplot2
へ容易に接続でき, 分位点や平均値を集計したり, 信用区間を図示したりできる*3. また, 特にモデル選択でよく用いられる DIC は, サンプルの個体ごとに評価した対数尤度の平均と, サンプル平均で評価した対数尤度 (厳密には deviance ) を使って求める. 現在 DIC を計算してくれるパッケージはないようなので, 計算したい場合はこの方法でデータフレームを作ると計算しやすいかもしれない. その場合は stan のプログラムで対数尤度を計算する必要もある. DIC でなく WAIC を計算する場合も,
ito-hi.blog.so-net.ne.jp
で言及されているように対数尤度が必要なので, stan を使うときは必ず対数尤度を計算するようプログラムする習慣にしたほうがいいかもしれない *4.
また, rstan::get_posterior_mean()
関数は事後サンプルの chain ごとの平均値だけを返す. 出力は
パラメータ名 | mean-chain1 | mean-chain2 | ... | mean-all chains |
---|---|---|---|---|
パラメータ1 | ... | ... | ... | ... |
パラメータ2 | ... | ... | ... | ... |
... | ... | ... | ... | ... |
のようにパラメータが縦持ちになる. 最後の mean-all chain
列が全体の平均となる.( なぜRその他の言語の変数名としてイレギュラーな名称にしたのか). これは事後平均だけを見て簡単な確認をしたい場合に使える.
などとかいていたのだが, rstan::stan()
および rstan::sampling()
には sample_file=
オプションがあることを今になって知った. このオプションは chain ごとにサンプリング結果をテキストで出してくれるので, 一旦テキストファイルで出してから読み込んだほうが楽かもしれない. ただしヘッダとして先頭から数行に実行時オプションが書かれるので, うまく読み飛ばす必要がある.
参考文献
小西 貞則・越智 義道・大森 祐浩. (2008) 計算統計学の方法 -ブートストラップ・EMアルゴリズム・MCMC-. シリーズ予測と発見の科学 5, 朝倉書店.

計算統計学の方法―ブートストラップ・EMアルゴリズム・MCMC (シリーズ予測と発見の科学 5)
- 作者: 小西貞則,越智義道,大森裕浩
- 出版社/メーカー: 朝倉書店
- 発売日: 2008/03/25
- メディア: 単行本
- 購入: 5人 クリック: 62回
- この商品を含むブログ (7件) を見る