ill-identified diary

所属組織の見解などとは一切関係なく小難しい話しかしません

(R) Stan の出力加工方法

この記事は最終更新日から3年以上が経過しています

概要

  • 忙しくて2ヶ月連続無更新になりそうになっているところをなんとか回避したいという妥協の産物

  • stan および rstan のモデルの事後診断機能がやや物足りないのでそれを補うヒント

  • 本当に簡単な話

2016/10/07:
Accessing the contents of a stanfit object という開発チームの解説記事が出たのでこっちのほうが正確だと思います。

stan は高速で便利な MC シミュレーション専用プログラムだが, 単体ではトレースプロットや stan_rhat など事後分布の収束などを判定する機能はあるものの, モデル評価の指標はない. データの形式さえ合わせれば stan 以外のモジュールを使って計算できるが, 結果を格納する stanfit の構造を調べるのがめんどくさい. あるいは他の方法で計算結果を視覚化したい, という場合もあるだろう. そういう人のためにヒント集を書いておく*1. MCMC の適切な診断方法については各種専門書を呼んで欲しい. たとえば 小西他 (2008) など*2.

出力結果をテーブル状にする

stanfit オブジェクトでなく, テーブル状にすれば ggplot2 などに stan の結果を渡すことができる.

まず, 事後サンプルは配列型にすることができる.

stan.sample <- sampling(hogehoge)  # MCサンプル計算
array.stan <- as.array(stan.sample)

すると, [サンプルサイズ,chain数, パラメータ数] の3次元配列array.stanが得られる. そこで,

for( i in 1:ncol(stan.sample) ){
  temp <- as.data.frame(array.stan[,i,])
  temp$chain <- i
  temp$iteration <- 1:nrow(temp)
  if ( i==1 ) df <- temp
  else df <- rbind(df, temp)
}
rm(i, temp)

とする. ポイントは, ncol() に stanfit オブジェクト(ここでは stan.sample) を与えると chain の数を返すことだ. 実行すると以下のようなデータフレームになる(iteration, chain は後で追加したので右端に来る).

chain iteration パラメータ1 パラメータ2 ... lp__
1 1 X.XX Y.YY ... ...
1 2 X.XX Y.YY ... ...
1 ... ... ... ... ...
2 1 ... ... ... ...
2 2 ... ... ... ...
2 ... ... ... ... ...

このように, サンプルが行, パラメータが列のデータフレームが得られる. stanfit オブジェクトに直接 as.data.frame() でもテーブル上にできるが, chain= を指定している場合, どの chain から得られたサンプルなのかの情報が欠落してしまう. chain=1 のときのみにしたほうがよい. さらに tidyr

df <- df %>% gather(key=pars, value=val, -chain, -iteration)

とすれば, 値がすべて縦持ちになり, dplyrggplot2 へ容易に接続でき, 分位点や平均値を集計したり, 信用区間を図示したりできる*3. また, 特にモデル選択でよく用いられる DIC は, サンプルの個体ごとに評価した対数尤度の平均と, サンプル平均で評価した対数尤度 (厳密には deviance ) を使って求める. 現在 DIC を計算してくれるパッケージはないようなので, 計算したい場合はこの方法でデータフレームを作ると計算しやすいかもしれない. その場合は stan のプログラムで対数尤度を計算する必要もある. DIC でなく WAIC を計算する場合も,
ito-hi.blog.so-net.ne.jp
で言及されているように対数尤度が必要なので, stan を使うときは必ず対数尤度を計算するようプログラムする習慣にしたほうがいいかもしれない *4.

また, rstan::get_posterior_mean() 関数は事後サンプルの chain ごとの平均値だけを返す. 出力は

パラメータ名 mean-chain1 mean-chain2 ... mean-all chains
パラメータ1 ... ... ... ...
パラメータ2 ... ... ... ...
... ... ... ... ...


のようにパラメータが縦持ちになる. 最後の mean-all chain 列が全体の平均となる.( なぜRその他の言語の変数名としてイレギュラーな名称にしたのか). これは事後平均だけを見て簡単な確認をしたい場合に使える.

などとかいていたのだが, rstan::stan() および rstan::sampling() には sample_file= オプションがあることを今になって知った. このオプションは chain ごとにサンプリング結果をテキストで出してくれるので, 一旦テキストファイルで出してから読み込んだほうが楽かもしれない. ただしヘッダとして先頭から数行に実行時オプションが書かれるので, うまく読み飛ばす必要がある.

参考文献


小西 貞則・越智 義道・大森 祐浩. (2008) 計算統計学の方法 -ブートストラップ・EMアルゴリズムMCMC-. シリーズ予測と発見の科学 5, 朝倉書店.

計算統計学の方法―ブートストラップ・EMアルゴリズム・MCMC (シリーズ予測と発見の科学 5)

計算統計学の方法―ブートストラップ・EMアルゴリズム・MCMC (シリーズ予測と発見の科学 5)

*1:そういえば先日 stan の開発者が来日講演を行なったらしいが, その辺の話題に言及はあったのだろうか.

*2:先日アマゾンリストよりプレゼントしていただいた方, ありがとうざいます.

*3:余談だが, rstan じたい ggplot2 に依存している

*4:sasmcmc プロシジャのように与える確率分布が事前分布か尤度かを明示することで対数事後確率だけでなく対数尤度も自動で計算することは原理的に不可能でないと思うのだが, そういう機能を stan 追加する予定はないのだろうか