• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import mindspore as ms
17import mindspore.nn as nn
18from mindspore import Tensor
19from mindspore import context
20from mindspore.train.model import Model
21from mindspore.common.initializer import initializer
22from mindspore.common.parameter import Parameter
23from mindspore.ops import operations as P
24
25
26class DatasetLenet():
27    def __init__(self, data, label, length=3):
28        self.data = data
29        self.label = label
30        self.index = 1
31        self.length = length
32
33    def __iter__(self):
34        return self
35
36    def __next__(self):
37        if self.index >= self.length:
38            raise StopIteration
39        self.index += 1
40        return self.data, self.label
41
42    def reset(self):
43        self.index = 0
44
45    def get_dataset_size(self):
46        return 32
47
48    def get_repeat_count(self):
49        return 1
50
51    def get_batch_size(self):
52        return 32
53
54    def create_tuple_iterator(self, num_epochs=1, do_copy=True):
55        return self
56
57
58class MatMulCell(nn.Cell):
59    def __init__(self):
60        super().__init__()
61        self.matmul = P.MatMul()
62        self.relu = P.ReLU()
63        self.weight = Parameter(initializer("ones", [64, 64]), name="param1")
64
65    def construct(self, x):
66        out = self.matmul(x, self.weight)
67        out = self.relu(out)
68        return out
69
70
71class Net(nn.Cell):
72    def __init__(self, strategy1, strategy2):
73        super().__init__()
74        self.matmul = P.MatMul().shard(strategy1)
75        self.weight = Parameter(initializer("ones", [64, 64]), name="param")
76        self.cell1 = MatMulCell()
77        self.cell2 = MatMulCell()
78        self.cell3 = MatMulCell()
79        self.cell4 = MatMulCell()
80        self.relu = P.ReLU().shard(strategy2)
81        self.reduce = P.ReduceSum()
82
83    def construct(self, x, y):
84        out = self.matmul(x, self.weight)
85        if self.reduce(y) == 1.0:
86            out = self.cell1(out)
87        elif self.reduce(y) == 2.0:
88            out = self.cell2(out)
89        elif self.reduce(y) == 3.0:
90            out = self.cell3(out)
91        else:
92            out = self.cell4(out)
93        out = self.relu(out)
94        out = out + x
95        return out
96
97
98def test_control_flow():
99    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
100    context.set_auto_parallel_context(device_num=8, global_rank=0)
101    strategy1 = ((2, 4), (4, 1))
102    strategy2 = ((8, 1),)
103    net = Net(strategy1, strategy2)
104    data = Tensor(np.ones([128, 64]), dtype=ms.float32)
105    label = Tensor(np.ones([8, 8]), dtype=ms.float32)
106    dataset = DatasetLenet(data, label, 3)
107    opt = nn.Lamb(net.trainable_params(), learning_rate=0.01)
108    model = Model(net, optimizer=opt)
109    model.train(2, dataset, dataset_sink_mode=False)
110