• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import os
2
3import pytest
4
5import mindspore.dataset as ds
6import mindspore.dataset.transforms.c_transforms as CT
7import mindspore.dataset.vision.c_transforms as CV
8import mindspore.nn as nn
9from mindspore import ParameterTuple
10from mindspore import context
11from mindspore.common import dtype as mstype
12from mindspore.common.initializer import Normal
13from mindspore.dataset.vision import Inter
14from mindspore.nn import Cell
15from mindspore.ops import composite as C
16from mindspore.ops import functional as F
17from mindspore.ops import operations as P
18from mindspore.train.dataset_helper import DatasetHelper
19from mindspore.train.serialization import save_checkpoint
20
21_sum_op = C.MultitypeFuncGraph("grad_sum_op")
22_clear_op = C.MultitypeFuncGraph("clear_op")
23
24
25@_sum_op.register("Tensor", "Tensor")
26def _cumulative_gard(grad_sum, grad):
27    """Apply gard sum to cumulative gradient."""
28    add = P.AssignAdd()
29    return add(grad_sum, grad)
30
31
32@_clear_op.register("Tensor", "Tensor")
33def _clear_grad_sum(grad_sum, zero):
34    """Apply zero to clear grad_sum."""
35    success = True
36    success = F.depend(success, F.assign(grad_sum, zero))
37    return success
38
39
40class LeNet5(nn.Cell):
41    """
42    Lenet network
43
44    Args:
45        num_class (int): Num classes. Default: 10.
46        num_channel (int): Num channels. Default: 1.
47
48    Returns:
49        Tensor, output tensor
50    Examples:
51        >>> LeNet(num_class=10)
52    """
53    def __init__(self, num_class=10, num_channel=1):
54        super(LeNet5, self).__init__()
55        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
56        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
57        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
58        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
59        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
60        self.relu = nn.ReLU()
61        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
62        self.flatten = nn.Flatten()
63
64    def construct(self, x):
65        x = self.max_pool2d(self.relu(self.conv1(x)))
66        x = self.max_pool2d(self.relu(self.conv2(x)))
67        x = self.flatten(x)
68        x = self.relu(self.fc1(x))
69        x = self.relu(self.fc2(x))
70        x = self.fc3(x)
71        return x
72
73
74class TrainForwardBackward(Cell):
75    def __init__(self, network, optimizer, grad_sum, sens=1.0):
76        super(TrainForwardBackward, self).__init__(auto_prefix=False)
77        self.network = network
78        self.network.set_grad()
79        self.network.add_flags(defer_inline=True)
80        self.weights = ParameterTuple(network.trainable_params())
81        self.optimizer = optimizer
82        self.grad_sum = grad_sum
83        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
84        self.sens = sens
85        self.hyper_map = C.HyperMap()
86
87    def construct(self, *inputs):
88        weights = self.weights
89        loss = self.network(*inputs)
90        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
91        grads = self.grad(self.network, weights)(*inputs, sens)
92        return F.depend(loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads))
93
94
95class TrainOptim(Cell):
96    def __init__(self, optimizer, grad_sum):
97        super(TrainOptim, self).__init__(auto_prefix=False)
98        self.optimizer = optimizer
99        self.grad_sum = grad_sum
100
101    def construct(self):
102        return self.optimizer(self.grad_sum)
103
104
105class TrainClear(Cell):
106    def __init__(self, grad_sum, zeros):
107        super(TrainClear, self).__init__(auto_prefix=False)
108        self.grad_sum = grad_sum
109        self.zeros = zeros
110        self.hyper_map = C.HyperMap()
111
112    def construct(self):
113        seccess = self.hyper_map(F.partial(_clear_op), self.grad_sum, self.zeros)
114        return seccess
115
116
117class GradientAccumulation:
118    def __init__(self, network, loss_fn, optimizer):
119        self._network = network
120        self._loss_fn = loss_fn
121        self._optimizer = optimizer
122
123        params = self._optimizer.parameters
124        self._grad_sum = params.clone(prefix="grad_sum", init='zeros')
125        self._zeros = params.clone(prefix="zeros", init='zeros')
126        self._train_forward_backward = self._build_train_forward_backward_network()
127        self._train_optim = self._build_train_optim()
128        self._train_clear = self._build_train_clear()
129
130    def _build_train_forward_backward_network(self):
131        """Build forward and backward network"""
132        network = self._network
133        network = nn.WithLossCell(network, self._loss_fn)
134        loss_scale = 1.0
135        network = TrainForwardBackward(network, self._optimizer, self._grad_sum, loss_scale).set_train()
136        return network
137
138    def _build_train_optim(self):
139        """Build optimizer network"""
140        network = TrainOptim(self._optimizer, self._grad_sum).set_train()
141        return network
142
143    def _build_train_clear(self):
144        """Build clear network"""
145        network = TrainClear(self._grad_sum, self._zeros).set_train()
146        return network
147
148    def train_process(self, epoch, train_dataset, mini_steps=None):
149        """
150        Training process. The data would be passed to network directly.
151        """
152        dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch)
153
154        for i in range(epoch):
155            step = 0
156            for k, next_element in enumerate(dataset_helper):
157                loss = self._train_forward_backward(*next_element)
158                if (k + 1) % mini_steps == 0:
159                    step += 1
160                    print("epoch:", i + 1, "step:", step, "loss is ", loss)
161                    self._train_optim()
162                    self._train_clear()
163
164            train_dataset.reset()
165
166        save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt",)
167
168
169def create_dataset(data_path, batch_size=32, repeat_size=1,
170                   num_parallel_workers=1):
171    """
172    create dataset for train or test
173    """
174    # define dataset
175    mnist_ds = ds.MnistDataset(data_path)
176
177    resize_height, resize_width = 32, 32
178    rescale = 1.0 / 255.0
179    shift = 0.0
180    rescale_nml = 1 / 0.3081
181    shift_nml = -1 * 0.1307 / 0.3081
182
183    # define map operations
184    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Bilinear mode
185    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
186    rescale_op = CV.Rescale(rescale, shift)
187    hwc2chw_op = CV.HWC2CHW()
188    type_cast_op = CT.TypeCast(mstype.int32)
189
190    # apply map operations on images
191    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
192    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
193    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
194    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
195    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
196
197    # apply DatasetOps
198    buffer_size = 10000
199    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script
200    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
201    mnist_ds = mnist_ds.repeat(repeat_size)
202
203    return mnist_ds
204
205
206@pytest.mark.level1
207@pytest.mark.platform_arm_ascend_training
208@pytest.mark.platform_x86_ascend_training
209@pytest.mark.env_onecard
210def test_gradient_accumulation():
211    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
212    ds_train = create_dataset(os.path.join("/home/workspace/mindspore_dataset/mnist", "train"), 32)
213
214    network = LeNet5(10)
215    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
216    net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
217    model = GradientAccumulation(network, net_loss, net_opt)
218
219    print("============== Starting Training ==============")
220    model.train_process(2, ds_train, mini_steps=4)
221