• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore import context
21from mindspore.common.parameter import Parameter, ParameterTuple
22from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
23from mindspore.nn.optim.momentum import Momentum
24from mindspore.ops import composite as C, operations as P
25from mindspore.train import Model
26from mindspore.context import ParallelMode
27from mindspore.train.loss_scale_manager import DynamicLossScaleManager
28from tests.dataset_mock import MindData
29
30context.set_context(mode=context.GRAPH_MODE)
31
32
33class Dataset(MindData):
34    def __init__(self, predict, label, length=3):
35        super(Dataset, self).__init__(size=length)
36        self.predict = predict
37        self.label = label
38        self.index = 0
39        self.length = length
40
41    def __iter__(self):
42        return self
43
44    def __next__(self):
45        if self.index >= self.length:
46            raise StopIteration
47        self.index += 1
48        return self.predict, self.label
49
50    def reset(self):
51        self.index = 0
52
53
54class AllToAllNet(nn.Cell):
55    def __init__(self, strategy1):
56        super(AllToAllNet, self).__init__()
57        self.matmul = P.MatMul().shard(((1, 1), (1, 8)))
58        self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight")
59        self.transpose1 = P.Transpose().shard(strategy1)
60
61    def construct(self, x):
62        x = self.matmul(x, self.matmul_weight)
63        x = self.transpose1(x, (1, 0))
64        return x
65
66
67def all_to_all_net(strategy1):
68    return AllToAllNet(strategy1=strategy1)
69
70
71def loss_scale_manager_common(strategy1):
72    learning_rate = 0.1
73    momentum = 0.9
74    epoch_size = 2
75
76    context.reset_auto_parallel_context()
77    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=8)
78    predict = Tensor(np.ones([32, 128]), dtype=ms.float32)
79    label = Tensor(np.ones([32]), dtype=ms.int32)
80    dataset = Dataset(predict, label, 2)
81    net = all_to_all_net(strategy1)
82
83    loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
84    loss.softmax_cross_entropy.shard(((8, 1), (8, 1)))
85    opt = Momentum(net.trainable_params(), learning_rate, momentum)
86    scale_manager = DynamicLossScaleManager(32, 2, 2000)
87    model = Model(net, loss, opt, loss_scale_manager=scale_manager)
88    # if no GE exists, outputs = self._train_network(*next_element) outputs inputs tensor.
89    try:
90        model.train(epoch_size, dataset, dataset_sink_mode=False)
91    except TypeError:
92        pass
93    else:
94        assert False
95
96
97def fixme_test_dataset_interface_sens_scalar():
98    # With error: "The type of sens node is not Tensor or Parameter, it is unsupported now."
99    strategy1 = ((8, 1),)
100    loss_scale_manager_common(strategy1)
101
102
103class TrainOneStepCell(nn.Cell):
104
105    def __init__(self, network, optimizer):
106        super(TrainOneStepCell, self).__init__(auto_prefix=False)
107        self.network = network
108        self.network.add_flags(defer_inline=True)
109        self.weights = ParameterTuple(network.trainable_params())
110        self.optimizer = optimizer
111        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
112
113    def construct(self, data, sens):
114        weights = self.weights
115        loss = self.network(data)
116        grads = self.grad(self.network, weights)(data, sens)
117        self.optimizer(grads)
118        return loss
119
120
121def loss_scale_manager_sens(strategy1, sens):
122    learning_rate = 0.1
123    momentum = 0.9
124    device_num = 8
125    context.reset_auto_parallel_context()
126    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num)
127    predict = Tensor(np.ones([32 * device_num, 128]), dtype=ms.float32)
128    net = all_to_all_net(strategy1)
129    opt = Momentum(net.trainable_params(), learning_rate, momentum)
130    train_net = TrainOneStepCell(net, opt)
131    train_net.set_train()
132    train_net(predict, sens)
133
134
135def test_dataset_interface_sens_shape_not_equal_loss():
136    strategy1 = ((8, 1),)
137    sens = Tensor(np.ones([256, 1024]), dtype=ms.float32)
138    try:
139        loss_scale_manager_sens(strategy1, sens)
140    except ValueError:
141        pass
142    except TypeError:
143        pass
144    except RuntimeError:
145        pass
146
147
148def test_dataset_interface_sens_shape_equal_loss():
149    strategy1 = ((4, 2),)
150    sens = Tensor(np.ones([256, 256]), dtype=ms.float32)
151    loss_scale_manager_sens(strategy1, sens)
152
153
154def test_input_not_in_parameter_layotu_dict():
155    class Net(nn.Cell):
156        def __init__(self, strategy1):
157            super(Net, self).__init__()
158            self.matmul = P.MatMul().shard(((1, 1), (1, 8)))
159            self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight")
160            self.transpose1 = P.Transpose().shard(strategy1)
161
162        def construct(self, x):
163            x = self.matmul(x, self.matmul_weight)
164            x = self.transpose1(x, (1, 0))
165            return x
166
167    strategy1 = ((8, 1),)
168    device_num = 8
169    context.reset_auto_parallel_context()
170    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num)
171    predict = Tensor(np.ones([32 * device_num, 128]), dtype=ms.float32)
172    net = Net(strategy1)
173    net.set_train()
174    net(predict)
175