1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16@File : test_compile.py 17@Author: 18@Date : 2019-03-20 19@Desc : test mindspore compile method 20""" 21import logging 22import numpy as np 23 24import mindspore.nn as nn 25from mindspore import Tensor, Model, context 26from mindspore.nn.optim import Momentum 27from mindspore.ops.composite import add_flags 28from ...ut_filter import non_graph_engine 29 30log = logging.getLogger("test") 31log.setLevel(level=logging.ERROR) 32 33 34class Net(nn.Cell): 35 """ Net definition """ 36 37 def __init__(self): 38 super(Net, self).__init__() 39 self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') 40 self.relu = nn.ReLU() 41 self.flatten = nn.Flatten() 42 43 def construct(self, x): 44 x = self.conv(x) 45 x = self.relu(x) 46 out = self.flatten(x) 47 return out 48 49 50loss = nn.MSELoss() 51 52 53# Test case 1 : test the new compiler interface 54# _build_train_graph is deprecated 55def test_build(): 56 """ test_build """ 57 Tensor(np.random.randint(0, 255, [1, 3, 224, 224])) 58 Tensor(np.random.randint(0, 10, [1, 10])) 59 net = Net() 60 opt = Momentum(net.get_parameters(), learning_rate=0.1, momentum=0.9) 61 Model(net, loss_fn=loss, optimizer=opt, metrics=None) 62 63 64# Test case 2 : test the use different args to run graph 65class Net2(nn.Cell): 66 """ Net2 definition """ 67 68 def __init__(self): 69 super(Net2, self).__init__() 70 self.relu = nn.ReLU() 71 72 def construct(self, x): 73 x = self.relu(x) 74 return x 75 76 77@non_graph_engine 78def test_different_args_run(): 79 """ test_different_args_run """ 80 np1 = np.random.randn(2, 3, 4, 5).astype(np.float32) 81 input_me1 = Tensor(np1) 82 np2 = np.random.randn(2, 3, 4, 5).astype(np.float32) 83 input_me2 = Tensor(np2) 84 85 net = Net2() 86 net = add_flags(net, predit=True) 87 context.set_context(mode=context.GRAPH_MODE) 88 model = Model(net) 89 me1 = model.predict(input_me1) 90 me2 = model.predict(input_me2) 91 out_me1 = me1.asnumpy() 92 out_me2 = me2.asnumpy() 93 print(np1) 94 print(np2) 95 print(out_me1) 96 print(out_me2) 97 assert not np.allclose(out_me1, out_me2, 0.01, 0.01) 98