• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import os
2import numpy as np
3
4import mindspore.nn as nn
5from mindspore import context
6from mindspore.common.tensor import Tensor
7from mindspore.common.initializer import TruncatedNormal
8from mindspore.common.parameter import ParameterTuple
9from mindspore.ops import operations as P
10from mindspore.ops import composite as C
11from mindspore.train.serialization import export
12
13
14def weight_variable():
15    return TruncatedNormal(0.02)
16
17
18def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
19    weight = weight_variable()
20    return nn.Conv2d(in_channels, out_channels,
21                     kernel_size=kernel_size, stride=stride, padding=padding,
22                     weight_init=weight, has_bias=False, pad_mode="valid")
23
24
25def fc_with_initialize(input_channels, out_channels):
26    weight = weight_variable()
27    bias = weight_variable()
28    return nn.Dense(input_channels, out_channels, weight, bias)
29
30
31class LeNet5(nn.Cell):
32    def __init__(self):
33        super(LeNet5, self).__init__()
34        self.batch_size = 32
35        self.conv1 = conv(1, 6, 5)
36        self.conv2 = conv(6, 16, 5)
37        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
38        self.fc2 = fc_with_initialize(120, 84)
39        self.fc3 = fc_with_initialize(84, 10)
40        self.relu = nn.ReLU()
41        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
42        self.reshape = P.Reshape()
43
44    def construct(self, x):
45        x = self.conv1(x)
46        x = self.relu(x)
47        x = self.max_pool2d(x)
48        x = self.conv2(x)
49        x = self.relu(x)
50        x = self.max_pool2d(x)
51        x = self.reshape(x, (self.batch_size, -1))
52        x = self.fc1(x)
53        x = self.relu(x)
54        x = self.fc2(x)
55        x = self.relu(x)
56        x = self.fc3(x)
57        return x
58
59
60class WithLossCell(nn.Cell):
61    def __init__(self, network):
62        super(WithLossCell, self).__init__(auto_prefix=False)
63        self.loss = nn.SoftmaxCrossEntropyWithLogits()
64        self.network = network
65
66    def construct(self, x, label):
67        predict = self.network(x)
68        return self.loss(predict, label)
69
70
71class TrainOneStepCell(nn.Cell):
72    def __init__(self, network):
73        super(TrainOneStepCell, self).__init__(auto_prefix=False)
74        self.network = network
75        self.network.set_train()
76        self.weights = ParameterTuple(network.trainable_params())
77        self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
78        self.hyper_map = C.HyperMap()
79        self.grad = C.GradOperation(get_by_list=True)
80
81    def construct(self, x, label):
82        weights = self.weights
83        grads = self.grad(self.network, weights)(x, label)
84        return self.optimizer(grads)
85
86
87def test_export_lenet_grad_mindir():
88    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
89    network = LeNet5()
90    network.set_train()
91    predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
92    label = Tensor(np.zeros([32, 10]).astype(np.float32))
93    net = TrainOneStepCell(WithLossCell(network))
94    file_name = "lenet_grad"
95    export(net, predict, label, file_name=file_name, file_format='MINDIR')
96    verify_name = file_name + ".mindir"
97    assert os.path.exists(verify_name)
98    os.remove(verify_name)
99