1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import pytest 17import mindspore.context as context 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore.common import dtype as mstype 21from mindspore.ops import operations as P 22 23context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 24 25 26class Net(nn.Cell): 27 def __init__(self, shape, seed=0, seed2=0): 28 super(Net, self).__init__() 29 self.shape = shape 30 self.min_val = Tensor(10, mstype.int32) 31 self.max_val = Tensor(100, mstype.int32) 32 self.seed = seed 33 self.seed2 = seed2 34 self.uniformint = P.UniformInt(seed, seed2) 35 36 def construct(self): 37 return self.uniformint(self.shape, self.min_val, self.max_val) 38 39 40@pytest.mark.level0 41@pytest.mark.platform_x86_cpu 42@pytest.mark.env_onecard 43def test_net(): 44 seed = 10 45 seed2 = 10 46 shape = (5, 6, 8) 47 net = Net(shape, seed, seed2) 48 output = net() 49 assert output.shape == (5, 6, 8) 50 outnumpyflatten_1 = output.asnumpy().flatten() 51 52 seed = 0 53 seed2 = 10 54 shape = (5, 6, 8) 55 net = Net(shape, seed, seed2) 56 output = net() 57 assert output.shape == (5, 6, 8) 58 outnumpyflatten_2 = output.asnumpy().flatten() 59 # same seed should generate same random number 60 assert (outnumpyflatten_1 == outnumpyflatten_2).all() 61 62 seed = 0 63 seed2 = 0 64 shape = (130, 120, 141) 65 net = Net(shape, seed, seed2) 66 output = net() 67 assert output.shape == (130, 120, 141) 68