1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/onehot_info.h"
18
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "ir/value.h"
24 #include "frontend/parallel/auto_parallel/costmodel.h"
25 #include "frontend/parallel/device_matrix.h"
26 #include "frontend/parallel/graph_util/generate_graph.h"
27 #include "frontend/parallel/strategy.h"
28 #include "utils/log_adapter.h"
29
30 namespace mindspore {
31 namespace parallel {
GetAttrs()32 Status OneHotInfo::GetAttrs() {
33 auto iter = attrs_.find(AXIS);
34 if (iter != attrs_.end()) {
35 MS_EXCEPTION_IF_NULL(iter->second);
36 if (iter->second->isa<Int64Imm>()) {
37 axis_value_ptr_ = iter->second;
38 axis_ = iter->second->cast<Int64ImmPtr>()->value();
39 } else {
40 MS_LOG(ERROR) << name_ << ": The value of axis is not int64_t.";
41 return FAILED;
42 }
43 }
44
45 if ((axis_ > 1) || (axis_ < -1)) {
46 MS_LOG(ERROR) << name_ << ": Axis " << axis_ << " is out of range[-1, 1].";
47 return FAILED;
48 }
49 return SUCCESS;
50 }
51
CheckStrategy(const StrategyPtr & strategy)52 Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) {
53 if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}) != SUCCESS) {
54 return FAILED;
55 }
56 auto stra = strategy->GetInputDim().at(0);
57 bool invalid = false;
58 for (size_t i = 1; i < stra.size(); ++i) {
59 if (stra.at(i) != 1) {
60 invalid = true;
61 break;
62 }
63 }
64 if ((inputs_shape_.at(0).size() > 1) && ((axis_ != -1) || invalid)) {
65 MS_LOG(ERROR) << "When input dimension is > 1, axis must be -1, and strategy must be data parallel.";
66 return FAILED;
67 }
68 return SUCCESS;
69 }
70
InferDevMatrixShape()71 Status OneHotInfo::InferDevMatrixShape() {
72 Strategys stra = strategy_->GetInputDim();
73 Dimensions input_strategy = stra.at(0);
74
75 if (axis_ == 0) {
76 // Here, only support 1-D input tensor, so the output is a 2-D tensor
77 // If input is a vector of length features, the output shape will be:
78 // [depth, features] if axis == 0
79 dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable
80 dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable
81 } else {
82 for (const auto &input_stra : input_strategy) {
83 dev_matrix_shape_.push_back(input_stra);
84 }
85 }
86 old_dev_matrix_back_ = dev_matrix_shape_.back();
87 if (old_dev_matrix_back_ == 1) {
88 repeated_num_in_dev_matrix_right_ = true;
89 } else {
90 repeated_num_in_dev_matrix_right_ = false;
91 }
92 return SUCCESS;
93 }
94
InferTensorMap()95 Status OneHotInfo::InferTensorMap() {
96 Shape input_tensor_map_index, output_tensor_map_index;
97 size_t size = outputs_shape_[0].size();
98 if (axis_ == 0) {
99 for (size_t i = 0; i < size; ++i) {
100 output_tensor_map_index.push_back((int64_t)(i));
101 }
102 input_tensor_map_index.push_back(1);
103 } else {
104 for (size_t i = 0; i < size; ++i) {
105 output_tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i));
106 }
107 for (size_t i = 0; i < size - 1; ++i) {
108 input_tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i));
109 }
110 }
111 outputs_tensor_map_.push_back(output_tensor_map_index);
112
113 inputs_tensor_map_.push_back(input_tensor_map_index);
114 return SUCCESS;
115 }
116
117 // axis = -1
118 // (0,(1,16),(),())reid dev_matrix=(1,16) map_in=(1) map_out=(1,0)
119 // (0,(16,1),(),())data parallel dev_matrix=(16,1) map_in=(1) map_out=(1,0)
120 // (0,(2,8),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between
121 // machines dev_matrix=(2,8) map_in=(1) map_out=(1,0) (0, (2,4),(),())16 devices dev_matrix=(2,4,2) map_in=(1)
122 // map_out=(1,0)
123 // axis = 0
124 // (0, (16,1),(),())reid dev_matrix=(1,16) map_in=(1) map_out=(0,1)
125 // (0, (1,16),(),())data parallel dev_matrix=(16,1) map_in=(1) map_out=(0,1)
126 // (0, (8,2),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between
127 // machines dev_matrix=(2,8) map_in=(1) map_out=(0,1) (0,(4,2),(),())16 devices dev_matrix=(2,4,2) map_in=(1)
128 // map_out=(0,1)
ExtractInputInfo()129 Status OneHotInfo::ExtractInputInfo() {
130 CheckGlobalDeviceManager();
131 rank_ = g_device_manager->rank_index_in_stage();
132 mod_rank_ = rank_ % old_dev_matrix_back_;
133 if (!cnode_) {
134 MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr";
135 return FAILED;
136 }
137 if (cnode_->inputs().size() != 5) {
138 MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, real input size is "
139 << cnode_->inputs().size();
140 return FAILED;
141 }
142 if (input_value_.size() != 4) {
143 MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, and input value size "
144 "must be 4, real size is "
145 << input_value_.size();
146 return FAILED;
147 }
148 auto value_ptr = input_value_.at(1);
149 if (value_ptr == nullptr) {
150 MS_LOG(WARNING) << "Input 2 of cnode is not a value node, its type is " << cnode_->input(2)->type_name();
151 return FAILED;
152 }
153
154 if (value_ptr->isa<Int64Imm>()) {
155 total_class_number_ = value_ptr->cast<Int64ImmPtr>()->value();
156 } else {
157 MS_LOG(ERROR) << "OneHot Primitive depth type must be int64_t";
158 return FAILED;
159 }
160 classes_each_device_ = total_class_number_ / old_dev_matrix_back_;
161
162 return SUCCESS;
163 }
164
ComputeReplaceGraph(const CNodePtr & cnode)165 Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
166 if (old_dev_matrix_back_ == 1) {
167 replace_graph_ = nullptr;
168 return SUCCESS;
169 }
170 if (ExtractInputInfo() != SUCCESS) {
171 MS_LOG(ERROR) << "ExtractInputInfo failed";
172 return FAILED;
173 }
174 GenerateGraph gen_g = GenerateGraph(attrs_);
175 Status status = gen_g.Init(cnode);
176 if (status != SUCCESS) {
177 MS_LOG(ERROR) << "GenerateGraph Init failed";
178 return FAILED;
179 }
180
181 auto floor_div =
182 gen_g.PushBack({gen_g.NewOpInst(FLOORDIV), gen_g.virtual_input_node(), CreateInt32Tensor(classes_each_device_)});
183 auto mul1 = gen_g.PushBack({gen_g.NewOpInst(MUL), floor_div, CreateInt32Tensor(classes_each_device_)});
184 auto sub1 = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), mul1});
185 auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)});
186 auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)});
187 auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast});
188 auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(ADD), mul2, CreateInt32Tensor(1)});
189 auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add});
190 auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)});
191 Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_);
192 OperatorAttrs attrs_onehot = {attr_onehot_axis};
193 auto onehot = gen_g.PushBack({gen_g.NewOpInst(ONEHOT, attrs_onehot), sub2, CreatInt64Imm(classes_each_device_),
194 cnode->input(3), cnode->input(4)});
195 std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(floor_div, 1), std::make_pair(sub1, 1)};
196 replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
197 std::make_pair(input_nodes, onehot));
198
199 return SUCCESS;
200 }
201
replace_graph(const CNodePtr & cnode)202 ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) {
203 if (ComputeReplaceGraph(cnode) != SUCCESS) {
204 MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
205 return nullptr;
206 }
207 return replace_graph_;
208 }
209
Init(const StrategyPtr & strategy)210 Status OneHotInfo::Init(const StrategyPtr &strategy) {
211 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
212 MS_LOG(ERROR) << name_ << ": Init failed.";
213 return FAILED;
214 }
215 Status status = ComputeReplaceGraph(cnode_);
216 if (status != SUCCESS) {
217 MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
218 return status;
219 }
220 MS_LOG(INFO) << name_ << ": Init success.";
221 return SUCCESS;
222 }
223
InitForCostModel(const StrategyPtr & strategy)224 Status OneHotInfo::InitForCostModel(const StrategyPtr &strategy) {
225 if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
226 MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
227 return FAILED;
228 }
229 MS_LOG(INFO) << name_ << ": Init for cost model success.";
230 return SUCCESS;
231 }
232
GenerateOpStrategies(int64_t stage_id)233 std::vector<StrategyPtr> OneHotInfo::GenerateOpStrategies(int64_t stage_id) {
234 Shapes splittable_inputs = {{1, 1}, {}, {}};
235 std::vector<StrategyPtr> sp_vector;
236 if (inputs_shape_.size() != 3) {
237 MS_LOG(EXCEPTION) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size();
238 }
239 if (outputs_shape_.size() != 1) {
240 MS_LOG(EXCEPTION) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size();
241 }
242 if (GenerateStrategiesForIndependentInputs(stage_id, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)},
243 splittable_inputs, &sp_vector) != SUCCESS) {
244 MS_LOG(EXCEPTION) << name_ << ": GenerateStrategies failed.";
245 }
246
247 return sp_vector;
248 }
249
SetCostUnderStrategy(const StrategyPtr & strategy)250 Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
251
GenerateBatchStrategies()252 std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() {
253 Dimensions strategy = {stage_device_size_, 1};
254 Dimensions empty_strategy;
255 Strategys strategy_v = {strategy, empty_strategy, empty_strategy};
256 return std::make_shared<Strategys>(strategy_v);
257 }
258 } // namespace parallel
259 } // namespace mindspore
260