• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" Bayesian Network """
16
17import mindspore.nn as nn
18
19import mindspore.nn.probability.distribution as msd
20from mindspore.common import dtype as mstype
21from mindspore.ops import operations as P
22
23
24class BayesianNet(nn.Cell):
25    """
26    We currently support 3 types of variables: x = observation, z = latent, y = condition.
27    A Bayeisian Network models a generative process for certain variables: p(x,z|y) or p(z|x,y) or p(x|z,y)
28    """
29
30    def __init__(self):
31        super().__init__()
32        self.normal_dist = msd.Normal(dtype=mstype.float32)
33        self.bernoulli_dist = msd.Bernoulli(dtype=mstype.float32)
34
35        self.reduce_sum = P.ReduceSum(keep_dims=True)
36
37    def normal(self,
38               name,
39               observation=None,
40               mean=None,
41               std=None,
42               seed=0,
43               dtype=mstype.float32,
44               shape=(),
45               reparameterize=True):
46        """ Normal distribution wrapper """
47
48        if not isinstance(name, str):
49            raise TypeError("The type of `name` should be string")
50
51        if observation is None:
52            if reparameterize:
53                epsilon = self.normal_dist('sample', shape, self.zeros(
54                    mean.shape), self.ones(std.shape))
55                sample = mean + std * epsilon
56            else:
57                sample = self.normal_dist('sample', shape, mean, std)
58        else:
59            sample = observation
60
61        log_prob = self.reduce_sum(self.normal_dist(
62            'log_prob', sample, mean, std), 1)
63        return sample, log_prob
64
65    def bernoulli(self,
66                  name,
67                  observation=None,
68                  probs=None,
69                  seed=0,
70                  dtype=mstype.float32,
71                  shape=()):
72        """ Bernoulli distribution wrapper """
73
74        if not isinstance(name, str):
75            raise TypeError("The type of `name` should be string")
76
77        if observation is None:
78            sample = self.bernoulli_dist('sample', shape, probs)
79        else:
80            sample = observation
81
82        log_prob = self.reduce_sum(
83            self.bernoulli_dist('log_prob', sample, probs), 1)
84        return sample, log_prob
85
86    def construct(self, *inputs, **kwargs):
87        """
88        We currently fix the parameters of the construct function.
89        Args:
90            the inputs must consist of 3 variables in order.
91            x: data sample, observation
92            z: latent variable
93            y: conditional information
94        """
95        raise NotImplementedError
96