• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/onehot_info.h"
18 
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "ir/value.h"
24 #include "frontend/parallel/auto_parallel/costmodel.h"
25 #include "frontend/parallel/device_matrix.h"
26 #include "frontend/parallel/graph_util/generate_graph.h"
27 #include "frontend/parallel/strategy.h"
28 #include "utils/log_adapter.h"
29 
30 namespace mindspore {
31 namespace parallel {
GetAttrs()32 Status OneHotInfo::GetAttrs() {
33   auto iter = attrs_.find(AXIS);
34   if (iter != attrs_.end()) {
35     MS_EXCEPTION_IF_NULL(iter->second);
36     if (iter->second->isa<Int64Imm>()) {
37       axis_value_ptr_ = iter->second;
38       axis_ = iter->second->cast<Int64ImmPtr>()->value();
39     } else {
40       MS_LOG(ERROR) << name_ << ": The value of axis is not int64_t.";
41       return FAILED;
42     }
43   }
44 
45   if ((axis_ > 1) || (axis_ < -1)) {
46     MS_LOG(ERROR) << name_ << ": Axis " << axis_ << " is out of range[-1, 1].";
47     return FAILED;
48   }
49   return SUCCESS;
50 }
51 
CheckStrategy(const StrategyPtr & strategy)52 Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) {
53   if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}) != SUCCESS) {
54     return FAILED;
55   }
56   auto stra = strategy->GetInputDim().at(0);
57   bool invalid = false;
58   for (size_t i = 1; i < stra.size(); ++i) {
59     if (stra.at(i) != 1) {
60       invalid = true;
61       break;
62     }
63   }
64   if ((inputs_shape_.at(0).size() > 1) && ((axis_ != -1) || invalid)) {
65     MS_LOG(ERROR) << "When input dimension is > 1, axis must be -1, and strategy must be data parallel.";
66     return FAILED;
67   }
68   return SUCCESS;
69 }
70 
InferDevMatrixShape()71 Status OneHotInfo::InferDevMatrixShape() {
72   Strategys stra = strategy_->GetInputDim();
73   Dimensions input_strategy = stra.at(0);
74 
75   if (axis_ == 0) {
76     // Here, only support 1-D input tensor, so the output is a 2-D tensor
77     // If input is a vector of length features, the output shape will be:
78     // [depth, features] if axis == 0
79     dev_matrix_shape_.push_back(input_strategy[1]);  // the depth is un-splittable
80     dev_matrix_shape_.push_back(input_strategy[0]);  // the features is splittable
81   } else {
82     for (const auto &input_stra : input_strategy) {
83       dev_matrix_shape_.push_back(input_stra);
84     }
85   }
86   old_dev_matrix_back_ = dev_matrix_shape_.back();
87   if (old_dev_matrix_back_ == 1) {
88     repeated_num_in_dev_matrix_right_ = true;
89   } else {
90     repeated_num_in_dev_matrix_right_ = false;
91   }
92   return SUCCESS;
93 }
94 
InferTensorMap()95 Status OneHotInfo::InferTensorMap() {
96   Shape input_tensor_map_index, output_tensor_map_index;
97   size_t size = outputs_shape_[0].size();
98   if (axis_ == 0) {
99     for (size_t i = 0; i < size; ++i) {
100       output_tensor_map_index.push_back((int64_t)(i));
101     }
102     input_tensor_map_index.push_back(1);
103   } else {
104     for (size_t i = 0; i < size; ++i) {
105       output_tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i));
106     }
107     for (size_t i = 0; i < size - 1; ++i) {
108       input_tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i));
109     }
110   }
111   outputs_tensor_map_.push_back(output_tensor_map_index);
112 
113   inputs_tensor_map_.push_back(input_tensor_map_index);
114   return SUCCESS;
115 }
116 
117 // axis = -1
118 // (0,(1,16),(),())reid   dev_matrix=(1,16)  map_in=(1) map_out=(1,0)
119 // (0,(16,1),(),())data parallel dev_matrix=(16,1)  map_in=(1) map_out=(1,0)
120 // (0,(2,8),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between
121 // machines dev_matrix=(2,8)  map_in=(1) map_out=(1,0) (0, (2,4),(),())16 devices dev_matrix=(2,4,2)  map_in=(1)
122 // map_out=(1,0)
123 // axis = 0
124 // (0, (16,1),(),())reid   dev_matrix=(1,16)  map_in=(1) map_out=(0,1)
125 // (0, (1,16),(),())data parallel dev_matrix=(16,1)  map_in=(1) map_out=(0,1)
126 // (0, (8,2),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between
127 // machines dev_matrix=(2,8)  map_in=(1) map_out=(0,1) (0,(4,2),(),())16 devices dev_matrix=(2,4,2)  map_in=(1)
128 // map_out=(0,1)
ExtractInputInfo()129 Status OneHotInfo::ExtractInputInfo() {
130   CheckGlobalDeviceManager();
131   rank_ = g_device_manager->rank_index_in_stage();
132   mod_rank_ = rank_ % old_dev_matrix_back_;
133   if (!cnode_) {
134     MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr";
135     return FAILED;
136   }
137   if (cnode_->inputs().size() != 5) {
138     MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, real input size is "
139                   << cnode_->inputs().size();
140     return FAILED;
141   }
142   if (input_value_.size() != 4) {
143     MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, and input value size "
144                      "must be 4, real size is "
145                   << input_value_.size();
146     return FAILED;
147   }
148   auto value_ptr = input_value_.at(1);
149   if (value_ptr == nullptr) {
150     MS_LOG(WARNING) << "Input 2 of cnode is not a value node, its type is " << cnode_->input(2)->type_name();
151     return FAILED;
152   }
153 
154   if (value_ptr->isa<Int64Imm>()) {
155     total_class_number_ = value_ptr->cast<Int64ImmPtr>()->value();
156   } else {
157     MS_LOG(ERROR) << "OneHot Primitive depth type must be int64_t";
158     return FAILED;
159   }
160   classes_each_device_ = total_class_number_ / old_dev_matrix_back_;
161 
162   return SUCCESS;
163 }
164 
ComputeReplaceGraph(const CNodePtr & cnode)165 Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
166   if (old_dev_matrix_back_ == 1) {
167     replace_graph_ = nullptr;
168     return SUCCESS;
169   }
170   if (ExtractInputInfo() != SUCCESS) {
171     MS_LOG(ERROR) << "ExtractInputInfo failed";
172     return FAILED;
173   }
174   GenerateGraph gen_g = GenerateGraph(attrs_);
175   Status status = gen_g.Init(cnode);
176   if (status != SUCCESS) {
177     MS_LOG(ERROR) << "GenerateGraph Init failed";
178     return FAILED;
179   }
180 
181   auto floor_div =
182     gen_g.PushBack({gen_g.NewOpInst(FLOORDIV), gen_g.virtual_input_node(), CreateInt32Tensor(classes_each_device_)});
183   auto mul1 = gen_g.PushBack({gen_g.NewOpInst(MUL), floor_div, CreateInt32Tensor(classes_each_device_)});
184   auto sub1 = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), mul1});
185   auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)});
186   auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)});
187   auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast});
188   auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(ADD), mul2, CreateInt32Tensor(1)});
189   auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add});
190   auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)});
191   Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_);
192   OperatorAttrs attrs_onehot = {attr_onehot_axis};
193   auto onehot = gen_g.PushBack({gen_g.NewOpInst(ONEHOT, attrs_onehot), sub2, CreatInt64Imm(classes_each_device_),
194                                 cnode->input(3), cnode->input(4)});
195   std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(floor_div, 1), std::make_pair(sub1, 1)};
196   replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
197     std::make_pair(input_nodes, onehot));
198 
199   return SUCCESS;
200 }
201 
replace_graph(const CNodePtr & cnode)202 ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) {
203   if (ComputeReplaceGraph(cnode) != SUCCESS) {
204     MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
205     return nullptr;
206   }
207   return replace_graph_;
208 }
209 
Init(const StrategyPtr & strategy)210 Status OneHotInfo::Init(const StrategyPtr &strategy) {
211   if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
212     MS_LOG(ERROR) << name_ << ": Init failed.";
213     return FAILED;
214   }
215   Status status = ComputeReplaceGraph(cnode_);
216   if (status != SUCCESS) {
217     MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
218     return status;
219   }
220   MS_LOG(INFO) << name_ << ": Init success.";
221   return SUCCESS;
222 }
223 
InitForCostModel(const StrategyPtr & strategy)224 Status OneHotInfo::InitForCostModel(const StrategyPtr &strategy) {
225   if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
226     MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
227     return FAILED;
228   }
229   MS_LOG(INFO) << name_ << ": Init for cost model success.";
230   return SUCCESS;
231 }
232 
GenerateOpStrategies(int64_t stage_id)233 std::vector<StrategyPtr> OneHotInfo::GenerateOpStrategies(int64_t stage_id) {
234   Shapes splittable_inputs = {{1, 1}, {}, {}};
235   std::vector<StrategyPtr> sp_vector;
236   if (inputs_shape_.size() != 3) {
237     MS_LOG(EXCEPTION) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size();
238   }
239   if (outputs_shape_.size() != 1) {
240     MS_LOG(EXCEPTION) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size();
241   }
242   if (GenerateStrategiesForIndependentInputs(stage_id, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)},
243                                              splittable_inputs, &sp_vector) != SUCCESS) {
244     MS_LOG(EXCEPTION) << name_ << ": GenerateStrategies failed.";
245   }
246 
247   return sp_vector;
248 }
249 
SetCostUnderStrategy(const StrategyPtr & strategy)250 Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
251 
GenerateBatchStrategies()252 std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() {
253   Dimensions strategy = {stage_device_size_, 1};
254   Dimensions empty_strategy;
255   Strategys strategy_v = {strategy, empty_strategy, empty_strategy};
256   return std::make_shared<Strategys>(strategy_v);
257 }
258 }  // namespace parallel
259 }  // namespace mindspore
260