1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore.ops import operations as P 23 24class RandpermNet(nn.Cell): 25 def __init__(self, max_length, pad, dtype): 26 super(RandpermNet, self).__init__() 27 self.randperm = P.Randperm(max_length, pad, dtype) 28 29 def construct(self, x): 30 return self.randperm(x) 31 32 33def randperm(max_length, pad, dtype, n): 34 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 35 36 x = Tensor(np.array([n]).astype(np.int32)) 37 randperm_net = RandpermNet(max_length, pad, dtype) 38 output = randperm_net(x).asnumpy() 39 40 # verify permutation 41 output_perm_sorted = np.sort(output[0:n]) 42 expected = np.arange(n) 43 np.testing.assert_array_equal(expected, output_perm_sorted) 44 45 # verify pad 46 output_pad = output[n:] 47 for e in output_pad: 48 assert e == pad 49 50 print(output) 51 print(output.dtype) 52 53 54@pytest.mark.level1 55@pytest.mark.platform_x86_gpu_training 56@pytest.mark.env_onecard 57def test_randperm_int8(): 58 randperm(8, -1, mindspore.int8, 5) 59 60@pytest.mark.level1 61@pytest.mark.platform_x86_gpu_training 62@pytest.mark.env_onecard 63def test_randperm_int16(): 64 randperm(3, 0, mindspore.int16, 3) 65 66@pytest.mark.level1 67@pytest.mark.platform_x86_gpu_training 68@pytest.mark.env_onecard 69def test_randperm_int32(): 70 randperm(4, -6, mindspore.int32, 2) 71 72@pytest.mark.level1 73@pytest.mark.platform_x86_gpu_training 74@pytest.mark.env_onecard 75def test_randperm_int64(): 76 randperm(12, 128, mindspore.int64, 4) 77 78@pytest.mark.level1 79@pytest.mark.platform_x86_gpu_training 80@pytest.mark.env_onecard 81def test_randperm_uint8(): 82 randperm(8, 1, mindspore.uint8, 5) 83 84@pytest.mark.level1 85@pytest.mark.platform_x86_gpu_training 86@pytest.mark.env_onecard 87def test_randperm_uint16(): 88 randperm(8, 0, mindspore.uint16, 8) 89 90@pytest.mark.level0 91@pytest.mark.platform_x86_gpu_training 92@pytest.mark.env_onecard 93def test_randperm_uint32(): 94 randperm(4, 8, mindspore.uint32, 3) 95 96@pytest.mark.level0 97@pytest.mark.platform_x86_gpu_training 98@pytest.mark.env_onecard 99def test_randperm_uint64(): 100 randperm(5, 4, mindspore.uint64, 5) 101 102@pytest.mark.level0 103@pytest.mark.platform_x86_gpu_training 104@pytest.mark.env_onecard 105def test_randperm_n_too_large(): 106 with pytest.raises(RuntimeError) as info: 107 randperm(1, 0, mindspore.int32, 2) 108 assert "n (2) cannot exceed max_length_ (1)" in str(info.value) 109