1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test conv""" 16import numpy as np 17 18import mindspore.nn as nn 19from mindspore import Tensor 20from ..ut_filter import non_graph_engine 21 22weight = Tensor(np.ones([2, 2])) 23in_channels = 3 24out_channels = 64 25 26 27class Net(nn.Cell): 28 """Net definition""" 29 30 def __init__(self, 31 cin, 32 cout, 33 kernel_size, 34 stride=1, 35 pad_mode='pad', 36 padding=0, 37 dilation=1, 38 group=1, 39 has_bias=False, 40 weight_init='normal', 41 bias_init='zeros'): 42 super(Net, self).__init__() 43 Tensor(np.ones([6, 3, 3, 3]).astype(np.float32) * 0.01) 44 self.conv = nn.Conv2d(cin, 45 cout, 46 kernel_size, 47 stride, 48 pad_mode, 49 padding, 50 dilation, 51 group, 52 has_bias, 53 weight_init, 54 bias_init) 55 56 def construct(self, input_x): 57 return self.conv(input_x) 58 59 60@non_graph_engine 61def test_compile(): 62 net = Net(3, 6, (3, 3), bias_init='zeros') 63 input_data = Tensor(np.ones([3, 3, 32, 32]).astype(np.float32) * 0.01) 64 output = net(input_data) 65 print(output.asnumpy()) 66 67 68@non_graph_engine 69def test_compile2(): 70 net = Net(3, 1, (3, 3), bias_init='zeros') 71 input_data = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01) 72 output = net(input_data) 73 print(output.asnumpy()) 74 75 76@non_graph_engine 77def test_compile3(): 78 net = Net(3, 1, (3, 3), weight_init='ONES') 79 input_data = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01) 80 output = net(input_data) 81 print(output.asnumpy()) 82