1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_dictionary """ 16import numpy as np 17 18from mindspore import Tensor, context 19from mindspore.nn import Cell 20 21context.set_context(mode=context.GRAPH_MODE) 22 23 24class Net1(Cell): 25 def __init__(self): 26 super().__init__() 27 28 def construct(self, x): 29 dic = {'x': 0, 'y': 1} 30 output = [] 31 for i in dic.keys(): 32 output.append(i) 33 for j in dic.values(): 34 output.append(j) 35 return output 36 37 38class Net2(Cell): 39 def __init__(self): 40 super().__init__() 41 42 def construct(self, x): 43 dic = {'x': x, 'y': 1} 44 output = [] 45 for i in dic.keys(): 46 output.append(i) 47 for j in dic.values(): 48 output.append(j) 49 return output 50 51 52class Net3(Cell): 53 def __init__(self): 54 super().__init__() 55 56 def construct(self, x): 57 dic = {'x': 0} 58 dic['y'] = (0, 1) 59 output = [] 60 for i in dic.keys(): 61 output.append(i) 62 for j in dic.values(): 63 output.append(j) 64 return output 65 66 67def test_dict1(): 68 input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) 69 input_me = Tensor(input_np) 70 net = Net1() 71 out_me = net(input_me) 72 assert out_me == ('x', 'y', 0, 1) 73 74 75def test_dict2(): 76 input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) 77 input_me = Tensor(input_np) 78 net = Net2() 79 net(input_me) 80 81 82def test_dict3(): 83 input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) 84 input_me = Tensor(input_np) 85 net = Net3() 86 out_me = net(input_me) 87 assert out_me == ('x', 'y', 0, (0, 1)) 88 89 90def test_dict4(): 91 class Net(Cell): 92 def __init__(self): 93 super().__init__() 94 95 def construct(self, tuple_x): 96 output = tuple_x + tuple_x 97 return output 98 99 x = (1, Tensor([1, 2, 3]), {"a": Tensor([1, 2, 3]), "b": 1}) 100 net = Net() 101 out_me = net(x) 102 assert out_me == x + x 103