贝叶斯网络入门和R代码实践视频教程

xxxspy 2024-01-04 11:27:59
Categories: Tags:

本教程旨在介绍使用贝叶斯网络学习和推理的基础知识,
如何使用R语言(bnlearn)完成一个贝叶斯网络结构建立和参数学习,
以介绍图算法建模的典型数据分析工作流程。 要点包括:

概率原理

在介绍贝叶斯网络之前, 需要先理解几个概念:

贝叶斯网络原理

定义

如果你对数学公式不感兴趣, 你可以跳转到下面的实战部分。

贝叶斯网络(Bayesian network)的定义:

所以, 一个网络的分布就是:

一个典型的有向无环图就是:

根据上面的公式, 我们的这个图暗含的概率是:

因为A节点没有父节点, 那么A发生的概率就是P(A)它是一个先验概率, 而节点E有两个父节点, 那么E的值的分布概率就是一个条件概率P(E|AS).

学习

所谓的学习就是从数据中学习网络的结构和参数, 所以学习分为两个阶段, 结构的学习和参数的学习。

节点就是变量, 哪个节点连接哪个节点, 这就是结构, 在贝叶斯网络构建之前, 我们不知道如何构建这个网络, 这用到的方法就是结构学习。
一旦有了网络结构, 下一步就是决定网络的参数是多少, 所谓的参数就是所有节点的联合概率分布。

推断

贝叶斯网络的推断就是当你知道网络中某个节点的取值时, 你可以推断其他节点的取值。这其实就是贝叶斯网络的用武之地。
通常我们有两种推断, 一种是条件概率的推断, 一种是极大后验概率的推断。

条件概率的推断(conditional probability)可以参考下图:

当你知道E节点的取值为uni时, 那么可以推断其他节点的条件概率。

极大后验概率推断可以参考下图:

当你知道某个节点的取值时, 你找到所有节点的取值, 使得这种组合发生的概率最大。

工具

我们这个教程主要是教大家如何使用R语言的“bnlearn”完成一个贝叶斯网络, 主要工具就是“bnlearn”。
bnlearn是一个专门用来学习贝叶斯网络的Python包,包含贝叶斯网络的结构学习、参数学习和推理三个方面的功能。其中,结构学习包含基于约束的算法、基于得分的算法和混合算法,参数学习包括最大似然估计和贝叶斯估计两种方法。此外,bnlearn还提供了自助法(bootstrap)、交叉验证(cross-validation)和随机模拟(stochastic simulation)等功能,附加的绘图功能需要调用前述的Rgraphviz和lattice包。

bnlearn 入门

bnlearn安装

使用R studio 或者任何其他R代码的编辑器, 运行如下命令进行安装:

1
install.packages("bnlearn")
输出(stream):
Warning message: "package 'bnlearn' is in use and will not be installed"

加载bnlearn到工作空间, 这样你才可以使用bnlearn:

1
library(bnlearn)

案例数据集

这篇教程基于数据集”transportation_survey”, 这个数据集可以在下面下载, 如果找不到可以联系我.

这个数据集非常简单, 每个变量解释如下:

看到这些变量, 你应该能想到, 这项研究的目的是探讨旅行方式受到哪些因素的影响, 我们经通过今天的教程, 探讨这个问题.

1
2
3
# 读取数据, 保存到survey这个数据框里
survey <- read.csv("transportation_survey.csv", header = TRUE)
head(survey)
输出(html):
A data.frame: 6 × 6
ASEORT
<chr><chr><chr><chr><chr><chr>
1adultFhighempsmalltrain
2youngMhighempbig car
3adultMuni empbig other
4old Funi empbig car
5youngFuni empbig car
6youngFuni empbig car

创建一个网络结构

创建一个网络, 使用字符来表示一个网络, 下面这段字符”[A][S][E|A:S][O|E][R|E][T|O:R]”看起来比较复杂, 我们解释一下:

1
2
dag = model2network("[A][S][E|A:S][O|E][R|E][T|O:R]")
dag
输出(plain):

Random/Generated Bayesian network

model:
[A][S][E|A:S][O|E][R|E][T|O:R]
nodes: 6
arcs: 6
undirected arcs: 0
directed arcs: 6
average markov blanket size: 2.67
average neighbourhood size: 2.00
average branching factor: 1.00

generation algorithm: Empty

除了使用字符串的方法构建网络, 还可以使用节点之间的关系, 例如下面的代码,
arc.set是一个矩阵, 每一行表示一个关系, 第一列表示父节点, 第二列表示子节点, 下面的代码与上面的代码表示的相同的网络:

1
2
3
4
5
6
7
8
9
10
11
arc.set = matrix(c("A", "E",
"S", "E",
"E", "O",
"E", "R",
"O", "T",
"R", "T"),
byrow = TRUE, ncol = 2,
dimnames = list(NULL, c("from", "to")))
dag = empty.graph(c("A", "S", "E", "O", "R", "T"))
arcs(dag) = arc.set
dag
输出(plain):

Random/Generated Bayesian network

model:
[A][S][E|A:S][O|E][R|E][T|O:R]
nodes: 6
arcs: 6
undirected arcs: 0
directed arcs: 6
average markov blanket size: 2.67
average neighbourhood size: 2.00
average branching factor: 1.00

generation algorithm: Empty

网络可视化

我们可以使用 plot 函数可视化一个网络, 但是如下你看到的, 这个图看起来没有规律,
比较乱, 这就是为什么我们不推荐使用 plot.

1
plot(dag)

那么我们还可以使用 graphviz.plot 这个函数,
这个函数不是来自于bnlearn, 它来自于 Rgraphviz , 所以我们先安装一下.

1
2
install.packages("BiocManager")
BiocManager::install("Rgraphviz")
输出(stream):
package 'BiocManager' successfully unpacked and MD5 sums checked The downloaded binary packages are in C:\Users\syd\AppData\Local\Temp\RtmpK8WHs0\downloaded_packages
输出(stream):
'getOption("repos")' replaces Bioconductor standard repositories, see 'help("repositories", package = "BiocManager")' for details. Replacement repositories: CRAN: https://cran.r-project.org Bioconductor version 3.18 (BiocManager 1.30.22), R 4.3.1 (2023-06-16 ucrt) Warning message: "package(s) not installed when version(s) same as or greater than current; use `force = TRUE` to re-install: 'Rgraphviz'" Old packages: 'askpass', 'bookdown', 'brio', 'bslib', 'cli', 'cluster', 'cowplot', 'cpp11', 'curl', 'data.table', 'datawizard', 'DBI', 'dbplyr', 'desc', 'dplyr', 'effectsize', 'emmeans', 'evaluate', 'fansi', 'fontawesome', 'foreign', 'ggplot2', 'ggrepel', 'glue', 'gtable', 'haven', 'htmltools', 'insight', 'jsonlite', 'KernSmooth', 'knitr', 'labeling', 'lattice', 'lifecycle', 'lme4', 'lubridate', 'markdown', 'Matrix', 'MatrixModels', 'mgcv', 'minqa', 'mvtnorm', 'nlme', 'openssl', 'parameters', 'performance', 'pkgload', 'plyr', 'prettyunits', 'processx', 'progress', 'psych', 'ragg', 'Rcpp', 'RcppEigen', 'rematch', 'rlang', 'rmarkdown', 'rpart', 'rprojroot', 'sass', 'scales', 'spatial', 'stringi', 'stringr', 'survival', 'systemfonts', 'testthat', 'textshaping', 'tinytex', 'utf8', 'vctrs', 'vroom', 'waldo', 'withr', 'xfun', 'xml2', 'yaml'

加载 Rgraphviz:

1
library("Rgraphviz")
1
2
3
# 可视化

graphviz.plot(dag, layout = "dot")

你现在看到的图是不是已经好多了, 你还可以更改布局:

1
graphviz.plot(dag, layout = "fdp")
1
graphviz.plot(dag, layout = "circo")

更改颜色:

1
2
3
hlight <- list(nodes = nodes(dag), arcs = arcs(dag),
col = "blue", textCol = "blue")
pp <- graphviz.plot(dag, highlight = hlight)
1
2
3
4
5
6
7
# 箭头颜色也可以更改:

edgeRenderInfo(pp) <- list(col = c("S~E" = "black", "E~R" = "black"),
lwd = c("S~E" = 3, "E~R" = 3))
# 需要重新渲染才能看到图

renderGraph(pp)

根据数据求解图结构

图结构是可以由数据学习而来, 研究者可能对最初是的图结构没有预先定义,
他也不知道使用什么图结构比较好, 那么这时候可是用函数 hc 来求解图结构:

1
2
3
4
5
6
7
8
9
10
survey$A = as.factor(survey$A)
survey$R = as.factor(survey$R)
survey$E = as.factor(survey$E)
survey$O = as.factor(survey$O)
survey$S = as.factor(survey$S)
survey$T = as.factor(survey$T)

dag = hc(survey)

dag
输出(plain):

Bayesian network learned via Score-based methods

model:
[A][S][O][E|A:S][R|E][T|R]
nodes: 6
arcs: 4
undirected arcs: 0
directed arcs: 4
average markov blanket size: 1.67
average neighbourhood size: 1.33
average branching factor: 0.67

learning algorithm: Hill-Climbing
score: BIC (disc.)
penalization coefficient: 3.107304
tests used in the learning procedure: 35
optimized: TRUE
1
2
3
4
# 可视化

graphviz.plot(dag, layout = "dot")

你可以看到, 这个网络结构与我们预先定义的网络结构有些相似, 但是O节点没啥用,
也就是这个变量似乎与其他变量没什么关系, 这时候你注意,
如果研究者认为O变量很重要, 可以重新加入到你的模型中,
如果O可有可无, 就应该从模型中删除,
不必完全相信算法给你的结果.

参数学习

在继续讲解bnlearn的函数之前, 需要先理解今天的案例用到的图模型是分类节点构成的模型,
那么模型的参数就是分类变量的条件分布概率, 如果你学过概率论应该知道, 我们可以用表格来表示条件分布概率:

1
2
3
4
##       E
## O high uni
## emp 0.98 0.92
## self 0.04 0.08

上面的概率是如何计算的呢, 请看公式:

$$ \hat{Pr}(O = emp | E = high) = \frac{\hat{Pr}(O = emp, E = high)}{\hat{Pr}(E = high)}= \frac{\text{number of observations for which O = emp and E = high}}{\text{number of observations for which E = high}} $$

用通俗的语言来说就是, 在 E为high的条件下, O为emp的概率是0.98, 在你的数据中, E为high同时O为emp的样本数除以E为high的样本数就是这个条件概率.
手动计算这个概率就是:

1
2
3
4
5
6
N.E.high = nrow(subset(survey, E=="high"))
N.E.high.O.emp = nrow(subset(survey, E=="high" & O=="emp"))

N.E.high.O.emp / N.E.high


输出(html):
0.982758620689655

bnlearn 提供了 函数 bn.fit 帮助我们求所有的条件概率:

1
2
fitted = bn.fit(dag, survey)
fitted
输出(plain):

Bayesian network parameters

Parameters of node A (multinomial distribution)

Conditional probability table:
adult old young
0.388 0.138 0.474

Parameters of node S (multinomial distribution)

Conditional probability table:
F M
0.522 0.478

Parameters of node E (multinomial distribution)

Conditional probability table:

, , S = F

A
E adult old young
high 0.5188679 0.8717949 0.1206897
uni 0.4811321 0.1282051 0.8793103

, , S = M

A
E adult old young
high 0.7613636 0.9000000 0.7685950
uni 0.2386364 0.1000000 0.2314050


Parameters of node O (multinomial distribution)

Conditional probability table:
emp self
0.98 0.02

Parameters of node R (multinomial distribution)

Conditional probability table:

E
R high uni
big 0.75862069 0.93809524
small 0.24137931 0.06190476

Parameters of node T (multinomial distribution)

Conditional probability table:

R
T big small
car 0.69784173 0.53012048
other 0.13908873 0.08433735
train 0.16306954 0.38554217

这个 fitted 包含了所有条件概率, 从输出中你已经看到.
下面我们可以看下这个模型的预测效果:

1
2
3
4
# 使用fitted 得到的参数预测survey中的T节点, 可以看下预测效果
pred = predict(fitted, data=survey, node="T")

table(pred, survey$T)
输出(plain):

pred car other train
car 335 65 100
other 0 0 0
train 0 0 0

结果是挺差的, 但是这个结果可以看出, 我们的模型把所有样本都预测为car, 因为T是一个分布很不均匀的变量,
car占的比例过高, 导致模型倾向于预测分布概率较高的分类. 另外导致这个问题的还有模型的结构, 你可以看下模型的结构:

对T有直接影响的就只有R, 也就是说预测T的时候, 只用到了R的取值, 信息损失严重, 所以我们有必要对模型进行修正.

1
graphviz.plot(dag, layout = "dot")

hc 函数提供了白名单功能, 也就是白名单之中的关系将被保留到模型中, 不受算法的影响, 我们为了保留尽量多的信息,
将所有节点对T的关系都加入白名单, 看代码:

1
2
3
4
5
6
7
8
9
10
11
survey <- read.csv("transportation_survey.csv", header = TRUE)
whitelist = data.frame(from=c("A", "E", "O", "S", "R"), to=c("T", "T", "T", "T", "T"))
survey$A = as.factor(survey$A)
survey$R = as.factor(survey$R)
survey$E = as.factor(survey$E)
survey$O = as.factor(survey$O)
survey$S = as.factor(survey$S)
survey$T = as.factor(survey$T)

dag = hc(survey, whitelist=whitelist)
plot(dag)

新的模型结构好多了, 我们看下模型预测的结果, 已经有很多样本被预测为非car意外的类别,
当然模型还有很大的提升空间, 但是我们先到此位置, 作为一个入门级教程, 不适于再深入.

1
2
3
4
fitted = bn.fit(dag, data=survey)
pred = predict(fitted, data=survey, node="T")

table(pred, survey$T)
输出(plain):

pred car other train
car 323 60 78
other 1 3 1
train 11 2 21

除了从数据中学习模型参数,
研究者可以自己提供参数, 使用 custom.fit 函数, 这个需要你手动设定所有的条件概率, 例如下面代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
dag2 = model2network("[A][S][E|A:S][O|E][R|E][T|O:R]")
A.lv = c("young", "adult", "old")
S.lv = c("M", "F")
E.lv = c("high", "uni")
O.lv = c("emp", "self")
R.lv = c("big","small")
T.lv = c("car", "other", "train")
A.prob = array(c(0.30, 0.50, 0.20), dim = 3, dimnames = list(A = A.lv))
S.prob = array(c(0.60, 0.40), dim = 2, dimnames = list(S = S.lv))
O.prob = array(c(0.96, 0.04, 0.92, 0.08), dim = c(2, 2),
dimnames = list(O = O.lv, E = E.lv))
R.prob = array(c(0.25, 0.75, 0.20, 0.80), dim = c(2, 2),
dimnames = list(R = R.lv, E = E.lv))
E.prob = array(c(0.75, 0.25, 0.72, 0.28, 0.88, 0.12, 0.64,
0.36, 0.70, 0.30, 0.90, 0.10), dim = c(2, 3, 2),
dimnames = list(E = E.lv, A = A.lv, S = S.lv))

T.prob = array(c(0.48, 0.42, 0.10, 0.56, 0.36, 0.08, 0.58,
0.24, 0.18, 0.70, 0.21, 0.09), dim = c(3, 2, 2),
dimnames = list(T = T.lv, O = O.lv, R = R.lv))
cpt = list(A = A.prob, S = S.prob, E = E.prob, O = O.prob, R = R.prob, T = T.prob)
custom.fitted = custom.fit(dag2, cpt)

custom.fitted
输出(plain):

Bayesian network parameters

Parameters of node A (multinomial distribution)

Conditional probability table:
A
young adult old
0.3 0.5 0.2

Parameters of node E (multinomial distribution)

Conditional probability table:

, , S = M

A
E young adult old
high 0.75 0.72 0.88
uni 0.25 0.28 0.12

, , S = F

A
E young adult old
high 0.64 0.70 0.90
uni 0.36 0.30 0.10


Parameters of node O (multinomial distribution)

Conditional probability table:

E
O high uni
emp 0.96 0.92
self 0.04 0.08

Parameters of node R (multinomial distribution)

Conditional probability table:

E
R high uni
big 0.25 0.20
small 0.75 0.80

Parameters of node S (multinomial distribution)

Conditional probability table:
S
M F
0.6 0.4

Parameters of node T (multinomial distribution)

Conditional probability table:

, , R = big

O
T emp self
car 0.48 0.56
other 0.42 0.36
train 0.10 0.08

, , R = small

O
T emp self
car 0.58 0.70
other 0.24 0.21
train 0.18 0.09

这个预测的效果也不咋地:

1
2
3
4

pred = predict(custom.fitted, data=survey, node="T")

table(pred, survey$T)
输出(plain):

pred car other train
car 335 65 100
other 0 0 0
train 0 0 0

从模型中采样

当你得到了包含参数的网络模型, 例如上面代码中的 fitted, 你可以做很多事情,
例如, 你可以从这个网络中生成样本, 使用 rbn 函数生成10个样本:

1
2
sample.data = rbn(fitted, n = 100)
head(sample.data)
输出(html):
A data.frame: 6 × 6
ASEORT
<fct><fct><fct><fct><fct><fct>
1youngMhighempsmalltrain
2youngMhighempbig car
3youngFhighempbig car
4adultFhighempsmallcar
5adultFhighempbig other
6youngFuni empbig car

预测节点

我们可以对节点(变量)的取值进行预测, 例如, 假设知道了除T以外的所有节点(变量)的取值,
可以对T变量进行预测:

1
2
3
4
5
bn.bayes <- bn.fit(dag, data = survey)

pred = predict(bn.bayes, node = "T", data = survey)

pred
输出(html):
  1. car
  2. car
  3. car
  4. other
  5. car
  6. car
  7. car
  8. car
  9. car
  10. car
  11. car
  12. car
  13. car
  14. train
  15. car
  16. car
  17. car
  18. car
  19. car
  20. car
  21. car
  22. car
  23. car
  24. car
  25. car
  26. car
  27. car
  28. car
  29. train
  30. train
  31. car
  32. car
  33. car
  34. car
  35. car
  36. car
  37. car
  38. car
  39. train
  40. car
  41. car
  42. train
  43. car
  44. car
  45. car
  46. car
  47. car
  48. car
  49. car
  50. car
  51. car
  52. train
  53. car
  54. car
  55. car
  56. car
  57. car
  58. car
  59. car
  60. car
  61. car
  62. car
  63. car
  64. car
  65. car
  66. car
  67. car
  68. car
  69. car
  70. car
  71. train
  72. car
  73. car
  74. car
  75. car
  76. car
  77. car
  78. car
  79. car
  80. car
  81. car
  82. car
  83. car
  84. car
  85. car
  86. car
  87. car
  88. train
  89. car
  90. car
  91. car
  92. car
  93. car
  94. car
  95. car
  96. car
  97. car
  98. car
  99. train
  100. car
  101. car
  102. car
  103. car
  104. car
  105. car
  106. car
  107. car
  108. car
  109. car
  110. car
  111. car
  112. car
  113. car
  114. car
  115. car
  116. car
  117. car
  118. car
  119. car
  120. car
  121. car
  122. car
  123. car
  124. car
  125. car
  126. car
  127. car
  128. car
  129. car
  130. car
  131. car
  132. car
  133. train
  134. car
  135. car
  136. car
  137. car
  138. car
  139. car
  140. car
  141. car
  142. car
  143. train
  144. car
  145. car
  146. car
  147. car
  148. car
  149. car
  150. car
  151. car
  152. train
  153. train
  154. car
  155. car
  156. car
  157. car
  158. car
  159. car
  160. car
  161. car
  162. car
  163. car
  164. car
  165. car
  166. car
  167. car
  168. car
  169. car
  170. car
  171. car
  172. car
  173. car
  174. car
  175. car
  176. car
  177. car
  178. car
  179. car
  180. car
  181. car
  182. car
  183. car
  184. car
  185. car
  186. car
  187. car
  188. car
  189. car
  190. car
  191. car
  192. car
  193. car
  194. car
  195. car
  196. car
  197. car
  198. car
  199. car
  200. car
  201. car
  202. car
  203. car
  204. car
  205. car
  206. car
  207. car
  208. car
  209. car
  210. car
  211. car
  212. car
  213. car
  214. car
  215. car
  216. car
  217. car
  218. car
  219. car
  220. car
  221. car
  222. car
  223. car
  224. car
  225. car
  226. car
  227. car
  228. car
  229. car
  230. car
  231. car
  232. train
  233. car
  234. car
  235. car
  236. car
  237. car
  238. car
  239. car
  240. car
  241. car
  242. car
  243. car
  244. car
  245. car
  246. car
  247. car
  248. car
  249. car
  250. car
  251. car
  252. car
  253. car
  254. car
  255. car
  256. car
  257. car
  258. car
  259. car
  260. car
  261. car
  262. car
  263. car
  264. train
  265. car
  266. car
  267. car
  268. car
  269. car
  270. car
  271. car
  272. car
  273. car
  274. car
  275. car
  276. car
  277. train
  278. car
  279. car
  280. train
  281. car
  282. car
  283. car
  284. car
  285. car
  286. car
  287. car
  288. car
  289. car
  290. car
  291. car
  292. car
  293. car
  294. car
  295. car
  296. car
  297. car
  298. car
  299. car
  300. car
  301. car
  302. train
  303. car
  304. car
  305. car
  306. car
  307. car
  308. car
  309. car
  310. car
  311. car
  312. car
  313. car
  314. car
  315. car
  316. car
  317. car
  318. car
  319. train
  320. car
  321. car
  322. car
  323. car
  324. car
  325. car
  326. car
  327. car
  328. car
  329. car
  330. car
  331. car
  332. car
  333. car
  334. car
  335. car
  336. car
  337. car
  338. other
  339. car
  340. car
  341. train
  342. car
  343. car
  344. car
  345. car
  346. car
  347. car
  348. train
  349. car
  350. car
  351. car
  352. car
  353. car
  354. car
  355. car
  356. car
  357. car
  358. car
  359. car
  360. car
  361. car
  362. car
  363. car
  364. car
  365. car
  366. car
  367. car
  368. car
  369. car
  370. car
  371. car
  372. car
  373. car
  374. car
  375. car
  376. other
  377. car
  378. car
  379. car
  380. car
  381. car
  382. car
  383. car
  384. car
  385. car
  386. car
  387. car
  388. car
  389. other
  390. car
  391. car
  392. car
  393. car
  394. car
  395. car
  396. car
  397. car
  398. train
  399. train
  400. car
Levels:
  1. 'car'
  2. 'other'
  3. 'train'

或者可以计算某种情况发生的概率, 例如 当节点E取值为high, 想要知道 S == “M” 并且 T == “car” 这种情况的概率:

1
2
3

cpquery(fitted, event = (S == "M") & (T == "car"), evidence = (E == "high"))

输出(html):
0.451535012073129

高级教程

这篇教程就是一个贝叶斯网络的基础入门教程,
为了让童鞋们在真实研究中使用贝叶斯网络,
并且可以撰写高质量的分析报告,
我们还开发了贝叶斯网络的高阶视频教程,
这篇教程以论文<贝叶斯网络在老年抑郁症危险…HARLS数据库的实证分析>为模板,
视频演示了论文中用到的所有技术, 并且教你如何整理成论文中用到的数据格式.

数据下载

数据下载链接:https://pan.baidu.com/s/1eVL7JYL5XKxOkvJcmKgMrw?pwd=9xqp
提取码:9xqp

注意
统计咨询请加QQ 2726725926, 微信 shujufenxidaizuo, SPSS统计咨询是收费的, 不论什么模型都可以, 只限制于1个研究内.
跟我学统计可以代做分析, 每单几百元不等.
本文由jupyter notebook转换而来, 您可以在这里下载notebook
可以在微博上@mlln-cn向我免费题问
请记住我的网址: mlln.cn 或者 jupyter.cn