1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_DYNAMIC_SHAPE_H_ 18 #define MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_DYNAMIC_SHAPE_H_ 19 20 #include <memory> 21 #include <utility> 22 #include <vector> 23 #include <string> 24 #include "pipeline/pynative/grad/top_cell.h" 25 #include "ops/ops_func_impl/simple_infer.h" 26 27 namespace mindspore { 28 namespace pynative { 29 30 struct NodeInfo { 31 // Is parameter or input or op's output 32 InputType grad_type; 33 // Just op output tensor has op_index 34 size_t op_index{0}; 35 // For scalar compare 36 ValuePtr value{nullptr}; 37 // For Input is tuple or list 38 std::vector<NodeInfo> seq_node; 39 }; 40 41 struct AbsCompareInfo { 42 AbsCompareInfo() = default; AbsCompareInfoAbsCompareInfo43 AbsCompareInfo(abstract::AbstractBasePtrList input_abs, abstract::AbstractBasePtr out_abs) 44 : input_abs(std::move(input_abs)), out_abs(std::move(out_abs)) {} 45 abstract::AbstractBasePtrList input_abs{}; 46 abstract::AbstractBasePtr out_abs{nullptr}; 47 std::vector<NodeInfo> inputs; 48 }; 49 50 struct ValueCompareInfo { 51 // ValueSimpleInfo 52 ValueSimpleInfo input_value_simple_info; 53 std::vector<NodeInfo> inputs; 54 }; 55 56 struct DynamicDetectNodeInfo { 57 explicit DynamicDetectNodeInfo(PrimitivePtr op_prim, bool is_value_compare = true) op_primDynamicDetectNodeInfo58 : op_prim(std::move(op_prim)), is_value_compare(is_value_compare) {} DynamicDetectNodeInfoDynamicDetectNodeInfo59 DynamicDetectNodeInfo(PrimitivePtr op_prim, abstract::AbstractBasePtrList input_abs, 60 abstract::AbstractBasePtr out_abs) 61 : op_prim(std::move(op_prim)), abs_compare_info(std::move(input_abs), std::move(out_abs)) {} 62 63 PrimitivePtr op_prim{nullptr}; 64 bool is_value_compare{false}; 65 bool is_graph_node{false}; 66 std::string graph_phase; 67 AbsCompareInfo abs_compare_info; 68 ValueCompareInfo value_compare_info; 69 }; 70 using DynamicDetectNodeInfoPtr = std::shared_ptr<DynamicDetectNodeInfo>; 71 using CellIdWithDynamicNodesMap = 72 mindspore::HashMap<std::string, mindspore::HashMap<std::string, std::vector<DynamicDetectNodeInfoPtr>>>; 73 74 class NodeDynamicDetect { 75 public: 76 NodeDynamicDetect() = default; 77 ~NodeDynamicDetect() = default; Clear()78 void Clear() { cell_id_with_dynamic_detect_nodes_.clear(); } 79 bool CheckNodeDynamic(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs, 80 const DynamicDetectNodeInfoPtr &node); 81 bool IsNeedSaveDynamicDetectNodes(const TopCellInfoPtr &top_cell, bool use_dynamic_shape_process); 82 83 private: 84 bool IsNodeDynamic(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs, const DynamicDetectNodeInfoPtr &node, 85 size_t node_idx); 86 void SaveDynamicDetectNodeInfoInFirstTime(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs, 87 const DynamicDetectNodeInfoPtr &node, size_t node_idx); 88 89 std::mutex async_mutex_; 90 CellIdWithDynamicNodesMap cell_id_with_dynamic_detect_nodes_; 91 }; 92 using NodeDynamicDetectPtr = std::shared_ptr<NodeDynamicDetect>; 93 94 class TopCellUnknownShapeDetect { 95 public: 96 TopCellUnknownShapeDetect() = default; 97 ~TopCellUnknownShapeDetect() = default; 98 99 void SetDynamicInput(const py::object &obj, const py::args &args); 100 void TryChangeTopCellToUnknownShape(const std::string &obj_id, const abstract::BaseShapePtrList &arg_base_shape_vec, 101 bool is_auto_detect); 102 void UpdateArgsAbsToUnknownShapeAbs(const py::object &obj, const py::args &args); 103 Clear()104 void Clear() { 105 obj_with_by_inputs_.clear(); 106 obj_id_args_info_by_set_inputs_.clear(); 107 } 108 109 private: 110 // pre top cell is already unknown shape, args shape is current input, check whether the requirements are met through 111 // shape comparison. 112 bool CanFindMatchedUnknownShapeTopCell(const TopCellInfoPtr &pre_top_cell, 113 const abstract::BaseShapePtrList &cur_args_shape); 114 bool SetTopCellUnknownShape(const TopCellInfoPtr &cur_top_cell, const TopCellInfoPtr &pre_top_cell, 115 const abstract::BaseShapePtrList &args_shape); 116 void ChangeTopCellToUnknownShape(const TopCellInfoPtr &top_cell, 117 const abstract::BaseShapePtrList &args_unknown_shape); 118 void UpdateUnknownShapeAbsCache(const std::vector<string> &input_arg_id_vec, 119 const std::vector<ValuePtr> &input_arg_value_vec, 120 const std::vector<abstract::BaseShapePtr> &args_base_shape); 121 122 // Like TrainOneStep, it is a cell and run first, top cell create first, but set inputs set in main cell 123 // and run later, so need change top cell to unknown shape too. 124 void UpdatePossibleTopCellToUnknownShape(const TopCellInfoPtr &cur_top_cell, 125 const std::vector<string> &cur_arg_id_vec, 126 const abstract::BaseShapePtrList &cur_args_shape); 127 128 // Obj id(cell or function) with set inputs 129 mindspore::HashSet<std::string> obj_with_by_inputs_; 130 // Obj id with its args base shape 131 mindspore::HashMap<std::string, abstract::BaseShapePtrList> obj_id_args_info_by_set_inputs_; 132 }; 133 using TopCellUnknownShapeDetectPtr = std::shared_ptr<TopCellUnknownShapeDetect>; 134 135 class DynamicShape { 136 public: DynamicShape()137 DynamicShape() 138 : top_cell_dynamic_detect_ptr_(std::make_shared<TopCellUnknownShapeDetect>()), 139 node_dynamic_detect_ptr_(std::make_shared<NodeDynamicDetect>()) {} 140 ~DynamicShape() = default; 141 set_enable_unknown_shape(bool enable_unknown_shape)142 void set_enable_unknown_shape(bool enable_unknown_shape) { enable_unknown_shape_ = enable_unknown_shape; } enable_unknown_shape()143 inline bool enable_unknown_shape() const { return enable_unknown_shape_; } 144 py::object GetDynamicInput(const py::object &actual_input); 145 void SaveUnknownShapeAbsFromJit(const ValuePtr &v, const AbstractBasePtr &abs, size_t index); 146 147 // For node dynamic struct check CheckNodeDynamic(const TopCellInfoPtr & top_cell,const ValuePtrList & inputs,const DynamicDetectNodeInfoPtr & node)148 bool CheckNodeDynamic(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs, 149 const DynamicDetectNodeInfoPtr &node) { 150 return node_dynamic_detect_ptr_->CheckNodeDynamic(top_cell, inputs, node); 151 } IsNeedSaveDynamicDetectNodes(const TopCellInfoPtr & top_cell,bool use_dynamic_shape_process)152 bool IsNeedSaveDynamicDetectNodes(const TopCellInfoPtr &top_cell, bool use_dynamic_shape_process) { 153 return node_dynamic_detect_ptr_->IsNeedSaveDynamicDetectNodes(top_cell, use_dynamic_shape_process); 154 } 155 156 // For top cell unknown shape SetDynamicInput(const py::object & obj,const py::args & args)157 void SetDynamicInput(const py::object &obj, const py::args &args) { 158 top_cell_dynamic_detect_ptr_->SetDynamicInput(obj, args); 159 } TryChangeTopCellToUnknownShape(const std::string & obj_id,const abstract::BaseShapePtrList & arg_base_shape_vec,bool is_auto_detect)160 void TryChangeTopCellToUnknownShape(const std::string &obj_id, const abstract::BaseShapePtrList &arg_base_shape_vec, 161 bool is_auto_detect) { 162 top_cell_dynamic_detect_ptr_->TryChangeTopCellToUnknownShape(obj_id, arg_base_shape_vec, is_auto_detect); 163 } UpdateArgsAbsToUnknownShapeAbs(const py::object & obj,const py::args & args)164 void UpdateArgsAbsToUnknownShapeAbs(const py::object &obj, const py::args &args) { 165 top_cell_dynamic_detect_ptr_->UpdateArgsAbsToUnknownShapeAbs(obj, args); 166 } 167 Clear()168 void Clear() { 169 node_dynamic_detect_ptr_->Clear(); 170 top_cell_dynamic_detect_ptr_->Clear(); 171 } 172 173 private: 174 bool enable_unknown_shape_{false}; 175 TopCellUnknownShapeDetectPtr top_cell_dynamic_detect_ptr_{nullptr}; 176 NodeDynamicDetectPtr node_dynamic_detect_ptr_{nullptr}; 177 }; 178 using DynamicShapePtr = std::shared_ptr<DynamicShape>; 179 } // namespace pynative 180 } // namespace mindspore 181 182 #endif // MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_DYNAMIC_SHAPE_H_ 183