• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import pytest
17
18import mindspore
19import mindspore.context as context
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore.ops import operations as P
23
24class RandpermNet(nn.Cell):
25    def __init__(self, max_length, pad, dtype):
26        super(RandpermNet, self).__init__()
27        self.randperm = P.Randperm(max_length, pad, dtype)
28
29    def construct(self, x):
30        return self.randperm(x)
31
32
33def randperm(max_length, pad, dtype, n):
34    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
35
36    x = Tensor(np.array([n]).astype(np.int32))
37    randperm_net = RandpermNet(max_length, pad, dtype)
38    output = randperm_net(x).asnumpy()
39
40    # verify permutation
41    output_perm_sorted = np.sort(output[0:n])
42    expected = np.arange(n)
43    np.testing.assert_array_equal(expected, output_perm_sorted)
44
45    # verify pad
46    output_pad = output[n:]
47    for e in output_pad:
48        assert e == pad
49
50    print(output)
51    print(output.dtype)
52
53
54@pytest.mark.level1
55@pytest.mark.platform_x86_gpu_training
56@pytest.mark.env_onecard
57def test_randperm_int8():
58    randperm(8, -1, mindspore.int8, 5)
59
60@pytest.mark.level1
61@pytest.mark.platform_x86_gpu_training
62@pytest.mark.env_onecard
63def test_randperm_int16():
64    randperm(3, 0, mindspore.int16, 3)
65
66@pytest.mark.level1
67@pytest.mark.platform_x86_gpu_training
68@pytest.mark.env_onecard
69def test_randperm_int32():
70    randperm(4, -6, mindspore.int32, 2)
71
72@pytest.mark.level1
73@pytest.mark.platform_x86_gpu_training
74@pytest.mark.env_onecard
75def test_randperm_int64():
76    randperm(12, 128, mindspore.int64, 4)
77
78@pytest.mark.level1
79@pytest.mark.platform_x86_gpu_training
80@pytest.mark.env_onecard
81def test_randperm_uint8():
82    randperm(8, 1, mindspore.uint8, 5)
83
84@pytest.mark.level1
85@pytest.mark.platform_x86_gpu_training
86@pytest.mark.env_onecard
87def test_randperm_uint16():
88    randperm(8, 0, mindspore.uint16, 8)
89
90@pytest.mark.level0
91@pytest.mark.platform_x86_gpu_training
92@pytest.mark.env_onecard
93def test_randperm_uint32():
94    randperm(4, 8, mindspore.uint32, 3)
95
96@pytest.mark.level0
97@pytest.mark.platform_x86_gpu_training
98@pytest.mark.env_onecard
99def test_randperm_uint64():
100    randperm(5, 4, mindspore.uint64, 5)
101
102@pytest.mark.level0
103@pytest.mark.platform_x86_gpu_training
104@pytest.mark.env_onecard
105def test_randperm_n_too_large():
106    with pytest.raises(RuntimeError) as info:
107        randperm(1, 0, mindspore.int32, 2)
108    assert "n (2) cannot exceed max_length_ (1)" in str(info.value)
109