• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore import context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.common.parameter import Parameter
23from mindspore.ops import composite as C
24from mindspore.ops import operations as P
25from mindspore.common.initializer import initializer
26from mindspore.nn import TrainOneStepCell, Momentum
27from tests.ut.python.ops.test_math_ops import VirtualLoss
28
29
30grad_all = C.GradOperation(get_all=True)
31
32
33class NetWithLoss(nn.Cell):
34    def __init__(self, network):
35        super(NetWithLoss, self).__init__()
36        self.loss = VirtualLoss()
37        self.network = network
38
39    def construct(self, x):
40        predict = self.network(x)
41        return self.loss(predict)
42
43
44class GradWrap(nn.Cell):
45    def __init__(self, network):
46        super(GradWrap, self).__init__()
47        self.network = network
48
49    def construct(self, x):
50        return grad_all(self.network)(x)
51
52def test_unique_column_split():
53    class Net(nn.Cell):
54        def __init__(self):
55            super().__init__()
56            self.unique = P.Unique().shard(((1,),))
57            self.relu = P.ReLU()
58            self.mul = P.Mul()
59            self.embedding_lookp = P.Gather().shard(((1, 8), (1,)))
60            self.embedding_table = Parameter(initializer('normal', [2000, 128]),
61                                             name='embedding_table')
62            self.gatherv2 = P.Gather().shard(((1, 8), (1,)))
63            self.reshape = P.Reshape()
64            self.matmul = P.MatMul()
65            self.mul_weight = Parameter(Tensor(np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight")
66
67        def construct(self, indices):
68            indices_flatten = self.reshape(indices, (-1,))
69            unique_id, unique_idx = self.unique(indices_flatten)
70            unique_id_weight = self.embedding_lookp(self.embedding_table, unique_id, 0)
71            weight_flatten = self.gatherv2(unique_id_weight, unique_idx, 0)
72            weight = self.reshape(weight_flatten, (32, 64, 128))
73            vx = self.mul(weight, self.mul_weight)
74            return vx
75
76    size = 8
77    context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="auto_parallel")
78    x = Tensor(np.ones([32, 64]), dtype=ms.int32)
79    net = Net()
80    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
81    train_net = TrainOneStepCell(net, optimizer)
82    train_net.set_auto_parallel()
83    train_net.set_train()
84    _cell_graph_executor.compile(train_net, x)
85
86def test_unique_row_split():
87    class Net(nn.Cell):
88        def __init__(self):
89            super().__init__()
90            self.unique = P.Unique().shard(((1,),))
91            self.relu = P.ReLU()
92            self.mul = P.Mul()
93            self.embedding_lookp = P.Gather().shard(((8, 1), (1,)))
94            self.embedding_table = Parameter(initializer('normal', [2000, 128]),
95                                             name='embedding_table')
96            self.gatherv2 = P.Gather().shard(((1, 1), (1,)))
97            self.reshape = P.Reshape()
98            self.matmul = P.MatMul()
99            self.mul_weight = Parameter(Tensor(np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight")
100
101        def construct(self, indices):
102            indices_flatten = self.reshape(indices, (-1,))
103            unique_id, unique_idx = self.unique(indices_flatten)
104            unique_id_weight = self.embedding_lookp(self.embedding_table, unique_id, 0)
105            weight_flatten = self.gatherv2(unique_id_weight, unique_idx, 0)
106            weight = self.reshape(weight_flatten, (32, 64, 128))
107            vx = self.mul(weight, self.mul_weight)
108            return vx
109
110    size = 8
111    context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="semi_auto_parallel")
112    x = Tensor(np.ones([32, 64]), dtype=ms.int32)
113    net = Net()
114    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
115    train_net = TrainOneStepCell(net, optimizer)
116    train_net.set_auto_parallel()
117    train_net.set_train()
118    _cell_graph_executor.compile(train_net, x)
119