Home
last modified time | relevance | path

Searched refs:probs_a (Results 1 – 4 of 4) sorted by relevance

/third_party/mindspore/tests/ut/python/nn/probability/distribution/
Dtest_geometric.py134 def construct(self, probs_b, probs_a): argument
136 kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a)
146 probs_a = Tensor([0.7], dtype=dtype.float32)
147 ans = ber_net(probs_b, probs_a)
161 def construct(self, probs_b, probs_a): argument
163 h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a)
173 probs_a = Tensor([0.7], dtype=dtype.float32)
174 ans = net(probs_b, probs_a)
Dtest_categorical.py158 def construct(self, probs_b, probs_a): argument
160 kl2 = self.c2.kl_loss('Categorical', probs_b, probs_a)
170 probs_a = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32)
171 ans = ber_net(probs_b, probs_a)
185 def construct(self, probs_b, probs_a): argument
187 h2 = self.c2.cross_entropy('Categorical', probs_b, probs_a)
197 probs_a = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32)
198 ans = net(probs_b, probs_a)
Dtest_bernoulli.py134 def construct(self, probs_b, probs_a): argument
136 kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
146 probs_a = Tensor([0.7], dtype=dtype.float32)
147 ans = ber_net(probs_b, probs_a)
161 def construct(self, probs_b, probs_a): argument
163 h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a)
173 probs_a = Tensor([0.7], dtype=dtype.float32)
174 ans = net(probs_b, probs_a)
/third_party/mindspore/mindspore/nn/probability/distribution/
Dcategorical.py241 probs_a = self._check_param_type(probs)
242 logits_a = self.log(probs_a)