• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/arithmetic_info.h"
18 
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 
23 #include "frontend/parallel/device_matrix.h"
24 #include "frontend/parallel/strategy.h"
25 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
26 
27 namespace mindspore {
28 namespace parallel {
ExpendShape(const Shape & bigger_size_shape,Shape smaller_size_shape)29 Shape ExpendShape(const Shape &bigger_size_shape, Shape smaller_size_shape) {
30   size_t insert_num = bigger_size_shape.size() - smaller_size_shape.size();
31   for (size_t num = 0; num < insert_num; ++num) {
32     (void)smaller_size_shape.insert(smaller_size_shape.begin(), 1);
33   }
34   return smaller_size_shape;
35 }
36 
InferExpendShape()37 Shapes ArithmeticBase::InferExpendShape() {
38   Shape input_a_shape = inputs_shape_.at(0);
39   Shape input_b_shape = inputs_shape_.at(1);
40   Shapes input_shapes;
41   size_t input_a_size = input_a_shape.size();
42   size_t input_b_size = input_b_shape.size();
43   if (input_a_size > input_b_size) {
44     input_shapes.push_back(input_a_shape);
45     input_shapes.push_back(ExpendShape(input_a_shape, input_b_shape));
46   } else if (input_a_size < input_b_size) {
47     input_shapes.push_back(ExpendShape(input_b_shape, input_a_shape));
48     input_shapes.push_back(input_b_shape);
49   } else {
50     input_shapes.push_back(input_a_shape);
51     input_shapes.push_back(input_b_shape);
52   }
53   return input_shapes;
54 }
55 
ExpendStrategy(const StrategyPtr & strategy)56 Strategys ExpendStrategy(const StrategyPtr &strategy) {
57   Strategys expend_strategy;
58   Strategys stra = strategy->GetInputDim();
59   Dimensions sub_a_strategy = stra.at(0);
60   Dimensions sub_b_strategy = stra.at(1);
61   size_t input_a_size = sub_a_strategy.size();
62   size_t input_b_size = sub_b_strategy.size();
63   if (input_a_size > input_b_size) {
64     expend_strategy.push_back(sub_a_strategy);
65     expend_strategy.push_back(ExpendShape(sub_a_strategy, sub_b_strategy));
66   } else if (input_a_size < input_b_size) {
67     expend_strategy.push_back(ExpendShape(sub_b_strategy, sub_a_strategy));
68     expend_strategy.push_back(sub_b_strategy);
69   } else {
70     expend_strategy = stra;
71   }
72   return expend_strategy;
73 }
74 
CheckStrategy(const StrategyPtr & strategy)75 Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
76   if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
77     MS_LOG(ERROR) << name_ << " : Invalid strategy.";
78     return FAILED;
79   }
80   Shapes input_shapes = InferExpendShape();
81   Strategys expend_strategy = ExpendStrategy(strategy);
82   Dimensions sub_a_strategy = expend_strategy.at(0);
83   Dimensions sub_b_strategy = expend_strategy.at(1);
84   Shape input_a_shape = input_shapes.at(0);
85   Shape input_b_shape = input_shapes.at(1);
86 
87   for (size_t i = 0; i < input_a_shape.size(); ++i) {
88     if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != 1) && (input_b_shape[i] != 1)) {
89       MS_LOG(ERROR) << name_ << " : Invalid strategy.";
90       return FAILED;
91     }
92   }
93   return SUCCESS;
94 }
95 
InferDevMatrixShape()96 Status ArithmeticBase::InferDevMatrixShape() {
97   Strategys expend_strategy = ExpendStrategy(strategy_);
98   Dimensions sub_a_strategy = expend_strategy.at(0);
99   Dimensions sub_b_strategy = expend_strategy.at(1);
100   Shape dev_shape;
101   for (size_t i = 0; i < sub_a_strategy.size(); ++i) {
102     if (sub_a_strategy[i] != sub_b_strategy[i]) {
103       dev_shape.push_back(sub_a_strategy[i] * sub_b_strategy[i]);
104     } else {
105       dev_shape.push_back(sub_a_strategy[i]);
106     }
107   }
108   dev_matrix_shape_ = dev_shape;
109 
110   return SUCCESS;
111 }
112 
SetExpendTensorMap(const Shape & strategy,const Shape & dev_matrix_shape)113 TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shape) {
114   TensorMap tensor_map_index;
115   for (size_t i = 0; i < strategy.size(); ++i) {
116     if (strategy[i] == dev_matrix_shape[i]) {
117       tensor_map_index.push_back((int64_t)(LAST_INDEX(strategy.size()) - i));
118     } else {
119       tensor_map_index.push_back(-1);
120     }
121   }
122   return tensor_map_index;
123 }
124 
SetTensorMap(const Shape & strategy_expend,const Shape & dev_matrix_shape,const Shape & strategy)125 TensorMap SetTensorMap(const Shape &strategy_expend, const Shape &dev_matrix_shape, const Shape &strategy) {
126   TensorMap expend_map = SetExpendTensorMap(strategy_expend, dev_matrix_shape);
127   size_t dev_matrix_size = dev_matrix_shape.size();
128   size_t strategy_size = strategy.size();
129   if (dev_matrix_size != strategy_size) {
130     (void)expend_map.erase(expend_map.begin(),
131                            expend_map.begin() + static_cast<different_type>(dev_matrix_size - strategy_size));
132   }
133   return expend_map;
134 }
135 
ReComputeBatchSplitFlagList()136 void ArithmeticBase::ReComputeBatchSplitFlagList() {
137   Shapes expend_shapes = InferExpendShape();
138   Shape expend_a_shape = expend_shapes.at(0);
139   Shape expend_b_shape = expend_shapes.at(1);
140   if (expend_a_shape.size() != expend_b_shape.size()) {
141     MS_LOG(EXCEPTION) << name_ << " : Recompute batch split flag list is wrong.";
142   }
143   if (expend_a_shape.empty()) {
144     split_flag_list_[0] = false;
145     split_flag_list_[1] = false;
146     return;
147   }
148   (expend_a_shape.at(0) != 1) ? (split_flag_list_[0] = true) : (split_flag_list_[0] = false);
149   (expend_b_shape.at(0) != 1) ? (split_flag_list_[1] = true) : (split_flag_list_[1] = false);
150 }
151 
InferTensorMap()152 Status ArithmeticBase::InferTensorMap() {
153   Shape tensor_map_index;
154   Strategys expend_strategy = ExpendStrategy(strategy_);
155   Dimensions sub_a_expend_strategy = expend_strategy.at(0);
156   Dimensions sub_b_expend_strategy = expend_strategy.at(1);
157   Strategys stra = strategy_->GetInputDim();
158   Dimensions sub_a_strategy = stra.at(0);
159   Dimensions sub_b_strategy = stra.at(1);
160   for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) {
161     tensor_map_index.push_back((int64_t)(LAST_INDEX(sub_a_expend_strategy.size()) - i));
162   }
163 
164   Shape dev_shape;
165   for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) {
166     if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) {
167       dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]);
168     } else {
169       dev_shape.push_back(sub_a_expend_strategy[i]);
170     }
171   }
172   inputs_tensor_map_.push_back(SetTensorMap(sub_a_expend_strategy, dev_shape, sub_a_strategy));
173   inputs_tensor_map_.push_back(SetTensorMap(sub_b_expend_strategy, dev_shape, sub_b_strategy));
174   outputs_tensor_map_.push_back(tensor_map_index);
175 
176   return SUCCESS;
177 }
178 
SetCostUnderStrategy(const StrategyPtr & strategy)179 Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
180 
GenerateOpStrategies(int64_t stage_id)181 std::vector<StrategyPtr> ArithmeticBase::GenerateOpStrategies(int64_t stage_id) {
182   Shape input0_split(inputs_shape_[0].size(), 1);
183   Shape input1_split(inputs_shape_[1].size(), 1);
184   Shapes splittable_inputs = {input0_split, input1_split};
185 
186   std::vector<StrategyPtr> sp_vector;
187   if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
188     MS_LOG(EXCEPTION) << name_ << " : Generate strategies with broadcast failed.";
189   }
190   MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success.";
191 
192   return sp_vector;
193 }
194 
Init(const StrategyPtr & strategy)195 Status ArithmeticBase::Init(const StrategyPtr &strategy) {
196   if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
197     MS_LOG(ERROR) << name_ << " : Init failed.";
198     return FAILED;
199   }
200   MS_LOG(INFO) << name_ << " : Init success.";
201   return SUCCESS;
202 }
203 
InitForCostModel(const StrategyPtr & strategy)204 Status ArithmeticBase::InitForCostModel(const StrategyPtr &strategy) {
205   if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
206     MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
207     return FAILED;
208   }
209 
210   MS_LOG(INFO) << name_ << " : Init for cost model success.";
211   return SUCCESS;
212 }
213 }  // namespace parallel
214 }  // namespace mindspore
215