• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import pytest
16import numpy as np
17import mindspore.nn as nn
18import mindspore.ops.operations as P
19import mindspore.nn.probability.distribution as msd
20from mindspore import context, Tensor
21from mindspore.ops import composite as C
22from mindspore.common import dtype as mstype
23from mindspore import dtype
24
25context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
26
27
28class Sampling(nn.Cell):
29    """
30    Test class: sample of Normal distribution.
31    """
32
33    def __init__(self, shape, seed=0):
34        super(Sampling, self).__init__()
35        self.n1 = msd.Normal(0, 1, seed=seed, dtype=dtype.float32)
36        self.shape = shape
37
38    def construct(self, mean=None, sd=None):
39        s1 = self.n1.sample(self.shape, mean, sd)
40        s2 = self.n1.sample(self.shape, mean, sd)
41        s3 = self.n1.sample(self.shape, mean, sd)
42        return s1, s2, s3
43
44
45@pytest.mark.level1
46@pytest.mark.platform_arm_ascend_training
47@pytest.mark.platform_x86_ascend_training
48@pytest.mark.env_onecard
49def test_sample_graph():
50    shape = (2, 3)
51    seed = 0
52    samp = Sampling(shape, seed=seed)
53    sample1, sample2, sample3 = samp()
54    assert ((sample1 != sample2).any() and (sample1 != sample3).any() and (sample2 != sample3).any()), \
55        "The results should be different!"
56
57
58class CompositeNormalNet(nn.Cell):
59    def __init__(self, shape=None, seed=0):
60        super(CompositeNormalNet, self).__init__()
61        self.shape = shape
62        self.seed = seed
63
64    def construct(self, mean, stddev):
65        s1 = C.normal(self.shape, mean, stddev, self.seed)
66        s2 = C.normal(self.shape, mean, stddev, self.seed)
67        s3 = C.normal(self.shape, mean, stddev, self.seed)
68        return s1, s2, s3
69
70
71@pytest.mark.level1
72@pytest.mark.platform_arm_ascend_training
73@pytest.mark.platform_x86_ascend_training
74@pytest.mark.env_onecard
75def test_composite_normal():
76    shape = (3, 2, 4)
77    mean = Tensor(0.0, mstype.float32)
78    stddev = Tensor(1.0, mstype.float32)
79    net = CompositeNormalNet(shape)
80    s1, s2, s3 = net(mean, stddev)
81    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
82        "The results should be different!"
83
84
85class CompositeLaplaceNet(nn.Cell):
86    def __init__(self, shape=None, seed=0):
87        super(CompositeLaplaceNet, self).__init__()
88        self.shape = shape
89        self.seed = seed
90
91    def construct(self, mean, lambda_param):
92        s1 = C.laplace(self.shape, mean, lambda_param, self.seed)
93        s2 = C.laplace(self.shape, mean, lambda_param, self.seed)
94        s3 = C.laplace(self.shape, mean, lambda_param, self.seed)
95        return s1, s2, s3
96
97
98@pytest.mark.level1
99@pytest.mark.platform_arm_ascend_training
100@pytest.mark.platform_x86_ascend_training
101@pytest.mark.env_onecard
102def test_composite_laplace():
103    shape = (3, 2, 4)
104    mean = Tensor(1.0, mstype.float32)
105    lambda_param = Tensor(1.0, mstype.float32)
106    net = CompositeLaplaceNet(shape)
107    s1, s2, s3 = net(mean, lambda_param)
108    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
109        "The results should be different!"
110
111
112class CompositeGammaNet(nn.Cell):
113    def __init__(self, shape=None, seed=0):
114        super(CompositeGammaNet, self).__init__()
115        self.shape = shape
116        self.seed = seed
117
118    def construct(self, alpha, beta):
119        s1 = C.gamma(self.shape, alpha, beta, self.seed)
120        s2 = C.gamma(self.shape, alpha, beta, self.seed)
121        s3 = C.gamma(self.shape, alpha, beta, self.seed)
122        return s1, s2, s3
123
124
125@pytest.mark.level1
126@pytest.mark.platform_arm_ascend_training
127@pytest.mark.platform_x86_ascend_training
128@pytest.mark.env_onecard
129def test_composite_gamma():
130    shape = (3, 2, 4)
131    alpha = Tensor(1.0, mstype.float32)
132    beta = Tensor(1.0, mstype.float32)
133    net = CompositeGammaNet(shape)
134    s1, s2, s3 = net(alpha, beta)
135    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
136        "The results should be different!"
137
138
139class CompositePoissonNet(nn.Cell):
140    def __init__(self, shape=None, seed=0):
141        super(CompositePoissonNet, self).__init__()
142        self.shape = shape
143        self.seed = seed
144
145    def construct(self, mean):
146        s1 = C.poisson(self.shape, mean, self.seed)
147        s2 = C.poisson(self.shape, mean, self.seed)
148        s3 = C.poisson(self.shape, mean, self.seed)
149        return s1, s2, s3
150
151
152@pytest.mark.level1
153@pytest.mark.platform_arm_ascend_training
154@pytest.mark.platform_x86_ascend_training
155@pytest.mark.env_onecard
156def test_composite_poisson():
157    shape = (3, 2, 4)
158    mean = Tensor(2.0, mstype.float32)
159    net = CompositePoissonNet(shape)
160    s1, s2, s3 = net(mean)
161    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
162        "The results should be different!"
163
164
165class CompositeUniformNet(nn.Cell):
166    def __init__(self, shape=None, seed=0):
167        super(CompositeUniformNet, self).__init__()
168        self.shape = shape
169        self.seed = seed
170
171    def construct(self, a, b):
172        s1 = C.uniform(self.shape, a, b, self.seed)
173        s2 = C.uniform(self.shape, a, b, self.seed)
174        s3 = C.uniform(self.shape, a, b, self.seed)
175        return s1, s2, s3
176
177
178@pytest.mark.level1
179@pytest.mark.platform_arm_ascend_training
180@pytest.mark.platform_x86_ascend_training
181@pytest.mark.env_onecard
182def test_composite_uniform():
183    shape = (3, 2, 4)
184    a = Tensor(0.0, mstype.float32)
185    b = Tensor(1.0, mstype.float32)
186    net = CompositeUniformNet(shape)
187    s1, s2, s3 = net(a, b)
188    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
189        "The results should be different!"
190
191
192class StandardNormalNet(nn.Cell):
193    def __init__(self, shape, seed=0, seed2=0):
194        super(StandardNormalNet, self).__init__()
195        self.shape = shape
196        self.seed = seed
197        self.seed2 = seed2
198        self.standard_normal = P.StandardNormal(seed, seed2)
199
200    def construct(self):
201        s1 = self.standard_normal(self.shape)
202        s2 = self.standard_normal(self.shape)
203        s3 = self.standard_normal(self.shape)
204        return s1, s2, s3
205
206
207@pytest.mark.level1
208@pytest.mark.platform_arm_ascend_training
209@pytest.mark.platform_x86_ascend_training
210@pytest.mark.env_onecard
211def test_standard_normal():
212    shape = (4, 16)
213    net = StandardNormalNet(shape)
214    s1, s2, s3 = net()
215    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
216        "The results should be different!"
217
218
219class StandardLaplaceNet(nn.Cell):
220    def __init__(self, shape, seed=0, seed2=0):
221        super(StandardLaplaceNet, self).__init__()
222        self.shape = shape
223        self.seed = seed
224        self.seed2 = seed2
225        self.standard_laplace = P.StandardLaplace(seed, seed2)
226
227    def construct(self):
228        s1 = self.standard_laplace(self.shape)
229        s2 = self.standard_laplace(self.shape)
230        s3 = self.standard_laplace(self.shape)
231        return s1, s2, s3
232
233
234@pytest.mark.level1
235@pytest.mark.platform_arm_ascend_training
236@pytest.mark.platform_x86_ascend_training
237@pytest.mark.env_onecard
238def test_standard_laplace():
239    shape = (4, 16)
240    net = StandardLaplaceNet(shape)
241    s1, s2, s3 = net()
242    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
243        "The results should be different!"
244
245
246class GammaNet(nn.Cell):
247    def __init__(self, shape, alpha, beta, seed=0, seed2=0):
248        super(GammaNet, self).__init__()
249        self.shape = shape
250        self.alpha = alpha
251        self.beta = beta
252        self.seed = seed
253        self.seed2 = seed2
254        self.gamma = P.Gamma(seed, seed2)
255
256    def construct(self):
257        s1 = self.gamma(self.shape, self.alpha, self.beta)
258        s2 = self.gamma(self.shape, self.alpha, self.beta)
259        s3 = self.gamma(self.shape, self.alpha, self.beta)
260        return s1, s2, s3
261
262
263@pytest.mark.level1
264@pytest.mark.platform_arm_ascend_training
265@pytest.mark.platform_x86_ascend_training
266@pytest.mark.env_onecard
267def test_gamma():
268    shape = (4, 16)
269    alpha = Tensor(1.0, mstype.float32)
270    beta = Tensor(1.0, mstype.float32)
271    net = GammaNet(shape, alpha, beta)
272    s1, s2, s3 = net()
273    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
274        "The results should be different!"
275
276
277class PoissonNet(nn.Cell):
278    def __init__(self, shape, seed=0, seed2=0):
279        super(PoissonNet, self).__init__()
280        self.shape = shape
281        self.seed = seed
282        self.seed2 = seed2
283        self.poisson = P.Poisson(seed, seed2)
284
285    def construct(self, mean):
286        s1 = self.poisson(self.shape, mean)
287        s2 = self.poisson(self.shape, mean)
288        s3 = self.poisson(self.shape, mean)
289        return s1, s2, s3
290
291
292@pytest.mark.level1
293@pytest.mark.platform_arm_ascend_training
294@pytest.mark.platform_x86_ascend_training
295@pytest.mark.env_onecard
296def test_poisson():
297    shape = (4, 16)
298    mean = Tensor(5.0, mstype.float32)
299    net = PoissonNet(shape=shape)
300    s1, s2, s3 = net(mean)
301    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
302        "The results should be different!"
303
304
305class UniformIntNet(nn.Cell):
306    def __init__(self, shape, seed=0, seed2=0):
307        super(UniformIntNet, self).__init__()
308        self.shape = shape
309        self.seed = seed
310        self.seed2 = seed2
311        self.uniform_int = P.UniformInt(seed, seed2)
312
313    def construct(self, minval, maxval):
314        s1 = self.uniform_int(self.shape, minval, maxval)
315        s2 = self.uniform_int(self.shape, minval, maxval)
316        s3 = self.uniform_int(self.shape, minval, maxval)
317        return s1, s2, s3
318
319
320@pytest.mark.level1
321@pytest.mark.platform_arm_ascend_training
322@pytest.mark.platform_x86_ascend_training
323@pytest.mark.env_onecard
324def test_uniform_int():
325    shape = (4, 16)
326    minval = Tensor(1, mstype.int32)
327    maxval = Tensor(5, mstype.int32)
328    net = UniformIntNet(shape)
329    s1, s2, s3 = net(minval, maxval)
330    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
331        "The results should be different!"
332
333
334class UniformRealNet(nn.Cell):
335    def __init__(self, shape, seed=0, seed2=0):
336        super(UniformRealNet, self).__init__()
337        self.shape = shape
338        self.seed = seed
339        self.seed2 = seed2
340        self.uniform_real = P.UniformReal(seed, seed2)
341
342    def construct(self):
343        s1 = self.uniform_real(self.shape)
344        s2 = self.uniform_real(self.shape)
345        s3 = self.uniform_real(self.shape)
346        return s1, s2, s3
347
348
349@pytest.mark.level1
350@pytest.mark.platform_arm_ascend_training
351@pytest.mark.platform_x86_ascend_training
352@pytest.mark.env_onecard
353def test_uniform_real():
354    shape = (4, 16)
355    net = UniformRealNet(shape)
356    s1, s2, s3 = net()
357    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
358        "The results should be different!"
359
360
361class DropoutGenMaskNet(nn.Cell):
362    def __init__(self, shape):
363        super(DropoutGenMaskNet, self).__init__()
364        self.shape = shape
365        self.dropout_gen_mask = P.DropoutGenMask(Seed0=0, Seed1=0)
366
367    def construct(self, keep_prob):
368        s1 = self.dropout_gen_mask(self.shape, keep_prob)
369        s2 = self.dropout_gen_mask(self.shape, keep_prob)
370        s3 = self.dropout_gen_mask(self.shape, keep_prob)
371        return s1, s2, s3
372
373
374@pytest.mark.level0
375@pytest.mark.platform_arm_ascend_training
376@pytest.mark.platform_x86_ascend_training
377@pytest.mark.env_onecard
378def test_dropout_gen_mask():
379    shape = (2, 4, 5)
380    keep_prob = Tensor(0.5, mstype.float32)
381    net = DropoutGenMaskNet(shape)
382    s1, s2, s3 = net(keep_prob)
383    assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \
384        "The results should be different!"
385
386
387class RandomChoiceWithMaskNet(nn.Cell):
388    def __init__(self):
389        super(RandomChoiceWithMaskNet, self).__init__()
390        self.rnd_choice_mask = P.RandomChoiceWithMask(count=4, seed=0)
391
392    def construct(self, x):
393        index1, _ = self.rnd_choice_mask(x)
394        index2, _ = self.rnd_choice_mask(x)
395        index3, _ = self.rnd_choice_mask(x)
396        return index1, index2, index3
397
398
399@pytest.mark.level0
400@pytest.mark.platform_arm_ascend_training
401@pytest.mark.platform_x86_ascend_training
402@pytest.mark.env_onecard
403def test_random_choice_with_mask():
404    mode = context.get_context('mode')
405    assert (mode == context.GRAPH_MODE), 'GRAPH_MODE required but got ' + str(mode)
406    net = RandomChoiceWithMaskNet()
407    x = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool))
408    index1, index2, index3 = net(x)
409    assert ((index1 != index2).any() and (index1 != index3).any() and (index2 != index3).any()), \
410        "The results should be different!"
411
412
413class RandomCategoricalNet(nn.Cell):
414    def __init__(self, num_sample):
415        super(RandomCategoricalNet, self).__init__()
416        self.random_categorical = P.RandomCategorical(mstype.int64)
417        self.num_sample = num_sample
418
419    def construct(self, logits, seed=0):
420        s1 = self.random_categorical(logits, self.num_sample, seed)
421        s2 = self.random_categorical(logits, self.num_sample, seed)
422        s3 = self.random_categorical(logits, self.num_sample, seed)
423        return s1, s2, s3
424
425
426@pytest.mark.level1
427@pytest.mark.platform_arm_ascend_training
428@pytest.mark.platform_x86_ascend_training
429@pytest.mark.env_onecard
430def test_random_categorical():
431    num_sample = 8
432    net = RandomCategoricalNet(num_sample)
433    x = Tensor(np.random.random((10, 5)).astype(np.float32))
434    # Outputs may be the same, only basic functions are verified here.
435    net(x)
436