1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16 17import mindspore.nn as nn 18from mindspore import Tensor, Parameter 19from mindspore import context 20from mindspore.common import dtype as mstype 21from mindspore.nn.optim import Momentum 22from mindspore.ops import functional as F 23from mindspore.ops import operations as P 24from mindspore.train import Model 25from mindspore.train.loss_scale_manager import FixedLossScaleManager 26from ....dataset_mock import MindData 27 28context.set_context(mode=context.GRAPH_MODE) 29 30 31class MindDataSet(MindData): 32 def __init__(self, dataset_types, dataset_shapes): 33 super(MindDataSet, self).__init__(size=2, batch_size=32, 34 np_types=dataset_types, 35 output_shapes=dataset_shapes, 36 input_indexs=(0, 1)) 37 38 def __next__(self): 39 if self._size < self._iter_num: 40 raise StopIteration 41 self._iter_num += 1 42 next_ = [] 43 for shape, type_ in zip(self._output_shapes, self._np_types): 44 next_.append(Tensor(np.ones(shape).astype(type_))) 45 return tuple(next_) 46 47 48class Net(nn.Cell): 49 def __init__(self, in_features, out_features): 50 super(Net, self).__init__() 51 self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight") 52 self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias") 53 self.matmul = P.MatMul() 54 self.add = P.Add() 55 56 def construct(self, input_): 57 output = self.add(self.matmul(input_, self.weight), self.bias) 58 return output 59 60 61class NetFP16(nn.Cell): 62 def __init__(self, in_features, out_features): 63 super(NetFP16, self).__init__() 64 self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight") 65 self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias") 66 self.matmul = P.MatMul() 67 self.add = P.Add() 68 self.cast = P.Cast() 69 70 def construct(self, input_): 71 output = self.cast( 72 self.add(self.matmul(self.cast(input_, mstype.float16), self.cast(self.weight, mstype.float16)), 73 self.cast(self.bias, mstype.float16)), mstype.float32) 74 return output 75 76 77def get_axis(x): 78 shape_op = P.Shape() 79 shape = shape_op(x) 80 length = F.tuple_len(shape) 81 perm = F.make_range(0, length) 82 return perm 83 84 85class MSELoss(nn.Cell): 86 def __init__(self): 87 super(MSELoss, self).__init__() 88 self.reduce_sum = P.ReduceSum() 89 self.square = P.Square() 90 self.reduce_mean = P.ReduceMean() 91 92 def construct(self, data, label): 93 diff = data - label 94 return self.reduce_mean(self.square(diff), get_axis(diff)) 95 96 97def test_auto_parallel_flag(): 98 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=1) 99 dataset_types = (np.float32, np.float32) 100 dataset_shapes = ((16, 16), (16, 16)) 101 102 dataset = MindDataSet(dataset_types, dataset_shapes) 103 net = NetFP16(16, 16) 104 net.set_train() 105 scale_manager = FixedLossScaleManager() 106 loss = MSELoss() 107 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 108 model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None, loss_scale_manager=scale_manager) 109 model.train(2, dataset) 110 assert model._train_network.get_flags()["auto_parallel"] 111 context.reset_auto_parallel_context() 112