1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" Bayesian Network """ 16 17import mindspore.nn as nn 18 19import mindspore.nn.probability.distribution as msd 20from mindspore.common import dtype as mstype 21from mindspore.ops import operations as P 22 23 24class BayesianNet(nn.Cell): 25 """ 26 We currently support 3 types of variables: x = observation, z = latent, y = condition. 27 A Bayeisian Network models a generative process for certain variables: p(x,z|y) or p(z|x,y) or p(x|z,y) 28 """ 29 30 def __init__(self): 31 super().__init__() 32 self.normal_dist = msd.Normal(dtype=mstype.float32) 33 self.bernoulli_dist = msd.Bernoulli(dtype=mstype.float32) 34 35 self.reduce_sum = P.ReduceSum(keep_dims=True) 36 37 def normal(self, 38 name, 39 observation=None, 40 mean=None, 41 std=None, 42 seed=0, 43 dtype=mstype.float32, 44 shape=(), 45 reparameterize=True): 46 """ Normal distribution wrapper """ 47 48 if not isinstance(name, str): 49 raise TypeError("The type of `name` should be string") 50 51 if observation is None: 52 if reparameterize: 53 epsilon = self.normal_dist('sample', shape, self.zeros( 54 mean.shape), self.ones(std.shape)) 55 sample = mean + std * epsilon 56 else: 57 sample = self.normal_dist('sample', shape, mean, std) 58 else: 59 sample = observation 60 61 log_prob = self.reduce_sum(self.normal_dist( 62 'log_prob', sample, mean, std), 1) 63 return sample, log_prob 64 65 def bernoulli(self, 66 name, 67 observation=None, 68 probs=None, 69 seed=0, 70 dtype=mstype.float32, 71 shape=()): 72 """ Bernoulli distribution wrapper """ 73 74 if not isinstance(name, str): 75 raise TypeError("The type of `name` should be string") 76 77 if observation is None: 78 sample = self.bernoulli_dist('sample', shape, probs) 79 else: 80 sample = observation 81 82 log_prob = self.reduce_sum( 83 self.bernoulli_dist('log_prob', sample, probs), 1) 84 return sample, log_prob 85 86 def construct(self, *inputs, **kwargs): 87 """ 88 We currently fix the parameters of the construct function. 89 Args: 90 the inputs must consist of 3 variables in order. 91 x: data sample, observation 92 z: latent variable 93 y: conditional information 94 """ 95 raise NotImplementedError 96