1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import time 16import numpy as np 17import pytest 18 19import mindspore.nn as nn 20from mindspore import context, Tensor, ParameterTuple 21from mindspore.common import dtype as mstype 22from mindspore.common.initializer import TruncatedNormal 23from mindspore.nn.optim import Momentum 24from mindspore.nn.wrap.cell_wrapper import WithLossCell 25from mindspore.ops import composite as C 26from mindspore.ops import functional as F 27from mindspore.ops import operations as P 28 29np.random.seed(1) 30 31 32grad_by_list = C.GradOperation(get_by_list=True) 33 34 35def weight_variable(): 36 """weight initial""" 37 return TruncatedNormal(0.02) 38 39 40def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): 41 """weight initial for conv layer""" 42 weight = weight_variable() 43 return nn.Conv2d(in_channels, out_channels, 44 kernel_size=kernel_size, stride=stride, padding=padding, 45 weight_init=weight, has_bias=False, pad_mode="valid") 46 47 48def fc_with_initialize(input_channels, out_channels): 49 """weight initial for fc layer""" 50 weight = weight_variable() 51 bias = weight_variable() 52 return nn.Dense(input_channels, out_channels, weight, bias) 53 54 55class LeNet(nn.Cell): 56 """ 57 Lenet network 58 Args: 59 num_class (int): Num classes, Default: 10. 60 Returns: 61 Tensor, output tensor 62 Examples: 63 >>> LeNet(num_class=10) 64 """ 65 66 def __init__(self, num_class=10): 67 super(LeNet, self).__init__() 68 self.num_class = num_class 69 self.batch_size = 32 70 self.conv1 = conv(1, 6, 5) 71 self.conv2 = conv(6, 16, 5) 72 self.fc1 = fc_with_initialize(16 * 5 * 5, 120) 73 self.fc2 = fc_with_initialize(120, 84) 74 self.fc3 = fc_with_initialize(84, self.num_class) 75 self.relu = nn.ReLU() 76 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 77 self.reshape = P.Reshape() 78 79 def construct(self, x): 80 x = self.conv1(x) 81 x = self.relu(x) 82 x = self.max_pool2d(x) 83 x = self.conv2(x) 84 x = self.relu(x) 85 x = self.max_pool2d(x) 86 x = self.reshape(x, (self.batch_size, -1)) 87 x = self.fc1(x) 88 x = self.relu(x) 89 x = self.fc2(x) 90 x = self.relu(x) 91 x = self.fc3(x) 92 return x 93 94 95class CrossEntropyLoss(nn.Cell): 96 """ 97 Define loss for network 98 """ 99 100 def __init__(self): 101 super(CrossEntropyLoss, self).__init__() 102 self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() 103 self.mean = P.ReduceMean() 104 self.one_hot = P.OneHot() 105 self.on_value = Tensor(1.0, mstype.float32) 106 self.off_value = Tensor(0.0, mstype.float32) 107 self.num = Tensor(32.0, mstype.float32) 108 109 def construct(self, logits, label): 110 label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value) 111 loss = self.cross_entropy(logits, label)[0] 112 loss = P.RealDiv()(P.ReduceSum()(loss, -1), self.num) 113 return loss 114 115 116class GradWrap(nn.Cell): 117 """ 118 GradWrap definition 119 """ 120 121 def __init__(self, network): 122 super(GradWrap, self).__init__() 123 self.network = network 124 self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) 125 126 def construct(self, x, label): 127 weights = self.weights 128 return grad_by_list(self.network, weights)(x, label) 129 130 131@pytest.mark.level1 132@pytest.mark.platform_arm_ascend_training 133@pytest.mark.platform_x86_ascend_training 134@pytest.mark.platform_x86_cpu 135@pytest.mark.platform_x86_gpu_training 136@pytest.mark.env_onecard 137def test_ascend_pynative_lenet(): 138 context.set_context(mode=context.PYNATIVE_MODE) 139 140 epoch_size = 20 141 batch_size = 32 142 inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32)) 143 labels = Tensor(np.ones([batch_size]).astype(np.int32)) 144 145 net = LeNet() 146 criterion = CrossEntropyLoss() 147 optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) 148 149 net_with_criterion = WithLossCell(net, criterion) 150 train_network = GradWrap(net_with_criterion) 151 train_network.set_train() 152 total_time = 0 153 154 for epoch in range(0, epoch_size): 155 start_time = time.time() 156 fw_output = net(inputs) 157 loss_output = criterion(fw_output, labels) 158 grads = train_network(inputs, labels) 159 optimizer(grads) 160 end_time = time.time() 161 cost_time = end_time - start_time 162 total_time = total_time + cost_time 163 164 print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) 165 assert loss_output.asnumpy() < 0.004 166 assert loss_output.asnumpy() > 0.003 167 168 169@pytest.mark.level1 170@pytest.mark.platform_arm_ascend_training 171@pytest.mark.platform_x86_ascend_training 172@pytest.mark.platform_x86_cpu 173@pytest.mark.platform_x86_gpu_training 174@pytest.mark.env_onecard 175def test_pynative_lenet_with_new_interface(): 176 context.set_context(mode=context.PYNATIVE_MODE) 177 178 epoch_size = 20 179 batch_size = 32 180 inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32)) 181 labels = Tensor(np.ones([batch_size]).astype(np.int32)) 182 183 net = LeNet() 184 criterion = CrossEntropyLoss() 185 net_with_criterion = WithLossCell(net, criterion) 186 net_with_criterion.set_train() 187 188 weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) 189 optimizer = Momentum(weights, 0.1, 0.9) 190 191 forward_value_and_grad = nn.ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True) 192 total_time = 0 193 for epoch in range(0, epoch_size): 194 start_time = time.time() 195 loss_output, grads = forward_value_and_grad(inputs, labels) 196 optimizer(grads) 197 end_time = time.time() 198 cost_time = end_time - start_time 199 total_time = total_time + cost_time 200 201 print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) 202 assert loss_output.asnumpy() < 0.005 203 assert loss_output.asnumpy() > 0.003 204