• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_DYNAMIC_SHAPE_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_DYNAMIC_SHAPE_H_
19 
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 #include <string>
24 #include "pipeline/pynative/grad/top_cell.h"
25 #include "ops/ops_func_impl/simple_infer.h"
26 
27 namespace mindspore {
28 namespace pynative {
29 
30 struct NodeInfo {
31   // Is parameter or input or op's output
32   InputType grad_type;
33   // Just op output tensor has op_index
34   size_t op_index{0};
35   // For scalar compare
36   ValuePtr value{nullptr};
37   // For Input is tuple or list
38   std::vector<NodeInfo> seq_node;
39 };
40 
41 struct AbsCompareInfo {
42   AbsCompareInfo() = default;
AbsCompareInfoAbsCompareInfo43   AbsCompareInfo(abstract::AbstractBasePtrList input_abs, abstract::AbstractBasePtr out_abs)
44       : input_abs(std::move(input_abs)), out_abs(std::move(out_abs)) {}
45   abstract::AbstractBasePtrList input_abs{};
46   abstract::AbstractBasePtr out_abs{nullptr};
47   std::vector<NodeInfo> inputs;
48 };
49 
50 struct ValueCompareInfo {
51   // ValueSimpleInfo
52   ValueSimpleInfo input_value_simple_info;
53   std::vector<NodeInfo> inputs;
54 };
55 
56 struct DynamicDetectNodeInfo {
57   explicit DynamicDetectNodeInfo(PrimitivePtr op_prim, bool is_value_compare = true)
op_primDynamicDetectNodeInfo58       : op_prim(std::move(op_prim)), is_value_compare(is_value_compare) {}
DynamicDetectNodeInfoDynamicDetectNodeInfo59   DynamicDetectNodeInfo(PrimitivePtr op_prim, abstract::AbstractBasePtrList input_abs,
60                         abstract::AbstractBasePtr out_abs)
61       : op_prim(std::move(op_prim)), abs_compare_info(std::move(input_abs), std::move(out_abs)) {}
62 
63   PrimitivePtr op_prim{nullptr};
64   bool is_value_compare{false};
65   bool is_graph_node{false};
66   std::string graph_phase;
67   AbsCompareInfo abs_compare_info;
68   ValueCompareInfo value_compare_info;
69 };
70 using DynamicDetectNodeInfoPtr = std::shared_ptr<DynamicDetectNodeInfo>;
71 using CellIdWithDynamicNodesMap =
72   mindspore::HashMap<std::string, mindspore::HashMap<std::string, std::vector<DynamicDetectNodeInfoPtr>>>;
73 
74 class NodeDynamicDetect {
75  public:
76   NodeDynamicDetect() = default;
77   ~NodeDynamicDetect() = default;
Clear()78   void Clear() { cell_id_with_dynamic_detect_nodes_.clear(); }
79   bool CheckNodeDynamic(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs,
80                         const DynamicDetectNodeInfoPtr &node);
81   bool IsNeedSaveDynamicDetectNodes(const TopCellInfoPtr &top_cell, bool use_dynamic_shape_process);
82 
83  private:
84   bool IsNodeDynamic(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs, const DynamicDetectNodeInfoPtr &node,
85                      size_t node_idx);
86   void SaveDynamicDetectNodeInfoInFirstTime(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs,
87                                             const DynamicDetectNodeInfoPtr &node, size_t node_idx);
88 
89   std::mutex async_mutex_;
90   CellIdWithDynamicNodesMap cell_id_with_dynamic_detect_nodes_;
91 };
92 using NodeDynamicDetectPtr = std::shared_ptr<NodeDynamicDetect>;
93 
94 class TopCellUnknownShapeDetect {
95  public:
96   TopCellUnknownShapeDetect() = default;
97   ~TopCellUnknownShapeDetect() = default;
98 
99   void SetDynamicInput(const py::object &obj, const py::args &args);
100   void TryChangeTopCellToUnknownShape(const std::string &obj_id, const abstract::BaseShapePtrList &arg_base_shape_vec,
101                                       bool is_auto_detect);
102   void UpdateArgsAbsToUnknownShapeAbs(const py::object &obj, const py::args &args);
103 
Clear()104   void Clear() {
105     obj_with_by_inputs_.clear();
106     obj_id_args_info_by_set_inputs_.clear();
107   }
108 
109  private:
110   // pre top cell is already unknown shape, args shape is current input, check whether the requirements are met through
111   // shape comparison.
112   bool CanFindMatchedUnknownShapeTopCell(const TopCellInfoPtr &pre_top_cell,
113                                          const abstract::BaseShapePtrList &cur_args_shape);
114   bool SetTopCellUnknownShape(const TopCellInfoPtr &cur_top_cell, const TopCellInfoPtr &pre_top_cell,
115                               const abstract::BaseShapePtrList &args_shape);
116   void ChangeTopCellToUnknownShape(const TopCellInfoPtr &top_cell,
117                                    const abstract::BaseShapePtrList &args_unknown_shape);
118   void UpdateUnknownShapeAbsCache(const std::vector<string> &input_arg_id_vec,
119                                   const std::vector<ValuePtr> &input_arg_value_vec,
120                                   const std::vector<abstract::BaseShapePtr> &args_base_shape);
121 
122   // Like TrainOneStep, it is a cell and run first, top cell create first, but set inputs set in main cell
123   // and run later, so need change top cell to unknown shape too.
124   void UpdatePossibleTopCellToUnknownShape(const TopCellInfoPtr &cur_top_cell,
125                                            const std::vector<string> &cur_arg_id_vec,
126                                            const abstract::BaseShapePtrList &cur_args_shape);
127 
128   // Obj id(cell or function) with set inputs
129   mindspore::HashSet<std::string> obj_with_by_inputs_;
130   // Obj id with its args base shape
131   mindspore::HashMap<std::string, abstract::BaseShapePtrList> obj_id_args_info_by_set_inputs_;
132 };
133 using TopCellUnknownShapeDetectPtr = std::shared_ptr<TopCellUnknownShapeDetect>;
134 
135 class DynamicShape {
136  public:
DynamicShape()137   DynamicShape()
138       : top_cell_dynamic_detect_ptr_(std::make_shared<TopCellUnknownShapeDetect>()),
139         node_dynamic_detect_ptr_(std::make_shared<NodeDynamicDetect>()) {}
140   ~DynamicShape() = default;
141 
set_enable_unknown_shape(bool enable_unknown_shape)142   void set_enable_unknown_shape(bool enable_unknown_shape) { enable_unknown_shape_ = enable_unknown_shape; }
enable_unknown_shape()143   inline bool enable_unknown_shape() const { return enable_unknown_shape_; }
144   py::object GetDynamicInput(const py::object &actual_input);
145   void SaveUnknownShapeAbsFromJit(const ValuePtr &v, const AbstractBasePtr &abs, size_t index);
146 
147   // For node dynamic struct check
CheckNodeDynamic(const TopCellInfoPtr & top_cell,const ValuePtrList & inputs,const DynamicDetectNodeInfoPtr & node)148   bool CheckNodeDynamic(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs,
149                         const DynamicDetectNodeInfoPtr &node) {
150     return node_dynamic_detect_ptr_->CheckNodeDynamic(top_cell, inputs, node);
151   }
IsNeedSaveDynamicDetectNodes(const TopCellInfoPtr & top_cell,bool use_dynamic_shape_process)152   bool IsNeedSaveDynamicDetectNodes(const TopCellInfoPtr &top_cell, bool use_dynamic_shape_process) {
153     return node_dynamic_detect_ptr_->IsNeedSaveDynamicDetectNodes(top_cell, use_dynamic_shape_process);
154   }
155 
156   // For top cell unknown shape
SetDynamicInput(const py::object & obj,const py::args & args)157   void SetDynamicInput(const py::object &obj, const py::args &args) {
158     top_cell_dynamic_detect_ptr_->SetDynamicInput(obj, args);
159   }
TryChangeTopCellToUnknownShape(const std::string & obj_id,const abstract::BaseShapePtrList & arg_base_shape_vec,bool is_auto_detect)160   void TryChangeTopCellToUnknownShape(const std::string &obj_id, const abstract::BaseShapePtrList &arg_base_shape_vec,
161                                       bool is_auto_detect) {
162     top_cell_dynamic_detect_ptr_->TryChangeTopCellToUnknownShape(obj_id, arg_base_shape_vec, is_auto_detect);
163   }
UpdateArgsAbsToUnknownShapeAbs(const py::object & obj,const py::args & args)164   void UpdateArgsAbsToUnknownShapeAbs(const py::object &obj, const py::args &args) {
165     top_cell_dynamic_detect_ptr_->UpdateArgsAbsToUnknownShapeAbs(obj, args);
166   }
167 
Clear()168   void Clear() {
169     node_dynamic_detect_ptr_->Clear();
170     top_cell_dynamic_detect_ptr_->Clear();
171   }
172 
173  private:
174   bool enable_unknown_shape_{false};
175   TopCellUnknownShapeDetectPtr top_cell_dynamic_detect_ptr_{nullptr};
176   NodeDynamicDetectPtr node_dynamic_detect_ptr_{nullptr};
177 };
178 using DynamicShapePtr = std::shared_ptr<DynamicShape>;
179 }  // namespace pynative
180 }  // namespace mindspore
181 
182 #endif  // MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_DYNAMIC_SHAPE_H_
183