• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import os
17import numpy as np
18
19import mindspore as ms
20import mindspore.communication.management as distributedTool
21import mindspore.context as context
22from mindspore.common.tensor import Tensor
23from mindspore.nn import Cell
24from mindspore.ops import operations as P
25
26device_num = 2
27device_id = int(os.getenv('DEVICE_ID'))
28rank_id = 0
29
30
31def setup_module():
32    global device_num
33    global rank_id
34    np.random.seed(0)
35    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
36    context.set_context(device_id=device_id)
37    distributedTool.init()
38    device_num = distributedTool.get_group_size()
39    rank_id = distributedTool.get_rank()
40    context.set_auto_parallel_context(device_num=device_num,
41                                      global_rank=rank_id)
42
43
44def teardown_module():
45    distributedTool.release()
46
47
48class Onehot(Cell):
49    def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, strategy=None):
50        super(Onehot, self).__init__()
51        trans_stra = None
52        if strategy:
53            trans_stra = (strategy[0],)
54        self.onehot = P.OneHot().shard(strategy=strategy)
55        self.depth = depth
56        self.on_value = Tensor(on_value, ms.float32)
57        self.off_value = Tensor(off_value, ms.float32)
58        self.transpose = P.Transpose().shard(strategy=trans_stra)
59        self.sub = P.Sub().shard(strategy=((1, 1), (1, 1)))
60        self.axis = axis
61
62    def construct(self, input_, indices):
63        x = self.onehot(indices, self.depth, self.on_value, self.off_value)
64        x = self.transpose(x, (1, 0))
65        x = self.sub(input_, x)
66        return x
67
68
69class DataGenerator():
70    def get_parallel_blocks(self, input_, strategy):
71        blocks = [input_]
72        i = 0
73        for stra in strategy:
74            temp = []
75            while blocks:
76                block = blocks.pop(0)
77                temp.extend(np.split(block, stra, axis=i))
78            blocks.extend(temp)
79            i += 1
80        return blocks
81
82    def generate_data(self, shape):
83        data = np.random.rand(*shape)
84        return data
85
86    def input_data(self, shape):
87        data = (self.generate_data(shape) * 2).astype(np.float32)
88        stra = [1] * len(shape)
89        stra[0] = device_num
90        datas = self.get_parallel_blocks(data, stra)
91        return Tensor(data), Tensor(datas[rank_id])
92
93    def label_data(self, shape, classes):
94        data = (self.generate_data(shape) * (classes - 1)).astype(np.int32)
95        stra = [1] * len(shape)
96        stra[0] = device_num
97        datas = self.get_parallel_blocks(data, stra)
98        return Tensor(data), Tensor(datas[rank_id])
99
100
101class OneHotFactory:
102    def __init__(self, batch_size, classes, on_value=1.0, off_value=0.0, axis=None, strategy=None):
103        data_gen = DataGenerator()
104        self.input_full, self.input_part = data_gen.input_data((classes, batch_size))
105        self.label_full, self.label_part = data_gen.label_data((batch_size,), classes)
106        self.depth = classes
107        self.on_value = on_value
108        self.off_value = off_value
109        self.axis = axis
110        self.strategy = strategy
111
112    def forward_mindspore_single_impl(self):
113        net = Onehot(axis=self.axis,
114                     depth=self.depth,
115                     on_value=self.on_value,
116                     off_value=self.off_value)
117        out = net(self.input_full, self.label_full)
118        return out
119
120    def forward_mindspore_parallel_impl(self):
121        context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
122        net = Onehot(axis=self.axis,
123                     depth=self.depth,
124                     on_value=self.on_value,
125                     off_value=self.off_value, strategy=self.strategy)
126        out = net.compile_and_run(self.input_full, self.label_full)
127        return out
128
129    def forward_cmp(self):
130        out_mindspore_single = self.forward_mindspore_single_impl().asnumpy()
131        context.reset_auto_parallel_context()
132        out_mindspore_parallel = self.forward_mindspore_parallel_impl().asnumpy()
133        context.reset_auto_parallel_context()
134        assert np.allclose(out_mindspore_single, out_mindspore_parallel, 0.0001, 0.0001)
135
136
137def test_reid_onehot_forward_int32_128_depth1024_model_parallel():
138    fact = OneHotFactory(batch_size=128,
139                         classes=1024,
140                         on_value=1.000000,
141                         off_value=0.000000,
142                         axis=-1,
143                         strategy=((1, device_num), (), ()))
144    fact.forward_cmp()
145
146
147def test_reid_onehot_forward_int32_1024_depth128_model_parallel():
148    fact = OneHotFactory(batch_size=1024,
149                         classes=128,
150                         on_value=1.000000,
151                         off_value=0.000000,
152                         axis=-1,
153                         strategy=((1, device_num), (), ()))
154    fact.forward_cmp()
155