• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test sparse feature bprop """
16import pytest
17import numpy as np
18
19import mindspore as ms
20import mindspore.nn as nn
21from mindspore import context
22from mindspore.common.parameter import Parameter
23from mindspore.common.tensor import Tensor
24from mindspore.ops import composite as C, operations as P
25from mindspore.ops.operations.comm_ops import AllReduce
26from mindspore.common.api import _cell_graph_executor
27from mindspore.nn import TrainOneStepCell, Adam
28
29
30grad_all = C.GradOperation(get_all=True)
31
32
33@pytest.fixture(name="test_context")
34def _test_context():
35    context.set_context(enable_sparse=True)
36    yield
37    context.set_context(enable_sparse=False)
38    context.reset_auto_parallel_context()
39
40
41class GradWrap(nn.Cell):
42    def __init__(self, network):
43        super(GradWrap, self).__init__()
44        self.network = network
45
46    def construct(self, x):
47        return grad_all(self.network)(x)
48
49def test_bprop_with_sparse_feature_allreduce(test_context):
50    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
51
52    class Net(nn.Cell):
53        def __init__(self, axis=0, shape=None):
54            super(Net, self).__init__()
55            if shape is None:
56                shape = [8, 8]
57            self.all_reduce = AllReduce()
58            self.gatherv2 = P.SparseGatherV2()
59            self.index = Tensor(np.ones(shape), dtype=ms.int32)
60            self.axis = axis
61
62        def construct(self, x):
63            out = self.all_reduce(x)
64            out = self.gatherv2(out, self.index, self.axis)
65
66            return out
67
68    net = GradWrap(Net())
69    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
70
71    net.set_train()
72    _cell_graph_executor.compile(net, x)
73
74
75def test_bprop_with_sparse_feature_mirror(test_context):
76    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
77
78    class Net(nn.Cell):
79        def __init__(self, shape=None):
80            super(Net, self).__init__()
81            if shape is None:
82                shape = [8, 8]
83            self.index = Tensor(np.ones(shape), dtype=ms.int32)
84            self.embeddinglookup = nn.EmbeddingLookup(64, 64, param_init='ones')
85            self.embeddinglookup.embeddinglookup.shard(((1, 1), (8, 1)))
86
87        def construct(self, x, b):
88            out = self.embeddinglookup(self.index)
89
90            return out
91
92    _x = Tensor(np.ones([126, 64, 32]), dtype=ms.float32)
93    _b = Tensor(np.ones([126, 64, 32]), dtype=ms.float32)
94
95    def compile_net(net):
96        optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
97        train_net = TrainOneStepCell(net, optimizer)
98        train_net.set_train()
99        _cell_graph_executor.compile(train_net, _x, _b)
100
101    net = Net()
102    compile_net(net)
103
104
105def test_bprop_with_sparse_feature_dataparallel(test_context):
106    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="data_parallel")
107
108    class Net(nn.Cell):
109        def __init__(self, axis=0, shape=None):
110            super(Net, self).__init__()
111            if shape is None:
112                shape = [8, 8]
113            weight = Tensor(np.ones([64, 64]), dtype=ms.float32)
114            self.weight = Parameter(weight, "w")
115            self.index = Tensor(np.ones(shape), dtype=ms.int32)
116            self.axis = axis
117            self.gatherv2 = P.SparseGatherV2()
118
119        def construct(self, x, b):
120            out = self.gatherv2(self.weight, self.index, self.axis)
121
122            return out
123
124    _x = Tensor(np.ones([126, 64, 32]), dtype=ms.float32)
125    _b = Tensor(np.ones([126, 64, 32]), dtype=ms.float32)
126
127    def compile_net(net):
128        optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
129        train_net = TrainOneStepCell(net, optimizer)
130        train_net.set_train()
131        _cell_graph_executor.compile(train_net, _x, _b)
132
133    net = Net()
134    compile_net(net)
135