-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_factorgraph.py
More file actions
48 lines (47 loc) · 1.38 KB
/
inference_factorgraph.py
File metadata and controls
48 lines (47 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from create_factorgraph import create_fg
from numbskull import NumbSkull
weight, variable, factor, fmap, domain_mask, edges = create_fg()
ns_learing = NumbSkull(
n_inference_epoch=1000,
n_learning_epoch=1000,
stepsize=0.01,
decay=0.95,
reg_param=1e-6,
regularization=2,
truncation=10,
quiet=(not False),
verbose=False,
learn_non_evidence=False, # need to test
sample_evidence=False,
burn_in=10,
nthreads=1
)
subgraph = weight, variable, factor, fmap, domain_mask, edges
ns_learing.loadFactorGraph(*subgraph)
# 因子图参数学习
ns_learing.learning()
# 因子图推理
# 参数学习完成后将weight的isfixed属性置为true
for index,w in enumerate(weight):
w["isFixed"] = True
w['initialValue'] = ns_learing.factorGraphs[0].weight[index]['initialValue']
ns_inference = NumbSkull(
n_inference_epoch=1000,
n_learning_epoch=1000,
stepsize=0.001,
decay=0.95,
reg_param=1e-6,
regularization=2,
truncation=10,
quiet=(not False),
verbose=False,
learn_non_evidence=False, # need to test
sample_evidence=False,
burn_in=10,
nthreads=1
)
ns_inference.loadFactorGraph(*subgraph)
# 因子图推理
ns_inference.inference()
#获取变量推理结果
print(ns_inference.factorGraphs[0].marginals)