1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_ 17 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_ 18 19 #include <memory> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "frontend/parallel/auto_parallel/operator_costmodel.h" 25 #include "frontend/parallel/ops_info/operator_info.h" 26 #include "frontend/parallel/strategy.h" 27 #include "ir/value.h" 28 #include "utils/hash_map.h" 29 30 namespace mindspore { 31 namespace parallel { 32 constexpr char BATCH_DIMS[] = "batch_dims"; 33 enum GatherMode { 34 BATCH = 0, 35 NORMAL, 36 MANUAL, 37 SHARD_BATCH_AND_AXIS, 38 SHARD_AXIS_0_DYNAMIC, 39 SHARD_AXIS_0_STATIC, 40 SHARD_AXIS_1, 41 INVALID 42 }; 43 44 class GatherUtil; 45 using GatherUtilPtr = std::shared_ptr<GatherUtil>; 46 47 class GatherUtil { 48 public: GatherUtil(std::string name,Shapes inputs_shape,Shapes outputs_shape,int64_t axis)49 GatherUtil(std::string name, Shapes inputs_shape, Shapes outputs_shape, int64_t axis) 50 : name_(std::move(name)), 51 inputs_shape_(std::move(inputs_shape)), 52 outputs_shape_(std::move(outputs_shape)), 53 axis_(axis) { 54 inputs_shape_clone_ = inputs_shape_; 55 outputs_shape_clone_ = outputs_shape_; 56 } 57 virtual ~GatherUtil() = default; 58 virtual Status CheckStrategy(const Shape ¶m_strategy, const Shape &indices_strategy) = 0; InferForwardCommunication()59 virtual Status InferForwardCommunication() { return SUCCESS; } 60 virtual Status InferTensorInfo() = 0; 61 virtual Status InferDevMatrixShape() = 0; 62 virtual Status InferTensorMap() = 0; InferReplaceGraph(const CNodePtr & cnode)63 virtual Status InferReplaceGraph(const CNodePtr &cnode) { return SUCCESS; } InferReplaceOps()64 virtual Status InferReplaceOps() { return SUCCESS; } 65 set_param_strategy(const Shape & a)66 void set_param_strategy(const Shape &a) { param_strategy_ = a; } set_indices_strategy(const Shape & a)67 void set_indices_strategy(const Shape &a) { indices_strategy_ = a; } set_gather_mode(const GatherMode & a)68 void set_gather_mode(const GatherMode &a) { gather_mode_ = a; } set_inputs_divisor(const Shapes & a)69 void set_inputs_divisor(const Shapes &a) { inputs_divisor_ = a; } set_outputs_divisor(const Shapes & a)70 void set_outputs_divisor(const Shapes &a) { outputs_divisor_ = a; } set_dynamic_shape_flag(bool a)71 void set_dynamic_shape_flag(bool a) { dynamic_shape_flag_ = a; } gather_mode()72 GatherMode gather_mode() const { return gather_mode_; } dev_matrix_shape()73 Shape dev_matrix_shape() const { return dev_matrix_shape_; } set_dev_matrix_shape(const Shape & a)74 void set_dev_matrix_shape(const Shape &a) { dev_matrix_shape_ = a; } inputs_tensor_map()75 TensorMaps inputs_tensor_map() const { return inputs_tensor_map_; } outputs_tensor_map()76 TensorMaps outputs_tensor_map() const { return outputs_tensor_map_; } set_inputs_tensor_map(const TensorMaps & a)77 void set_inputs_tensor_map(const TensorMaps &a) { inputs_tensor_map_ = a; } set_outputs_tensor_map(const TensorMaps & a)78 void set_outputs_tensor_map(const TensorMaps &a) { outputs_tensor_map_ = a; } inputs_tensor_info()79 std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; } outputs_tensor_info()80 std::vector<TensorInfo> outputs_tensor_info() const { return outputs_tensor_info_; } forward_op()81 ForwardOp forward_op() const { return forward_op_; } replace_op()82 ForwardOp replace_op() const { return replace_op_; } replace_graph()83 ReplaceGraphPtr replace_graph() const { return replace_graph_; } repeated_num_in_dev_matrix_right()84 bool repeated_num_in_dev_matrix_right() const { return repeated_num_in_dev_matrix_right_; } out_dev_matrix_shape()85 Shape out_dev_matrix_shape() const { return out_dev_matrix_shape_; } GatherModeToString()86 std::string GatherModeToString() const { return gather_mode_string_[gather_mode_]; } 87 void DivisorsReplaceShapes(); 88 void ResumeShapes(); 89 90 protected: 91 std::string name_; 92 Shapes inputs_shape_; 93 Shapes outputs_shape_; 94 Shapes inputs_shape_clone_; 95 Shapes outputs_shape_clone_; 96 Shapes inputs_divisor_; 97 Shapes outputs_divisor_; 98 int64_t axis_; 99 100 Shape param_strategy_; 101 Shape indices_strategy_; 102 GatherMode gather_mode_ = INVALID; 103 Shape dev_matrix_shape_; 104 TensorMaps inputs_tensor_map_; 105 TensorMaps outputs_tensor_map_; 106 std::vector<TensorInfo> inputs_tensor_info_; 107 std::vector<TensorInfo> outputs_tensor_info_; 108 ForwardOp forward_op_; 109 ForwardOp replace_op_; 110 ReplaceGraphPtr replace_graph_; 111 112 Status InferTensorInfoNoSplitAxis(); 113 bool repeated_num_in_dev_matrix_right_ = true; // only for shard axis 114 Shape out_dev_matrix_shape_; // only for shard axis 115 bool dynamic_shape_flag_ = False; 116 117 private: 118 const std::vector<std::string> gather_mode_string_ = { 119 "batch", "normal", "manual", "shard_batch_and_axis", "shard_axis_0_dynamic", "shard_axis_0_static", 120 "shard_axis_1", "invalid"}; 121 }; 122 123 // batch mode: batch_dims > 1 124 // constraint: 125 // 1) axis can not be split 126 // 2) can not set out_strategy 127 // param shape: [A, B, C, D,E] 128 // indices shape: [A, B, F, G] 129 // batch_dims = 2 130 // axis = 3 131 // out = gather(param, indices, axis) 132 // out shape: [A, B, C, F, G, E] 133 // parameter's strategy: [a, b, c, 1, e], indices' strategy: [a, b, f, g] 134 // output's strategy: [a, b, c, f, g, e] 135 // dev_matrix: [a, b, f, g, c, 1, e] 136 class BatchImpl : public GatherUtil { 137 public: BatchImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)138 BatchImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis) 139 : GatherUtil(name, inputs_shape, outputs_shape, axis) {} 140 ~BatchImpl() override = default; 141 Status CheckStrategy(const Shape ¶m_strategy, const Shape &indices_strategy) override; 142 Status InferDevMatrixShape() override; set_batch_dims(int64_t batch_dims)143 void set_batch_dims(int64_t batch_dims) { batch_dims_ = batch_dims; } 144 Status InferTensorMap() override; InferTensorInfo()145 Status InferTensorInfo() override { return InferTensorInfoNoSplitAxis(); } 146 147 private: 148 int64_t batch_dims_ = 0; 149 }; 150 151 // normal mode: batch_dims = 0, and the axis has not be split 152 // constraint: 153 // 1) can not set out_strategy 154 // param shape: [C, D,E] 155 // indices shape: [F, G] 156 // batch_dims = 0 157 // axis = 1 158 // out = gather(param, indices, axis) 159 // out shape: [C, F, G, E] 160 // parameter's strategy: [c, 1, e], indices' strategy: [f, g] 161 // output's strategy: [c, f, g, e] 162 // dev_matrix: [f, g, c, 1, e] 163 class NormalImpl : public GatherUtil { 164 public: NormalImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)165 NormalImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis) 166 : GatherUtil(name, inputs_shape, outputs_shape, axis) {} 167 ~NormalImpl() override = default; 168 Status CheckStrategy(const Shape ¶m_strategy, const Shape &indices_strategy) override; 169 Status InferDevMatrixShape() override; 170 Status InferTensorMap() override; InferTensorInfo()171 Status InferTensorInfo() override { return InferTensorInfoNoSplitAxis(); } 172 }; 173 174 // manual mode: the primitive has the "manual_split" attr, axis = 0, batch_dims = 0 175 // constraint: 176 // 1) the field dimension of indices is the last dimension; 177 // 2) can not support repeated calculation 178 // 3) parameter's dim >= 1, indices' dim >= 1 179 // 4) can not set out_strategy 180 // param shape: [A, B, ..., C] 181 // indices shape: [D, ..., E, F] 182 // batch_dims = 0 183 // axis = 0 184 // out = gather(param, indices, axis) 185 // out shape: [D, ..., E, F, B, ..., C] 186 // parameter's strategy: [a, b, ..., c], indices' strategy: [1, ..., 1, a] 187 // output's strategy: [1, ..., 1, a, b, ..., c] 188 // dev_matrix: [a, b, ..., c] 189 class ManualImpl : public GatherUtil { 190 public: ManualImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)191 ManualImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis) 192 : GatherUtil(name, inputs_shape, outputs_shape, axis) {} 193 ~ManualImpl() override = default; 194 Status CheckStrategy(const Shape ¶m_strategy, const Shape &indices_strategy) override; 195 Status InferDevMatrixShape() override; 196 Status InferTensorMap() override; 197 Status InferTensorInfo() override; 198 Status InferReplaceGraph(const CNodePtr &cnode) override; 199 Status InferReplaceOps() override; 200 set_param_split_shapes(const Shape & a)201 void set_param_split_shapes(const Shape &a) { param_split_shapes_ = a; } set_index_offsets(const Shape & a)202 void set_index_offsets(const Shape &a) { index_offsets_ = a; } set_target(const std::string & a)203 void set_target(const std::string &a) { target_ = a; } set_attrs(const mindspore::HashMap<std::string,ValuePtr> & a)204 void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &a) { attrs_ = a; } set_replace_op_name(const std::string & a)205 void set_replace_op_name(const std::string &a) { replace_op_name_ = a; } 206 207 protected: 208 Status InferOffset(); 209 std::string target_ = DEVICE; 210 mindspore::HashMap<std::string, ValuePtr> attrs_; 211 std::string replace_op_name_; 212 int64_t index_offset_ = 0; 213 214 private: 215 Shape param_split_shapes_; 216 Shape index_offsets_; 217 }; 218 219 class GatherManualImpl : public ManualImpl { 220 public: GatherManualImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)221 GatherManualImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis) 222 : ManualImpl(name, inputs_shape, outputs_shape, axis) {} 223 ~GatherManualImpl() override = default; 224 Status InferReplaceGraph(const CNodePtr &cnode) override; 225 }; 226 227 // SHARD_AXIS_0_DYNAMIC, SHARD_AXIS_0_STATIC and SHARD_AXIS_1 mode: batch_dims = 0, and split axis 228 // constraint: 229 // 1) parameter's dim is 1 or 2, indices' dim >= 1 230 // 2) indices can't be split 231 // 3) axis = 0 or axis = 1 232 // 4) if axis = 1, can not support repeated calculation 233 // 5) if axis = 0, and param_shape[1] is split, can not support repeated calculation 234 // param shape: [A, B] 235 // indices shape: [C, D] 236 // batch_dims = 0 237 // axis = 0 238 // out = gather(param, indices, axis) 239 // out shape: [A, B, C] 240 // parameter's strategy: [a, b], indices' strategy: [1, 1] 241 // output's strategy: 242 // 1) if use allreduce: [1, 1, b] 243 // 2) if use reducescatter: [a, 1, b] 244 // dev_matrix: [a, b] 245 class ShardAxisImpl : public GatherUtil { 246 public: ShardAxisImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)247 ShardAxisImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis) 248 : GatherUtil(name, inputs_shape, outputs_shape, axis) {} 249 ~ShardAxisImpl() override = default; 250 Status CheckStrategy(const Shape ¶m_strategy, const Shape &indices_strategy) override; 251 Status InferDevMatrixShape() override; 252 Status InferTensorMap() override; 253 Status InferTensorInfo() override; 254 virtual Status InferBias(); set_target(const std::string & a)255 void set_target(const std::string &a) { target_ = a; } set_dynamic_shape_indices(bool a)256 void set_dynamic_shape_indices(bool a) { dynamic_shape_indices_ = a; } set_attrs(const mindspore::HashMap<std::string,ValuePtr> & a)257 void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &a) { attrs_ = a; } set_replace_op_name(const std::string & a)258 void set_replace_op_name(const std::string &a) { replace_op_name_ = a; } set_axis_split_forward_allreduce(bool a)259 void set_axis_split_forward_allreduce(bool a) { axis_split_forward_allreduce_ = a; } 260 261 // ShardBatchAndAxisImpl and ShardAxisImpl 262 Status InferForwardCommunication() override; 263 Status InferReplaceOps() override; 264 Status InferReplaceGraph(const CNodePtr &cnode) override; set_assigned_parallel(bool is_assigned_parallel)265 void set_assigned_parallel(bool is_assigned_parallel) { is_assigned_parallel_ = is_assigned_parallel; } 266 267 protected: 268 // use for split axis 269 Status CheckSplitAxisStrategy(const Shape ¶m_strategy, const Shape &indices_strategy); 270 void SetAttribute(const Shape ¶m_strategy); 271 Status InferGroup(); 272 std::string target_ = DEVICE; 273 std::string replace_op_name_; 274 bool dynamic_shape_indices_ = false; 275 bool is_assigned_parallel_ = false; 276 bool axis_split_forward_allreduce_ = false; // when axis is split, use reducescatter as default in forward 277 int64_t repeated_calculation_num_ = 1; 278 Group group_; 279 mindspore::HashMap<std::string, ValuePtr> attrs_; 280 int64_t bias_ = 0; 281 int64_t slice_size_ = 0; 282 }; 283 284 // shard_batch_and_axis mode: axis = 0, batch_dims = 0, and only split the first dimension of parameter and the first 285 // dimension of indices constraint: 286 // 1) the dim of param is 2, and the dim of indices is 2; 287 // 2) only split the first dimension of parameter and the first dimension of indices, other dims can not be split 288 // 3) do not support repeat calculation 289 // param shape: [A, B] 290 // indices shape: [C, D] 291 // batch_dims = 0 292 // axis = 0 293 // out = gather(param, indices, axis) 294 // out shape: [C, D, B] 295 // parameter's strategy: [a, 1], indices' strategy: [c, 1] 296 // output's strategy: 297 // 1) if use allreduce: [c, 1, 1] 298 // 2) if use reducescatter: [a*c, 1, 1] 299 // dev_matrix: 300 // 1) if use allreduce: [c, a] 301 // 2) if use reducescatter: [a*c] 302 class ShardBatchAndAxisImpl : public ShardAxisImpl { 303 public: ShardBatchAndAxisImpl(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,int64_t axis)304 ShardBatchAndAxisImpl(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, int64_t axis) 305 : ShardAxisImpl(name, inputs_shape, outputs_shape, axis) {} 306 ~ShardBatchAndAxisImpl() override = default; CheckStrategy(const Shape & param_strategy,const Shape & indices_strategy)307 Status CheckStrategy(const Shape ¶m_strategy, const Shape &indices_strategy) override { 308 return SUCCESS; 309 } // no need check 310 Status InferDevMatrixShape() override; 311 Status InferTensorMap() override; InferTensorInfo()312 Status InferTensorInfo() override { return InferTensorInfoNoSplitAxis(); } // do not need to use out_dev_matrix_shape 313 Status InferBias() override; 314 }; 315 316 class GatherInfo : public OperatorInfo { 317 public: 318 GatherInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 319 const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2) OperatorInfo(name,inputs_shape,outputs_shape,attrs,std::make_shared<GatherCost> ())320 : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherCost>()), 321 replace_op_name_(replace_op_name) {} 322 ~GatherInfo() override = default; 323 Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy, 324 const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts = {}, 325 const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts = {}) override; 326 Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override; 327 328 std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; 329 Status SetCostUnderStrategy(const StrategyPtr &strategy) override; 330 ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; 331 std::shared_ptr<Strategies> GenerateBatchStrategies() override; param_split_shapes()332 const std::vector<int64_t> ¶m_split_shapes() const { return param_split_shapes_; } index_offsets()333 const std::vector<int64_t> &index_offsets() const { return index_offsets_; } 334 GatherMode GetGatherMode(const Shape ¶m_strategy, const Shape &indices_strategy) const; 335 336 protected: 337 Status CheckStrategy(const StrategyPtr &strategy) override; 338 Status CheckOutputStrategy(const StrategyPtr &out_strategy) override; 339 Status CheckStrategyForDynamicShape(const StrategyPtr &strategy) override; 340 Status InferMirrorOps() override; 341 Status InferForwardCommunication() override; 342 Status InferTensorInfo() override; 343 Status InferDevMatrixShape() override; 344 Status InferTensorMap() override; 345 Status GetAttrs() override; 346 virtual void DealWithBatchDimsMirrorOp() noexcept; 347 virtual void GetBatchDims() noexcept; 348 virtual GatherUtilPtr MakeManualUtil(); 349 int64_t axis_ = 0; 350 351 private: 352 int64_t batch_dims_ = 0; 353 Status GetManualSplitAttr(); 354 Status GetManualSplitWithoutOffsetAttr(); 355 Status ComputeReplaceOp(); 356 bool ShardBatchAndAxis(const Shape ¶m_strategy, const Shape &indices_strategy) const; 357 358 std::string target_ = DEVICE; 359 int64_t bias_ = 0; 360 std::string replace_op_name_ = GATHERV2; 361 bool manual_split_ = false; 362 bool dynamic_shape_indices_ = false; 363 std::vector<int64_t> param_split_shapes_; // manual split 364 std::vector<int64_t> index_offsets_; // manual split 365 GatherMode gather_mode_ = INVALID; 366 GatherUtilPtr gather_util_; 367 }; 368 369 class IndexSelectInfo final : public GatherInfo { 370 public: 371 IndexSelectInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 372 const PrimitiveAttrs &attrs, const std::string &replace_op_name = INDEX_SELECT) GatherInfo(name,inputs_shape,outputs_shape,attrs,replace_op_name)373 : GatherInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {} 374 ~IndexSelectInfo() override = default; 375 }; 376 377 class SparseGatherV2Info final : public GatherInfo { 378 public: 379 SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 380 const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2) GatherInfo(name,inputs_shape,outputs_shape,attrs,replace_op_name)381 : GatherInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {} 382 ~SparseGatherV2Info() override = default; 383 384 protected: DealWithBatchDimsMirrorOp()385 void DealWithBatchDimsMirrorOp() noexcept override {} GetBatchDims()386 void GetBatchDims() noexcept override {} 387 GatherUtilPtr MakeManualUtil() override; 388 }; 389 390 class EmbeddingLookupInfo final : public GatherInfo { 391 public: EmbeddingLookupInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)392 EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 393 const PrimitiveAttrs &attrs) 394 : GatherInfo(name, inputs_shape, outputs_shape, attrs) {} 395 ~EmbeddingLookupInfo() override = default; 396 397 protected: DealWithBatchDimsMirrorOp()398 void DealWithBatchDimsMirrorOp() noexcept override {} GetBatchDims()399 void GetBatchDims() noexcept override {} 400 GatherUtilPtr MakeManualUtil() override; 401 }; 402 } // namespace parallel 403 } // namespace mindspore 404 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_ 405