Home
last modified time | relevance | path

Searched refs:probs_b (Results 1 – 4 of 4) sorted by relevance

/third_party/mindspore/tests/ut/python/nn/probability/distribution/
Dtest_geometric.py134 def construct(self, probs_b, probs_a): argument
135 kl1 = self.g1.kl_loss('Geometric', probs_b)
136 kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a)
145 probs_b = Tensor([0.3], dtype=dtype.float32)
147 ans = ber_net(probs_b, probs_a)
161 def construct(self, probs_b, probs_a): argument
162 h1 = self.g1.cross_entropy('Geometric', probs_b)
163 h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a)
172 probs_b = Tensor([0.3], dtype=dtype.float32)
174 ans = net(probs_b, probs_a)
Dtest_categorical.py158 def construct(self, probs_b, probs_a): argument
159 kl1 = self.c1.kl_loss('Categorical', probs_b)
160 kl2 = self.c2.kl_loss('Categorical', probs_b, probs_a)
169 probs_b = Tensor([0.3, 0.1, 0.6], dtype=dtype.float32)
171 ans = ber_net(probs_b, probs_a)
185 def construct(self, probs_b, probs_a): argument
186 h1 = self.c1.cross_entropy('Categorical', probs_b)
187 h2 = self.c2.cross_entropy('Categorical', probs_b, probs_a)
196 probs_b = Tensor([0.3, 0.1, 0.6], dtype=dtype.float32)
198 ans = net(probs_b, probs_a)
Dtest_bernoulli.py134 def construct(self, probs_b, probs_a): argument
135 kl1 = self.b1.kl_loss('Bernoulli', probs_b)
136 kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
145 probs_b = Tensor([0.3], dtype=dtype.float32)
147 ans = ber_net(probs_b, probs_a)
161 def construct(self, probs_b, probs_a): argument
162 h1 = self.b1.cross_entropy('Bernoulli', probs_b)
163 h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a)
172 probs_b = Tensor([0.3], dtype=dtype.float32)
174 ans = net(probs_b, probs_a)
/third_party/mindspore/mindspore/nn/probability/distribution/
Dcategorical.py229 def _kl_loss(self, dist, probs_b, probs=None): argument
239 probs_b = self._check_value(probs_b, 'probs_b')
240 probs_b = self.cast(probs_b, self.parameter_type)
243 logits_b = self.log(probs_b)
247 def _cross_entropy(self, dist, probs_b, probs=None): argument
257 return self._entropy(probs) + self._kl_loss(dist, probs_b, probs)