1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common.parameter import Parameter, ParameterTuple 22from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits 23from mindspore.nn.optim.momentum import Momentum 24from mindspore.ops import composite as C, operations as P 25from mindspore.train import Model 26from mindspore.context import ParallelMode 27from mindspore.train.loss_scale_manager import DynamicLossScaleManager 28from tests.dataset_mock import MindData 29 30context.set_context(mode=context.GRAPH_MODE) 31 32 33class Dataset(MindData): 34 def __init__(self, predict, label, length=3): 35 super(Dataset, self).__init__(size=length) 36 self.predict = predict 37 self.label = label 38 self.index = 0 39 self.length = length 40 41 def __iter__(self): 42 return self 43 44 def __next__(self): 45 if self.index >= self.length: 46 raise StopIteration 47 self.index += 1 48 return self.predict, self.label 49 50 def reset(self): 51 self.index = 0 52 53 54class AllToAllNet(nn.Cell): 55 def __init__(self, strategy1): 56 super(AllToAllNet, self).__init__() 57 self.matmul = P.MatMul().shard(((1, 1), (1, 8))) 58 self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight") 59 self.transpose1 = P.Transpose().shard(strategy1) 60 61 def construct(self, x): 62 x = self.matmul(x, self.matmul_weight) 63 x = self.transpose1(x, (1, 0)) 64 return x 65 66 67def all_to_all_net(strategy1): 68 return AllToAllNet(strategy1=strategy1) 69 70 71def loss_scale_manager_common(strategy1): 72 learning_rate = 0.1 73 momentum = 0.9 74 epoch_size = 2 75 76 context.reset_auto_parallel_context() 77 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=8) 78 predict = Tensor(np.ones([32, 128]), dtype=ms.float32) 79 label = Tensor(np.ones([32]), dtype=ms.int32) 80 dataset = Dataset(predict, label, 2) 81 net = all_to_all_net(strategy1) 82 83 loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 84 loss.softmax_cross_entropy.shard(((8, 1), (8, 1))) 85 opt = Momentum(net.trainable_params(), learning_rate, momentum) 86 scale_manager = DynamicLossScaleManager(32, 2, 2000) 87 model = Model(net, loss, opt, loss_scale_manager=scale_manager) 88 # if no GE exists, outputs = self._train_network(*next_element) outputs inputs tensor. 89 try: 90 model.train(epoch_size, dataset, dataset_sink_mode=False) 91 except TypeError: 92 pass 93 else: 94 assert False 95 96 97def fixme_test_dataset_interface_sens_scalar(): 98 # With error: "The type of sens node is not Tensor or Parameter, it is unsupported now." 99 strategy1 = ((8, 1),) 100 loss_scale_manager_common(strategy1) 101 102 103class TrainOneStepCell(nn.Cell): 104 105 def __init__(self, network, optimizer): 106 super(TrainOneStepCell, self).__init__(auto_prefix=False) 107 self.network = network 108 self.network.add_flags(defer_inline=True) 109 self.weights = ParameterTuple(network.trainable_params()) 110 self.optimizer = optimizer 111 self.grad = C.GradOperation(get_by_list=True, sens_param=True) 112 113 def construct(self, data, sens): 114 weights = self.weights 115 loss = self.network(data) 116 grads = self.grad(self.network, weights)(data, sens) 117 self.optimizer(grads) 118 return loss 119 120 121def loss_scale_manager_sens(strategy1, sens): 122 learning_rate = 0.1 123 momentum = 0.9 124 device_num = 8 125 context.reset_auto_parallel_context() 126 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num) 127 predict = Tensor(np.ones([32 * device_num, 128]), dtype=ms.float32) 128 net = all_to_all_net(strategy1) 129 opt = Momentum(net.trainable_params(), learning_rate, momentum) 130 train_net = TrainOneStepCell(net, opt) 131 train_net.set_train() 132 train_net(predict, sens) 133 134 135def test_dataset_interface_sens_shape_not_equal_loss(): 136 strategy1 = ((8, 1),) 137 sens = Tensor(np.ones([256, 1024]), dtype=ms.float32) 138 try: 139 loss_scale_manager_sens(strategy1, sens) 140 except ValueError: 141 pass 142 except TypeError: 143 pass 144 except RuntimeError: 145 pass 146 147 148def test_dataset_interface_sens_shape_equal_loss(): 149 strategy1 = ((4, 2),) 150 sens = Tensor(np.ones([256, 256]), dtype=ms.float32) 151 loss_scale_manager_sens(strategy1, sens) 152 153 154def test_input_not_in_parameter_layotu_dict(): 155 class Net(nn.Cell): 156 def __init__(self, strategy1): 157 super(Net, self).__init__() 158 self.matmul = P.MatMul().shard(((1, 1), (1, 8))) 159 self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight") 160 self.transpose1 = P.Transpose().shard(strategy1) 161 162 def construct(self, x): 163 x = self.matmul(x, self.matmul_weight) 164 x = self.transpose1(x, (1, 0)) 165 return x 166 167 strategy1 = ((8, 1),) 168 device_num = 8 169 context.reset_auto_parallel_context() 170 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num) 171 predict = Tensor(np.ones([32 * device_num, 128]), dtype=ms.float32) 172 net = Net(strategy1) 173 net.set_train() 174 net(predict) 175