• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_
17 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
25 #include "frontend/parallel/ops_info/operator_info.h"
26 #include "frontend/parallel/strategy.h"
27 #include "ir/value.h"
28 #include "utils/hash_map.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 constexpr char BATCH_DIMS[] = "batch_dims";
33 enum GatherMode {
34   BATCH = 0,
35   NORMAL,
36   MANUAL,
37   SHARD_BATCH_AND_AXIS,
38   SHARD_AXIS_0_DYNAMIC,
39   SHARD_AXIS_0_STATIC,
40   SHARD_AXIS_1,
41   INVALID
42 };
43 
44 class GatherUtil;
45 using GatherUtilPtr = std::shared_ptr<GatherUtil>;
46 
47 class GatherUtil {
48  public:
GatherUtil(std::string name,Shapes inputs_shape,Shapes outputs_shape,int64_t axis)49   GatherUtil(std::string name, Shapes inputs_shape, Shapes outputs_shape, int64_t axis)
50       : name_(std::move(name)),
51         inputs_shape_(std::move(inputs_shape)),
52         outputs_shape_(std::move(outputs_shape)),
53         axis_(axis) {
54     inputs_shape_clone_ = inputs_shape_;
55     outputs_shape_clone_ = outputs_shape_;
56   }
57   virtual ~GatherUtil() = default;
58   virtual Status CheckStrategy(const Shape &param_strategy, const Shape &indices_strategy) = 0;
InferForwardCommunication()59   virtual Status InferForwardCommunication() { return SUCCESS; }
60   virtual Status InferTensorInfo() = 0;
61   virtual Status InferDevMatrixShape() = 0;
62   virtual Status InferTensorMap() = 0;
InferReplaceGraph(const CNodePtr & cnode)63   virtual Status InferReplaceGraph(const CNodePtr &cnode) { return SUCCESS; }
InferReplaceOps()64   virtual Status InferReplaceOps() { return SUCCESS; }
65 
set_param_strategy(const Shape & a)66   void set_param_strategy(const Shape &a) { param_strategy_ = a; }
set_indices_strategy(const Shape & a)67   void set_indices_strategy(const Shape &a) { indices_strategy_ = a; }
set_gather_mode(const GatherMode & a)68   void set_gather_mode(const GatherMode &a) { gather_mode_ = a; }
set_inputs_divisor(const Shapes & a)69   void set_inputs_divisor(const Shapes &a) { inputs_divisor_ = a; }
set_outputs_divisor(const Shapes & a)70   void set_outputs_divisor(const Shapes &a) { outputs_divisor_ = a; }
set_dynamic_shape_flag(bool a)71   void set_dynamic_shape_flag(bool a) { dynamic_shape_flag_ = a; }
gather_mode()72   GatherMode gather_mode() const { return gather_mode_; }
dev_matrix_shape()73   Shape dev_matrix_shape() const { return dev_matrix_shape_; }
set_dev_matrix_shape(const Shape & a)74   void set_dev_matrix_shape(const Shape &a) { dev_matrix_shape_ = a; }
inputs_tensor_map()75   TensorMaps inputs_tensor_map() const { return inputs_tensor_map_; }
outputs_tensor_map()76   TensorMaps outputs_tensor_map() const { return outputs_tensor_map_; }
set_inputs_tensor_map(const TensorMaps & a)77   void set_inputs_tensor_map(const TensorMaps &a) { inputs_tensor_map_ = a; }
set_outputs_tensor_map(const TensorMaps & a)78   void set_outputs_tensor_map(const TensorMaps &a) { outputs_tensor_map_ = a; }
inputs_tensor_info()79   std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; }
outputs_tensor_info()80   std::vector<TensorInfo> outputs_tensor_info() const { return outputs_tensor_info_; }
forward_op()81   ForwardOp forward_op() const { return forward_op_; }
replace_op()82   ForwardOp replace_op() const { return replace_op_; }
replace_graph()83   ReplaceGraphPtr replace_graph() const { return replace_graph_; }
repeated_num_in_dev_matrix_right()84   bool repeated_num_in_dev_matrix_right() const { return repeated_num_in_dev_matrix_right_; }
out_dev_matrix_shape()85   Shape out_dev_matrix_shape() const { return out_dev_matrix_shape_; }
GatherModeToString()86   std::string GatherModeToString() const { return gather_mode_string_[gather_mode_]; }
87   void DivisorsReplaceShapes();
88   void ResumeShapes();
89 
90  protected:
91   std::string name_;
92   Shapes inputs_shape_;
93   Shapes outputs_shape_;
94   Shapes inputs_shape_clone_;
95   Shapes outputs_shape_clone_;
96   Shapes inputs_divisor_;
97   Shapes outputs_divisor_;
98   int64_t axis_;
99 
100   Shape param_strategy_;
101   Shape indices_strategy_;
102   GatherMode gather_mode_ = INVALID;
103   Shape dev_matrix_shape_;
104   TensorMaps inputs_tensor_map_;
105   TensorMaps outputs_tensor_map_;
106   std::vector<TensorInfo> inputs_tensor_info_;
107   std::vector<TensorInfo> outputs_tensor_info_;
108   ForwardOp forward_op_;
109   ForwardOp replace_op_;
110   ReplaceGraphPtr replace_graph_;
111 
112   Status InferTensorInfoNoSplitAxis();
113   bool repeated_num_in_dev_matrix_right_ = true;  // only for shard axis
114   Shape out_dev_matrix_shape_;                    // only for shard axis
115   bool dynamic_shape_flag_ = False;
116 
117  private:
118   const std::vector<std::string> gather_mode_string_ = {
119     "batch",        "normal", "manual", "shard_batch_and_axis", "shard_axis_0_dynamic", "shard_axis_0_static",
120     "shard_axis_1", "invalid"};
121 };
122 
123 // batch mode: batch_dims > 1
124 // constraint:
125 //   1) axis can not be split
126 //   2) can not set out_strategy
127 // param  shape: [A, B, C, D,E]
128 // indices shape: [A, B, F, G]
129 // batch_dims = 2
130 // axis = 3
131 // out = gather(param,  indices,  axis)
132 // out shape: [A, B, C, F, G, E]
133 // parameter's strategy: [a, b, c, 1, e], indices' strategy: [a, b, f, g]
134 // output's strategy: [a, b, c, f, g, e]
135 // dev_matrix: [a, b, f, g, c, 1, e]
136 class BatchImpl : public GatherUtil {
137  public:
BatchImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)138   BatchImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis)
139       : GatherUtil(name, inputs_shape, outputs_shape, axis) {}
140   ~BatchImpl() override = default;
141   Status CheckStrategy(const Shape &param_strategy, const Shape &indices_strategy) override;
142   Status InferDevMatrixShape() override;
set_batch_dims(int64_t batch_dims)143   void set_batch_dims(int64_t batch_dims) { batch_dims_ = batch_dims; }
144   Status InferTensorMap() override;
InferTensorInfo()145   Status InferTensorInfo() override { return InferTensorInfoNoSplitAxis(); }
146 
147  private:
148   int64_t batch_dims_ = 0;
149 };
150 
151 // normal mode: batch_dims = 0, and the axis has not be split
152 // constraint:
153 //   1) can not set out_strategy
154 // param  shape: [C, D,E]
155 // indices shape: [F, G]
156 // batch_dims = 0
157 // axis = 1
158 // out = gather(param,  indices,  axis)
159 // out shape: [C, F, G, E]
160 // parameter's strategy: [c, 1, e], indices' strategy: [f, g]
161 // output's strategy: [c, f, g, e]
162 // dev_matrix: [f, g, c, 1, e]
163 class NormalImpl : public GatherUtil {
164  public:
NormalImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)165   NormalImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis)
166       : GatherUtil(name, inputs_shape, outputs_shape, axis) {}
167   ~NormalImpl() override = default;
168   Status CheckStrategy(const Shape &param_strategy, const Shape &indices_strategy) override;
169   Status InferDevMatrixShape() override;
170   Status InferTensorMap() override;
InferTensorInfo()171   Status InferTensorInfo() override { return InferTensorInfoNoSplitAxis(); }
172 };
173 
174 // manual mode: the primitive has the "manual_split" attr, axis = 0, batch_dims = 0
175 // constraint:
176 //   1) the field dimension of indices is the last dimension;
177 //   2) can not support repeated calculation
178 //   3) parameter's dim >= 1, indices' dim >= 1
179 //   4) can not set out_strategy
180 // param  shape: [A, B, ..., C]
181 // indices shape: [D, ..., E, F]
182 // batch_dims = 0
183 // axis = 0
184 // out = gather(param,  indices,  axis)
185 // out shape: [D, ..., E, F, B, ..., C]
186 // parameter's strategy: [a, b, ..., c], indices' strategy: [1, ..., 1, a]
187 // output's strategy: [1, ..., 1, a, b, ..., c]
188 // dev_matrix: [a, b, ..., c]
189 class ManualImpl : public GatherUtil {
190  public:
ManualImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)191   ManualImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis)
192       : GatherUtil(name, inputs_shape, outputs_shape, axis) {}
193   ~ManualImpl() override = default;
194   Status CheckStrategy(const Shape &param_strategy, const Shape &indices_strategy) override;
195   Status InferDevMatrixShape() override;
196   Status InferTensorMap() override;
197   Status InferTensorInfo() override;
198   Status InferReplaceGraph(const CNodePtr &cnode) override;
199   Status InferReplaceOps() override;
200 
set_param_split_shapes(const Shape & a)201   void set_param_split_shapes(const Shape &a) { param_split_shapes_ = a; }
set_index_offsets(const Shape & a)202   void set_index_offsets(const Shape &a) { index_offsets_ = a; }
set_target(const std::string & a)203   void set_target(const std::string &a) { target_ = a; }
set_attrs(const mindspore::HashMap<std::string,ValuePtr> & a)204   void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &a) { attrs_ = a; }
set_replace_op_name(const std::string & a)205   void set_replace_op_name(const std::string &a) { replace_op_name_ = a; }
206 
207  protected:
208   Status InferOffset();
209   std::string target_ = DEVICE;
210   mindspore::HashMap<std::string, ValuePtr> attrs_;
211   std::string replace_op_name_;
212   int64_t index_offset_ = 0;
213 
214  private:
215   Shape param_split_shapes_;
216   Shape index_offsets_;
217 };
218 
219 class GatherManualImpl : public ManualImpl {
220  public:
GatherManualImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)221   GatherManualImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis)
222       : ManualImpl(name, inputs_shape, outputs_shape, axis) {}
223   ~GatherManualImpl() override = default;
224   Status InferReplaceGraph(const CNodePtr &cnode) override;
225 };
226 
227 // SHARD_AXIS_0_DYNAMIC, SHARD_AXIS_0_STATIC and SHARD_AXIS_1 mode: batch_dims = 0, and split axis
228 // constraint:
229 //   1) parameter's dim is 1 or 2, indices' dim >= 1
230 //   2) indices can't be split
231 //   3) axis = 0 or axis = 1
232 //   4) if axis = 1, can not support repeated calculation
233 //   5) if axis = 0, and param_shape[1] is split, can not support repeated calculation
234 // param  shape: [A, B]
235 // indices shape: [C, D]
236 // batch_dims = 0
237 // axis = 0
238 // out = gather(param,  indices,  axis)
239 // out shape: [A, B, C]
240 // parameter's strategy: [a, b], indices' strategy: [1, 1]
241 // output's strategy:
242 //   1) if use allreduce: [1, 1, b]
243 //   2) if use reducescatter: [a, 1, b]
244 // dev_matrix: [a, b]
245 class ShardAxisImpl : public GatherUtil {
246  public:
ShardAxisImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)247   ShardAxisImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis)
248       : GatherUtil(name, inputs_shape, outputs_shape, axis) {}
249   ~ShardAxisImpl() override = default;
250   Status CheckStrategy(const Shape &param_strategy, const Shape &indices_strategy) override;
251   Status InferDevMatrixShape() override;
252   Status InferTensorMap() override;
253   Status InferTensorInfo() override;
254   virtual Status InferBias();
set_target(const std::string & a)255   void set_target(const std::string &a) { target_ = a; }
set_dynamic_shape_indices(bool a)256   void set_dynamic_shape_indices(bool a) { dynamic_shape_indices_ = a; }
set_attrs(const mindspore::HashMap<std::string,ValuePtr> & a)257   void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &a) { attrs_ = a; }
set_replace_op_name(const std::string & a)258   void set_replace_op_name(const std::string &a) { replace_op_name_ = a; }
set_axis_split_forward_allreduce(bool a)259   void set_axis_split_forward_allreduce(bool a) { axis_split_forward_allreduce_ = a; }
260 
261   // ShardBatchAndAxisImpl and ShardAxisImpl
262   Status InferForwardCommunication() override;
263   Status InferReplaceOps() override;
264   Status InferReplaceGraph(const CNodePtr &cnode) override;
set_assigned_parallel(bool is_assigned_parallel)265   void set_assigned_parallel(bool is_assigned_parallel) { is_assigned_parallel_ = is_assigned_parallel; }
266 
267  protected:
268   // use for split axis
269   Status CheckSplitAxisStrategy(const Shape &param_strategy, const Shape &indices_strategy);
270   void SetAttribute(const Shape &param_strategy);
271   Status InferGroup();
272   std::string target_ = DEVICE;
273   std::string replace_op_name_;
274   bool dynamic_shape_indices_ = false;
275   bool is_assigned_parallel_ = false;
276   bool axis_split_forward_allreduce_ = false;  // when axis is split, use reducescatter as default in forward
277   int64_t repeated_calculation_num_ = 1;
278   Group group_;
279   mindspore::HashMap<std::string, ValuePtr> attrs_;
280   int64_t bias_ = 0;
281   int64_t slice_size_ = 0;
282 };
283 
284 // shard_batch_and_axis mode: axis = 0, batch_dims = 0, and only split the first dimension of parameter and the first
285 // dimension of indices constraint:
286 //   1) the dim of param is 2, and the dim of indices is 2;
287 //   2) only split the first dimension of parameter and the first dimension of indices, other dims can not be split
288 //   3) do not support repeat calculation
289 // param  shape: [A, B]
290 // indices shape: [C, D]
291 // batch_dims = 0
292 // axis = 0
293 // out = gather(param,  indices,  axis)
294 // out shape: [C, D, B]
295 // parameter's strategy: [a, 1], indices' strategy: [c, 1]
296 // output's strategy:
297 //   1) if use allreduce: [c, 1, 1]
298 //   2) if use reducescatter: [a*c, 1, 1]
299 // dev_matrix:
300 //   1) if use allreduce: [c, a]
301 //   2) if use reducescatter: [a*c]
302 class ShardBatchAndAxisImpl : public ShardAxisImpl {
303  public:
ShardBatchAndAxisImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)304   ShardBatchAndAxisImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis)
305       : ShardAxisImpl(name, inputs_shape, outputs_shape, axis) {}
306   ~ShardBatchAndAxisImpl() override = default;
CheckStrategy(const Shape & param_strategy,const Shape & indices_strategy)307   Status CheckStrategy(const Shape &param_strategy, const Shape &indices_strategy) override {
308     return SUCCESS;
309   }  // no need check
310   Status InferDevMatrixShape() override;
311   Status InferTensorMap() override;
InferTensorInfo()312   Status InferTensorInfo() override { return InferTensorInfoNoSplitAxis(); }  // do not need to use out_dev_matrix_shape
313   Status InferBias() override;
314 };
315 
316 class GatherInfo : public OperatorInfo {
317  public:
318   GatherInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
319              const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2)
OperatorInfo(name,inputs_shape,outputs_shape,attrs,std::make_shared<GatherCost> ())320       : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherCost>()),
321         replace_op_name_(replace_op_name) {}
322   ~GatherInfo() override = default;
323   Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
324               const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts = {},
325               const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts = {}) override;
326   Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
327 
328   std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
329   Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
330   ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
331   std::shared_ptr<Strategies> GenerateBatchStrategies() override;
param_split_shapes()332   const std::vector<int64_t> &param_split_shapes() const { return param_split_shapes_; }
index_offsets()333   const std::vector<int64_t> &index_offsets() const { return index_offsets_; }
334   GatherMode GetGatherMode(const Shape &param_strategy, const Shape &indices_strategy) const;
335 
336  protected:
337   Status CheckStrategy(const StrategyPtr &strategy) override;
338   Status CheckOutputStrategy(const StrategyPtr &out_strategy) override;
339   Status CheckStrategyForDynamicShape(const StrategyPtr &strategy) override;
340   Status InferMirrorOps() override;
341   Status InferForwardCommunication() override;
342   Status InferTensorInfo() override;
343   Status InferDevMatrixShape() override;
344   Status InferTensorMap() override;
345   Status GetAttrs() override;
346   virtual void DealWithBatchDimsMirrorOp() noexcept;
347   virtual void GetBatchDims() noexcept;
348   virtual GatherUtilPtr MakeManualUtil();
349   int64_t axis_ = 0;
350 
351  private:
352   int64_t batch_dims_ = 0;
353   Status GetManualSplitAttr();
354   Status GetManualSplitWithoutOffsetAttr();
355   Status ComputeReplaceOp();
356   bool ShardBatchAndAxis(const Shape &param_strategy, const Shape &indices_strategy) const;
357 
358   std::string target_ = DEVICE;
359   int64_t bias_ = 0;
360   std::string replace_op_name_ = GATHERV2;
361   bool manual_split_ = false;
362   bool dynamic_shape_indices_ = false;
363   std::vector<int64_t> param_split_shapes_;  // manual split
364   std::vector<int64_t> index_offsets_;       // manual split
365   GatherMode gather_mode_ = INVALID;
366   GatherUtilPtr gather_util_;
367 };
368 
369 class IndexSelectInfo final : public GatherInfo {
370  public:
371   IndexSelectInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
372                   const PrimitiveAttrs &attrs, const std::string &replace_op_name = INDEX_SELECT)
GatherInfo(name,inputs_shape,outputs_shape,attrs,replace_op_name)373       : GatherInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {}
374   ~IndexSelectInfo() override = default;
375 };
376 
377 class SparseGatherV2Info final : public GatherInfo {
378  public:
379   SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
380                      const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2)
GatherInfo(name,inputs_shape,outputs_shape,attrs,replace_op_name)381       : GatherInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {}
382   ~SparseGatherV2Info() override = default;
383 
384  protected:
DealWithBatchDimsMirrorOp()385   void DealWithBatchDimsMirrorOp() noexcept override {}
GetBatchDims()386   void GetBatchDims() noexcept override {}
387   GatherUtilPtr MakeManualUtil() override;
388 };
389 
390 class EmbeddingLookupInfo final : public GatherInfo {
391  public:
EmbeddingLookupInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)392   EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
393                       const PrimitiveAttrs &attrs)
394       : GatherInfo(name, inputs_shape, outputs_shape, attrs) {}
395   ~EmbeddingLookupInfo() override = default;
396 
397  protected:
DealWithBatchDimsMirrorOp()398   void DealWithBatchDimsMirrorOp() noexcept override {}
GetBatchDims()399   void GetBatchDims() noexcept override {}
400   GatherUtilPtr MakeManualUtil() override;
401 };
402 }  // namespace parallel
403 }  // namespace mindspore
404 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_
405