• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore.nn as nn
18from mindspore import Tensor, Parameter
19from mindspore import context
20from mindspore.common import dtype as mstype
21from mindspore.nn.optim import Momentum
22from mindspore.ops import functional as F
23from mindspore.ops import operations as P
24from mindspore.train import Model
25from mindspore.train.loss_scale_manager import FixedLossScaleManager
26from ....dataset_mock import MindData
27
28context.set_context(mode=context.GRAPH_MODE)
29
30
31class MindDataSet(MindData):
32    def __init__(self, dataset_types, dataset_shapes):
33        super(MindDataSet, self).__init__(size=2, batch_size=32,
34                                          np_types=dataset_types,
35                                          output_shapes=dataset_shapes,
36                                          input_indexs=(0, 1))
37
38    def __next__(self):
39        if self._size < self._iter_num:
40            raise StopIteration
41        self._iter_num += 1
42        next_ = []
43        for shape, type_ in zip(self._output_shapes, self._np_types):
44            next_.append(Tensor(np.ones(shape).astype(type_)))
45        return tuple(next_)
46
47
48class Net(nn.Cell):
49    def __init__(self, in_features, out_features):
50        super(Net, self).__init__()
51        self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
52        self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias")
53        self.matmul = P.MatMul()
54        self.add = P.Add()
55
56    def construct(self, input_):
57        output = self.add(self.matmul(input_, self.weight), self.bias)
58        return output
59
60
61class NetFP16(nn.Cell):
62    def __init__(self, in_features, out_features):
63        super(NetFP16, self).__init__()
64        self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
65        self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias")
66        self.matmul = P.MatMul()
67        self.add = P.Add()
68        self.cast = P.Cast()
69
70    def construct(self, input_):
71        output = self.cast(
72            self.add(self.matmul(self.cast(input_, mstype.float16), self.cast(self.weight, mstype.float16)),
73                     self.cast(self.bias, mstype.float16)), mstype.float32)
74        return output
75
76
77def get_axis(x):
78    shape_op = P.Shape()
79    shape = shape_op(x)
80    length = F.tuple_len(shape)
81    perm = F.make_range(0, length)
82    return perm
83
84
85class MSELoss(nn.Cell):
86    def __init__(self):
87        super(MSELoss, self).__init__()
88        self.reduce_sum = P.ReduceSum()
89        self.square = P.Square()
90        self.reduce_mean = P.ReduceMean()
91
92    def construct(self, data, label):
93        diff = data - label
94        return self.reduce_mean(self.square(diff), get_axis(diff))
95
96
97def test_auto_parallel_flag():
98    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=1)
99    dataset_types = (np.float32, np.float32)
100    dataset_shapes = ((16, 16), (16, 16))
101
102    dataset = MindDataSet(dataset_types, dataset_shapes)
103    net = NetFP16(16, 16)
104    net.set_train()
105    scale_manager = FixedLossScaleManager()
106    loss = MSELoss()
107    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
108    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None, loss_scale_manager=scale_manager)
109    model.train(2, dataset)
110    assert model._train_network.get_flags()["auto_parallel"]
111    context.reset_auto_parallel_context()
112