10/19 09:39

# 二、BNT参数学习(MATLAB)

#模型设置#

``````N=4;
dag=zeros(N,N);
C=1;S=2;R=3;W=4;
dag(C,[R S])=1;
dag(R,W)=1;
dag(S,W)=1;``````

#生成多项式条件概率#

``````false = 1; true = 2;
ns = 2*ones(1,N); % binary nodes 节点状态数
figure
draw_graph(dag);
bnet = mk_bnet(dag, ns);
bnet.CPD{

C} = tabular_CPD(bnet, C, [0.5 0.5]);
bnet.CPD{

R} = tabular_CPD(bnet, R, [0.8 0.2 0.2 0.8]);
bnet.CPD{

S} = tabular_CPD(bnet, S, [0.5 0.9 0.5 0.1]);
bnet.CPD{

W} = tabular_CPD(bnet, W, [1 0.1 0.1 0.01 0 0.9 0.9 0.99]);

bnet.CPD{

W}

CPT = cell(1,N);
for i=1:N
s=struct(bnet.CPD{

i});  % 创建或转换为结构数组。
CPT{

i}=s.CPT;
end``````

#构造样本数据#

``````nsamples =5000;
samples = cell(N, nsamples); %创建单元格数组

for i=1:nsamples
samples(:,i) = sample_bnet(bnet); %SAMPLE_BNET从贝叶斯网络生成随机样本。
end
data = cell2num(samples); %CELL2NUM将2D单元格数组转换为2D数字数组``````

#建立贝叶斯网络#

``bnet2 = mk_bnet(dag, ns);  ``

#手动构造条件概率表cpt#

``````bnet2.CPD{

C} = tabular_CPD(bnet2, C, 'clamped', 1, 'CPT', [0.5 0.5], ...
'prior_type', 'dirichlet', 'dirichlet_weight', 0);
bnet2.CPD{

R} = tabular_CPD(bnet2, R, 'prior_type', 'dirichlet', 'dirichlet_weight', 0);
bnet2.CPD{

S} = tabular_CPD(bnet2, S, 'prior_type', 'dirichlet', 'dirichlet_weight', 0);
bnet2.CPD{

W} = tabular_CPD(bnet2, W, 'prior_type', 'dirichlet', 'dirichlet_weight', 0); % tabular_CPD生成多项式条件概率``````

#显示估计的参数#

``````Parameter_MLE=bnet2;
CPT_MLE=cell(1,N);
for i=1:N
s=struct(Parameter_MLE.CPD{

i});
CPT_MLE{

i}=s.CPT;
end

Parameter_MLE_W = CPT_MLE{

4};``````

#从完全观察到的数据中查找MLE#
#先验0的贝叶斯更新等效于ML估计#

``````% Find MLEs from fully observed data 从完全观察到的数据中查找MLE
bnet4 = learn_params(bnet2, samples); %LEARN_PARAMS查找完全观察的模型的最大似然参数

% Bayesian updating with 0 prior is equivalent to ML estimation 先验0的贝叶斯更新等效于ML估计
bnet5 = bayes_update_params(bnet2, samples);  %给定完全观察到的数据，BAYES_UPDATE_PARAMS贝叶斯参数更新``````

#显示学习参数结果#

``````% MLE
CPT4 = cell(1,N);
for i=1:N
s=struct(bnet4.CPD{

i});  % violate object privacy
CPT4{

i}=s.CPT ;
end
CPT4{

4}

% Bayesian
CPT5 = cell(1,N);
for i=1:N
s=struct(bnet5.CPD{

i});  % violate object privacy
CPT5{

i}=s.CPT ;
assert(approxeq(CPT5{

i}, CPT4{

i}));
end
CPT5{

4}
``````

T=cputime;

E=cputime-T;
disp(E)

0
0 收藏

0 评论
0 收藏
0