1import os 2import numpy as np 3 4import mindspore.nn as nn 5from mindspore import context 6from mindspore.common.tensor import Tensor 7from mindspore.common.initializer import TruncatedNormal 8from mindspore.common.parameter import ParameterTuple 9from mindspore.ops import operations as P 10from mindspore.ops import composite as C 11from mindspore.train.serialization import export 12 13 14def weight_variable(): 15 return TruncatedNormal(0.02) 16 17 18def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): 19 weight = weight_variable() 20 return nn.Conv2d(in_channels, out_channels, 21 kernel_size=kernel_size, stride=stride, padding=padding, 22 weight_init=weight, has_bias=False, pad_mode="valid") 23 24 25def fc_with_initialize(input_channels, out_channels): 26 weight = weight_variable() 27 bias = weight_variable() 28 return nn.Dense(input_channels, out_channels, weight, bias) 29 30 31class LeNet5(nn.Cell): 32 def __init__(self): 33 super(LeNet5, self).__init__() 34 self.batch_size = 32 35 self.conv1 = conv(1, 6, 5) 36 self.conv2 = conv(6, 16, 5) 37 self.fc1 = fc_with_initialize(16 * 5 * 5, 120) 38 self.fc2 = fc_with_initialize(120, 84) 39 self.fc3 = fc_with_initialize(84, 10) 40 self.relu = nn.ReLU() 41 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 42 self.reshape = P.Reshape() 43 44 def construct(self, x): 45 x = self.conv1(x) 46 x = self.relu(x) 47 x = self.max_pool2d(x) 48 x = self.conv2(x) 49 x = self.relu(x) 50 x = self.max_pool2d(x) 51 x = self.reshape(x, (self.batch_size, -1)) 52 x = self.fc1(x) 53 x = self.relu(x) 54 x = self.fc2(x) 55 x = self.relu(x) 56 x = self.fc3(x) 57 return x 58 59 60class WithLossCell(nn.Cell): 61 def __init__(self, network): 62 super(WithLossCell, self).__init__(auto_prefix=False) 63 self.loss = nn.SoftmaxCrossEntropyWithLogits() 64 self.network = network 65 66 def construct(self, x, label): 67 predict = self.network(x) 68 return self.loss(predict, label) 69 70 71class TrainOneStepCell(nn.Cell): 72 def __init__(self, network): 73 super(TrainOneStepCell, self).__init__(auto_prefix=False) 74 self.network = network 75 self.network.set_train() 76 self.weights = ParameterTuple(network.trainable_params()) 77 self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) 78 self.hyper_map = C.HyperMap() 79 self.grad = C.GradOperation(get_by_list=True) 80 81 def construct(self, x, label): 82 weights = self.weights 83 grads = self.grad(self.network, weights)(x, label) 84 return self.optimizer(grads) 85 86 87def test_export_lenet_grad_mindir(): 88 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 89 network = LeNet5() 90 network.set_train() 91 predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) 92 label = Tensor(np.zeros([32, 10]).astype(np.float32)) 93 net = TrainOneStepCell(WithLossCell(network)) 94 file_name = "lenet_grad" 95 export(net, predict, label, file_name=file_name, file_format='MINDIR') 96 verify_name = file_name + ".mindir" 97 assert os.path.exists(verify_name) 98 os.remove(verify_name) 99