1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import os 17import numpy as np 18 19import mindspore as ms 20import mindspore.communication.management as distributedTool 21import mindspore.context as context 22from mindspore.common.tensor import Tensor 23from mindspore.nn import Cell 24from mindspore.ops import operations as P 25 26device_num = 2 27device_id = int(os.getenv('DEVICE_ID')) 28rank_id = 0 29 30 31def setup_module(): 32 global device_num 33 global rank_id 34 np.random.seed(0) 35 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 36 context.set_context(device_id=device_id) 37 distributedTool.init() 38 device_num = distributedTool.get_group_size() 39 rank_id = distributedTool.get_rank() 40 context.set_auto_parallel_context(device_num=device_num, 41 global_rank=rank_id) 42 43 44def teardown_module(): 45 distributedTool.release() 46 47 48class Onehot(Cell): 49 def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, strategy=None): 50 super(Onehot, self).__init__() 51 trans_stra = None 52 if strategy: 53 trans_stra = (strategy[0],) 54 self.onehot = P.OneHot().shard(strategy=strategy) 55 self.depth = depth 56 self.on_value = Tensor(on_value, ms.float32) 57 self.off_value = Tensor(off_value, ms.float32) 58 self.transpose = P.Transpose().shard(strategy=trans_stra) 59 self.sub = P.Sub().shard(strategy=((1, 1), (1, 1))) 60 self.axis = axis 61 62 def construct(self, input_, indices): 63 x = self.onehot(indices, self.depth, self.on_value, self.off_value) 64 x = self.transpose(x, (1, 0)) 65 x = self.sub(input_, x) 66 return x 67 68 69class DataGenerator(): 70 def get_parallel_blocks(self, input_, strategy): 71 blocks = [input_] 72 i = 0 73 for stra in strategy: 74 temp = [] 75 while blocks: 76 block = blocks.pop(0) 77 temp.extend(np.split(block, stra, axis=i)) 78 blocks.extend(temp) 79 i += 1 80 return blocks 81 82 def generate_data(self, shape): 83 data = np.random.rand(*shape) 84 return data 85 86 def input_data(self, shape): 87 data = (self.generate_data(shape) * 2).astype(np.float32) 88 stra = [1] * len(shape) 89 stra[0] = device_num 90 datas = self.get_parallel_blocks(data, stra) 91 return Tensor(data), Tensor(datas[rank_id]) 92 93 def label_data(self, shape, classes): 94 data = (self.generate_data(shape) * (classes - 1)).astype(np.int32) 95 stra = [1] * len(shape) 96 stra[0] = device_num 97 datas = self.get_parallel_blocks(data, stra) 98 return Tensor(data), Tensor(datas[rank_id]) 99 100 101class OneHotFactory: 102 def __init__(self, batch_size, classes, on_value=1.0, off_value=0.0, axis=None, strategy=None): 103 data_gen = DataGenerator() 104 self.input_full, self.input_part = data_gen.input_data((classes, batch_size)) 105 self.label_full, self.label_part = data_gen.label_data((batch_size,), classes) 106 self.depth = classes 107 self.on_value = on_value 108 self.off_value = off_value 109 self.axis = axis 110 self.strategy = strategy 111 112 def forward_mindspore_single_impl(self): 113 net = Onehot(axis=self.axis, 114 depth=self.depth, 115 on_value=self.on_value, 116 off_value=self.off_value) 117 out = net(self.input_full, self.label_full) 118 return out 119 120 def forward_mindspore_parallel_impl(self): 121 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 122 net = Onehot(axis=self.axis, 123 depth=self.depth, 124 on_value=self.on_value, 125 off_value=self.off_value, strategy=self.strategy) 126 out = net.compile_and_run(self.input_full, self.label_full) 127 return out 128 129 def forward_cmp(self): 130 out_mindspore_single = self.forward_mindspore_single_impl().asnumpy() 131 context.reset_auto_parallel_context() 132 out_mindspore_parallel = self.forward_mindspore_parallel_impl().asnumpy() 133 context.reset_auto_parallel_context() 134 assert np.allclose(out_mindspore_single, out_mindspore_parallel, 0.0001, 0.0001) 135 136 137def test_reid_onehot_forward_int32_128_depth1024_model_parallel(): 138 fact = OneHotFactory(batch_size=128, 139 classes=1024, 140 on_value=1.000000, 141 off_value=0.000000, 142 axis=-1, 143 strategy=((1, device_num), (), ())) 144 fact.forward_cmp() 145 146 147def test_reid_onehot_forward_int32_1024_depth128_model_parallel(): 148 fact = OneHotFactory(batch_size=1024, 149 classes=128, 150 on_value=1.000000, 151 off_value=0.000000, 152 axis=-1, 153 strategy=((1, device_num), (), ())) 154 fact.forward_cmp() 155