獨斷論

Statistical Rethinking 제9장 연습문제 풀이 본문

과학과 기술/R 통계

Statistical Rethinking 제9장 연습문제 풀이

부르칸 2021. 4. 1. 07:45

9M1 풀이

9장 본문의 R 스크립트 중 sigma 부분만 바꾸면 된다.

아래 m91과 m911을 비교

library(rethinking)

rm(list=ls())

data(rugged) 
d = rugged
d$log_gdp = log(d$rgdppc_2000)
dd = d[complete.cases(d$rgdppc_2000), ]
dd$log_gdp_std = dd$log_gdp / mean(dd$log_gdp)
dd$rugged_std = dd$rugged / max(dd$rugged)
dd$cid = ifelse(dd$cont_africa == 1, 1, 2)

dat_slim = list(
  log_gdp_std = dd$log_gdp_std, 
  rugged_std = dd$rugged_std, 
  cid = as.integer(dd$cid)
)
rm(rugged, d, dd)

m91 = ulam(
  alist(
    log_gdp_std ~ dnorm(mu, sigma), 
    mu <- a[cid] + b[cid]*(rugged_std - 0.215), 
    a[cid] ~ dnorm(1, 0.1), 
    b[cid] ~ dnorm(0, 0.3),
    sigma ~ dexp(1)
  ),
  data = dat_slim
)

m911 = ulam(
  alist(
    log_gdp_std ~ dnorm(mu, sigma), 
    mu <- a[cid] + b[cid]*(rugged_std - 0.215), 
    a[cid] ~ dnorm(1, 0.1), 
    b[cid] ~ dnorm(0, 0.3),
    sigma ~ dunif(0, 1)
  ),
  data = dat_slim 
)

precis(m91, depth = 2)
precis(m911, depth = 2)

sigma를 dunif(0,1)을 사용하여도 posterior distribution에 영향을 미치지 않는다.

> precis(m91, depth = 2)
       mean   sd  5.5% 94.5% n_eff Rhat4
a[1]   0.89 0.02  0.86  0.91   583     1
a[2]   1.05 0.01  1.03  1.07   679     1
b[1]   0.13 0.07  0.02  0.25   705     1
b[2]  -0.14 0.06 -0.24 -0.05   678     1
sigma  0.11 0.01  0.10  0.12   622     1

> precis(m911, depth = 2)
       mean   sd  5.5% 94.5% n_eff Rhat4
a[1]   0.89 0.02  0.86  0.91   656     1
a[2]   1.05 0.01  1.04  1.07   746     1
b[1]   0.13 0.08  0.01  0.25   656     1
b[2]  -0.14 0.05 -0.23 -0.05   431     1
sigma  0.11 0.01  0.10  0.12   800     1

 

9M2 풀이

m912 = ulam(
  alist(
    log_gdp_std ~ dnorm(mu, sigma), 
    mu <- a[cid] + b[cid]*(rugged_std - 0.215), 
    a[cid] ~ dnorm(1, 0.1), 
    b[cid] ~ dexp(0.3),
    sigma ~ dexp(1)
  ),
  data = dat_slim 
)

precis(m91, depth = 2)
precis(m912, depth = 2)

m912에서 b[cid] 부분만 dexp(0.3)으로 바꾸어준다.

> precis(m91, depth = 2)
       mean   sd  5.5% 94.5% n_eff Rhat4
a[1]   0.89 0.02  0.86  0.91   841     1
a[2]   1.05 0.01  1.03  1.07  1028     1
b[1]   0.13 0.08  0.01  0.25   639     1
b[2]  -0.14 0.06 -0.23 -0.05   605     1
sigma  0.11 0.01  0.10  0.12   786     1

> precis(m912, depth = 2)
      mean   sd 5.5% 94.5% n_eff Rhat4
a[1]  0.89 0.02 0.86  0.91   659     1
a[2]  1.05 0.01 1.03  1.06  1095     1
b[1]  0.14 0.08 0.03  0.27   168     1
b[2]  0.02 0.02 0.00  0.05   576     1
sigma 0.11 0.01 0.10  0.12   384     1

 

9M3 풀이

m91 모델을  rstan을 사용하여 작성하여보자.

RStudio에서 File -> New File -> Stan File을 클릭한다. 기본적 사항이 작성되여 열리는데 다 지우고 아래를 붙여넣기한다. 윈도우를 사용한다면 맨 마지막 2줄을 공백으로 둔다. 파일을 ex9m3.stan이라고 R의 작업디렉토리에 저장한다. 여기서는 d:\tmp\rcode\에 저장하였다.

data {
  int<lower=0> cid[170];
  vector[170] log_gdp;
  vector[170] rugged;  
}

parameters {
  vector[2] a;
  vector[2] b;
  real<lower=0> sigma;
}

model {
  vector[170] mu;
  sigma ~ exponential(1);
  b ~ normal(0, 0.3);
  a ~ normal(1, 0.1);
  for(i in 1:170){
    mu[i] = a[cid[i]] + b[cid[i]] * (rugged[i] - 0.215);
  }
  log_gdp ~ normal(mu, sigma);
}

이제 R script를 작성한다.

rm(list=ls())
setwd("d:/tmp/rcode")
library(rstan)
library(rethinking)

data(rugged)
d = rugged; rm(rugged);
d$log_gdp = log(d$rgdppc_2000)
dd = d[complete.cases(d$rgdppc_2000), ]
dd$log_gdp_std = dd$log_gdp / mean(dd$log_gdp)
dd$rugged_std = dd$rugged / max(dd$rugged)
dd$cid = ifelse(dd$cont_africa == 1, 1, 2)

mydat = list(
  log_gdp = dd$log_gdp_std,
  rugged = dd$rugged_std, 
  cid = as.integer(dd$cid)
)
rm(rugged, d, dd)

wm1 = stan(file = 'ex9m3.stan', data = mydat, iter = 1000, warmup = 300)
wm2 = stan(file = 'ex9m3.stan', data = mydat, iter = 1000, warmup = 500)
wm3 = stan(file = 'ex9m3.stan', data = mydat, iter = 1000, warmup = 700)
print(wm1)
print(wm2)
print(wm3)

결과는 다음과 같다.

> print(wm1)
Inference for Stan model: ex9m3.
4 chains, each with iter=1000; warmup=300; thin=1; 
post-warmup draws per chain=700, total post-warmup draws=2800.

        mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
a[1]    0.89    0.00 0.02   0.86   0.88   0.89   0.90   0.92  3368    1
a[2]    1.05    0.00 0.01   1.03   1.04   1.05   1.06   1.07  3693    1
b[1]    0.13    0.00 0.08  -0.02   0.08   0.13   0.18   0.29  3187    1
b[2]   -0.14    0.00 0.06  -0.25  -0.18  -0.14  -0.10  -0.03  2937    1
sigma   0.11    0.00 0.01   0.10   0.11   0.11   0.12   0.12  3396    1
lp__  285.16    0.04 1.53 281.57 284.30 285.49 286.31 287.29  1356    1

Samples were drawn using NUTS(diag_e) at Wed Mar 31 15:42:23 2021.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

> print(wm2)
Inference for Stan model: ex9m3.
4 chains, each with iter=1000; warmup=500; thin=1; 
post-warmup draws per chain=500, total post-warmup draws=2000.

        mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
a[1]    0.89    0.00 0.02   0.86   0.88   0.89   0.90   0.92  2167    1
a[2]    1.05    0.00 0.01   1.03   1.04   1.05   1.06   1.07  2443    1
b[1]    0.13    0.00 0.08  -0.02   0.08   0.14   0.18   0.27  2309    1
b[2]   -0.14    0.00 0.06  -0.25  -0.18  -0.14  -0.11  -0.03  2260    1
sigma   0.11    0.00 0.01   0.10   0.11   0.11   0.12   0.12  1896    1
lp__  285.11    0.05 1.64 281.04 284.29 285.45 286.29 287.24   997    1

Samples were drawn using NUTS(diag_e) at Wed Mar 31 15:42:26 2021.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

> print(wm3)
Inference for Stan model: ex9m3.
4 chains, each with iter=1000; warmup=700; thin=1; 
post-warmup draws per chain=300, total post-warmup draws=1200.

        mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
a[1]    0.89    0.00 0.02   0.86   0.88   0.89   0.90   0.92  1311 1.00
a[2]    1.05    0.00 0.01   1.03   1.04   1.05   1.06   1.07  1651 1.00
b[1]    0.14    0.00 0.07  -0.01   0.09   0.14   0.19   0.28  1182 1.00
b[2]   -0.14    0.00 0.06  -0.25  -0.18  -0.14  -0.10  -0.03  1290 1.00
sigma   0.11    0.00 0.01   0.10   0.11   0.11   0.12   0.12  1430 1.00
lp__  285.19    0.07 1.59 281.26 284.33 285.57 286.41 287.26   600 1.01

Samples were drawn using NUTS(diag_e) at Wed Mar 31 15:42:29 2021.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

warmup이 커짐에 따라 n_eff는 점점 줄어듦음 알수 있다.

Comments