1import os 2 3import pytest 4 5import mindspore.dataset as ds 6import mindspore.dataset.transforms.c_transforms as CT 7import mindspore.dataset.vision.c_transforms as CV 8import mindspore.nn as nn 9from mindspore import ParameterTuple 10from mindspore import context 11from mindspore.common import dtype as mstype 12from mindspore.common.initializer import Normal 13from mindspore.dataset.vision import Inter 14from mindspore.nn import Cell 15from mindspore.ops import composite as C 16from mindspore.ops import functional as F 17from mindspore.ops import operations as P 18from mindspore.train.dataset_helper import DatasetHelper 19from mindspore.train.serialization import save_checkpoint 20 21_sum_op = C.MultitypeFuncGraph("grad_sum_op") 22_clear_op = C.MultitypeFuncGraph("clear_op") 23 24 25@_sum_op.register("Tensor", "Tensor") 26def _cumulative_gard(grad_sum, grad): 27 """Apply gard sum to cumulative gradient.""" 28 add = P.AssignAdd() 29 return add(grad_sum, grad) 30 31 32@_clear_op.register("Tensor", "Tensor") 33def _clear_grad_sum(grad_sum, zero): 34 """Apply zero to clear grad_sum.""" 35 success = True 36 success = F.depend(success, F.assign(grad_sum, zero)) 37 return success 38 39 40class LeNet5(nn.Cell): 41 """ 42 Lenet network 43 44 Args: 45 num_class (int): Num classes. Default: 10. 46 num_channel (int): Num channels. Default: 1. 47 48 Returns: 49 Tensor, output tensor 50 Examples: 51 >>> LeNet(num_class=10) 52 """ 53 def __init__(self, num_class=10, num_channel=1): 54 super(LeNet5, self).__init__() 55 self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') 56 self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') 57 self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) 58 self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) 59 self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) 60 self.relu = nn.ReLU() 61 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 62 self.flatten = nn.Flatten() 63 64 def construct(self, x): 65 x = self.max_pool2d(self.relu(self.conv1(x))) 66 x = self.max_pool2d(self.relu(self.conv2(x))) 67 x = self.flatten(x) 68 x = self.relu(self.fc1(x)) 69 x = self.relu(self.fc2(x)) 70 x = self.fc3(x) 71 return x 72 73 74class TrainForwardBackward(Cell): 75 def __init__(self, network, optimizer, grad_sum, sens=1.0): 76 super(TrainForwardBackward, self).__init__(auto_prefix=False) 77 self.network = network 78 self.network.set_grad() 79 self.network.add_flags(defer_inline=True) 80 self.weights = ParameterTuple(network.trainable_params()) 81 self.optimizer = optimizer 82 self.grad_sum = grad_sum 83 self.grad = C.GradOperation(get_by_list=True, sens_param=True) 84 self.sens = sens 85 self.hyper_map = C.HyperMap() 86 87 def construct(self, *inputs): 88 weights = self.weights 89 loss = self.network(*inputs) 90 sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) 91 grads = self.grad(self.network, weights)(*inputs, sens) 92 return F.depend(loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads)) 93 94 95class TrainOptim(Cell): 96 def __init__(self, optimizer, grad_sum): 97 super(TrainOptim, self).__init__(auto_prefix=False) 98 self.optimizer = optimizer 99 self.grad_sum = grad_sum 100 101 def construct(self): 102 return self.optimizer(self.grad_sum) 103 104 105class TrainClear(Cell): 106 def __init__(self, grad_sum, zeros): 107 super(TrainClear, self).__init__(auto_prefix=False) 108 self.grad_sum = grad_sum 109 self.zeros = zeros 110 self.hyper_map = C.HyperMap() 111 112 def construct(self): 113 seccess = self.hyper_map(F.partial(_clear_op), self.grad_sum, self.zeros) 114 return seccess 115 116 117class GradientAccumulation: 118 def __init__(self, network, loss_fn, optimizer): 119 self._network = network 120 self._loss_fn = loss_fn 121 self._optimizer = optimizer 122 123 params = self._optimizer.parameters 124 self._grad_sum = params.clone(prefix="grad_sum", init='zeros') 125 self._zeros = params.clone(prefix="zeros", init='zeros') 126 self._train_forward_backward = self._build_train_forward_backward_network() 127 self._train_optim = self._build_train_optim() 128 self._train_clear = self._build_train_clear() 129 130 def _build_train_forward_backward_network(self): 131 """Build forward and backward network""" 132 network = self._network 133 network = nn.WithLossCell(network, self._loss_fn) 134 loss_scale = 1.0 135 network = TrainForwardBackward(network, self._optimizer, self._grad_sum, loss_scale).set_train() 136 return network 137 138 def _build_train_optim(self): 139 """Build optimizer network""" 140 network = TrainOptim(self._optimizer, self._grad_sum).set_train() 141 return network 142 143 def _build_train_clear(self): 144 """Build clear network""" 145 network = TrainClear(self._grad_sum, self._zeros).set_train() 146 return network 147 148 def train_process(self, epoch, train_dataset, mini_steps=None): 149 """ 150 Training process. The data would be passed to network directly. 151 """ 152 dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch) 153 154 for i in range(epoch): 155 step = 0 156 for k, next_element in enumerate(dataset_helper): 157 loss = self._train_forward_backward(*next_element) 158 if (k + 1) % mini_steps == 0: 159 step += 1 160 print("epoch:", i + 1, "step:", step, "loss is ", loss) 161 self._train_optim() 162 self._train_clear() 163 164 train_dataset.reset() 165 166 save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt",) 167 168 169def create_dataset(data_path, batch_size=32, repeat_size=1, 170 num_parallel_workers=1): 171 """ 172 create dataset for train or test 173 """ 174 # define dataset 175 mnist_ds = ds.MnistDataset(data_path) 176 177 resize_height, resize_width = 32, 32 178 rescale = 1.0 / 255.0 179 shift = 0.0 180 rescale_nml = 1 / 0.3081 181 shift_nml = -1 * 0.1307 / 0.3081 182 183 # define map operations 184 resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode 185 rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) 186 rescale_op = CV.Rescale(rescale, shift) 187 hwc2chw_op = CV.HWC2CHW() 188 type_cast_op = CT.TypeCast(mstype.int32) 189 190 # apply map operations on images 191 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) 192 mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) 193 mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) 194 mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) 195 mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) 196 197 # apply DatasetOps 198 buffer_size = 10000 199 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script 200 mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) 201 mnist_ds = mnist_ds.repeat(repeat_size) 202 203 return mnist_ds 204 205 206@pytest.mark.level1 207@pytest.mark.platform_arm_ascend_training 208@pytest.mark.platform_x86_ascend_training 209@pytest.mark.env_onecard 210def test_gradient_accumulation(): 211 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 212 ds_train = create_dataset(os.path.join("/home/workspace/mindspore_dataset/mnist", "train"), 32) 213 214 network = LeNet5(10) 215 net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") 216 net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) 217 model = GradientAccumulation(network, net_loss, net_opt) 218 219 print("============== Starting Training ==============") 220 model.train_process(2, ds_train, mini_steps=4) 221