• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BASE_H_
18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BASE_H_
19 
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 
25 #include "utils/hash_map.h"
26 #include "utils/hash_set.h"
27 #include "ir/anf.h"
28 #include "pybind_api/ir/primitive_py.h"
29 #include "abstract/abstract_value.h"
30 #include "include/common/utils/stub_tensor.h"
31 #include "mindspore/core/utils/simple_info.h"
32 #include "ops/op_def.h"
33 
34 namespace mindspore {
35 namespace pynative {
36 namespace py = pybind11;
37 constexpr size_t kDefaultContainerSize = 5000;
38 enum class SensType { kNormal = 0, kTuple = 1, kDict = 2 };
39 
40 struct BaseOpRunInfo {
41   uint64_t py_prim_id_{0};
42   bool has_dynamic_output = false;
43   bool is_mixed_precision_cast = false;
44   bool use_dynamic_shape_process = false;
45   bool need_earse_cache = false;
46   size_t stream_id{kDefaultStreamIndex};
47   std::string op_name;
48   std::string next_op_name;
49   std::string device_target = "Unknown";
50 #if defined(__APPLE__)
51   int next_input_index = 0;
52 #else
53   size_t next_input_index = 0;
54 #endif
55   std::vector<ValuePtr> expanded_input_values;
56   std::vector<InputType> input_types;
57   AbstractBasePtr abstract;
58   std::vector<size_t> output_indexes;
59   std::vector<int64_t> dyn_input_sizes;
60   std::vector<tensor::BaseTensorPtr> output_tensors;
61 };
62 
63 struct AsyncStatus {
64   bool disable_mix_precision{false};
65   bool is_jit_compiling{false};
66   size_t custom_bprop_cell_count{0};
67 };
68 
69 struct OpGradInfo {
70   PrimitivePtr op_prim{nullptr};
71   abstract::AbstractBasePtrList input_abs{};
72   abstract::AbstractBasePtr out_abs{nullptr};
73   std::vector<ValuePtr> input_value{};
74   ValuePtr out_value{nullptr};
75   std::vector<InputType> input_value_grad_type{};
76   size_t output_size;
77   bool is_need_recompute{false};
78   ValueSimpleInfoPtr output_value_simple_info{nullptr};
79 };
80 using OpGradInfoPtr = std::shared_ptr<OpGradInfo>;
81 
82 struct GradParam {
GradParamGradParam83   GradParam(const OpGradInfoPtr &op_grad_info, bool use_dynamic_shape_process)
84       : op_grad_info(op_grad_info), use_dynamic_shape_process(use_dynamic_shape_process) {
85     input_size = op_grad_info->input_value.size();
86   }
87 
88   OpGradInfoPtr op_grad_info;
89 
90   // Dynamic shape or dynamic structure
91   bool use_dynamic_shape_process{false};
92 
93   // For other used
94   bool out_used_in_bporp_graph{false};
95   bool is_control_flow{false};
96   bool is_func_grad{false};
97   size_t input_size{0};
98 
99   // For jit domain
100   bool has_added_v{false};
101   bool is_jit_graph{false};
102   bool jit_out_has_dict{false};
103   bool is_jit_self_dynamic_shape{false};
104 
105   // For KPynativeWithFProp used
106   FuncGraphPtr fg{nullptr};
107   // grad func graph for jit or fg
108   FuncGraphPtr source_fg{nullptr};
109   // Op forward output used in bprop graph
110   std::string graph_cache_key;
111   // Used for pyexecute
112   CNodePtr cnode;
113 };
114 
115 using GradParamPtr = std::shared_ptr<GradParam>;
116 
117 struct FrontendOpRunInfo {
FrontendOpRunInfoFrontendOpRunInfo118   FrontendOpRunInfo() { op_grad_info = std::make_shared<OpGradInfo>(); }
119   OpGradInfoPtr op_grad_info;
120 
121   BaseOpRunInfo base_op_run_info;
122   bool run_in_vm = false;
123   bool requires_grad = false;
124   bool output_get_by_infer_value = false;
125   bool should_be_cache = false;
126   bool is_jit_input = false;
127   bool is_view_op = false;
128   int mix_type{0};
129   size_t input_size = 0;
130   // none_intit_inputs is the inputs those not defined in Primitive's __init__ function
131   size_t none_init_inputs_num = 0;
132   // real_out return to python; out_value in OpGradInfo may be fake value;
133   ValuePtr real_out{nullptr};
134   std::string op_info;
135   std::string out_value_id;
136   std::string cell_obj_id;
137   // Hold tensorGradType
138   std::vector<std::string> input_value_id{};
139   stub::StubNodePtr stub_output{nullptr};
140   std::vector<Signature> signatures{};
141   std::vector<ops::OP_DTYPE> source_type{};
142   AsyncStatus async_status;
143   mindspore::HashSet<size_t> input_to_attr{};
144 };
145 using FrontendOpRunInfoPtr = std::shared_ptr<FrontendOpRunInfo>;
146 
147 struct InputArgsInfo {
148   InputArgsInfo() = default;
149   ~InputArgsInfo() = default;
InputArgsInfoInputArgsInfo150   InputArgsInfo(bool is_grad_topest_cell, bool is_high_order_top_cell)
151       : is_grad_topest_cell(is_grad_topest_cell), is_high_order_top_cell(is_high_order_top_cell) {}
152 
153   bool is_grad_topest_cell;
154   bool is_high_order_top_cell;
155 
156   bool is_need_recompute{false};
157   bool has_custom_bprop{false};
158   SensType sens_type{SensType::kNormal};
159   PrimitivePyPtr custom_bprop_prim{nullptr};
160   ValuePtr out_value{nullptr};
161   std::string obj_id;
162   std::string cell_id;
163   std::string already_run_cell_id;
164   std::string input_args_id;
165   size_t input_size = 0;
166   std::vector<std::string> input_arg_id_vec;
167   std::vector<ValuePtr> input_arg_value_vec;
168   // Used for dynamic shape auto detect
169   std::vector<abstract::BaseShapePtr> input_arg_base_shape_vec;
170 
171   // Free memory
ResetInputArgsInfo172   void Reset() {
173     custom_bprop_prim = nullptr;
174     out_value = nullptr;
175     input_arg_value_vec.clear();
176   }
177 };
178 using InputArgsInfoPtr = std::shared_ptr<InputArgsInfo>;
179 
180 class FastValue {
181  public:
182   FastValue() = default;
183   ~FastValue() = default;
184 
FastValue(const int64_t & v)185   explicit FastValue(const int64_t &v) : int_value_(v), is_int_{true} {}
FastValue(std::vector<int64_t> v)186   explicit FastValue(std::vector<int64_t> v) : vec_value_(std::move(v)), is_int_{false} {}
187 
is_int()188   bool is_int() const { return is_int_; }
int_value()189   int64_t int_value() const { return int_value_; }
vec_value()190   const std::vector<int64_t> &vec_value() const { return vec_value_; }
191 
192  private:
193   int64_t int_value_{0};
194   std::vector<int64_t> vec_value_;
195   bool is_int_{false};
196 };
197 using FastValuePtr = std::shared_ptr<FastValue>;
198 
199 struct SliceOpInfo {
200   SliceOpInfo() = default;
201   ~SliceOpInfo() = default;
202   std::string slice_op_name;
203   std::vector<size_t> data_indexs;
204   std::vector<FastValuePtr> slice_index_inputs;
205 };
206 using SliceOpInfoPtr = std::shared_ptr<SliceOpInfo>;
207 
208 struct GraphCallCondition {
GraphCallConditionGraphCallCondition209   GraphCallCondition(bool is_control_flow, bool is_jit_graph, bool is_dynamic_shape_process, bool jit_out_has_dict,
210                      bool is_func_grad)
211       : is_control_flow_(is_control_flow),
212         is_jit_graph_(is_jit_graph),
213         is_dynamic_shape_process_(is_dynamic_shape_process),
214         jit_out_has_dict_(jit_out_has_dict),
215         is_func_grad_(is_func_grad) {}
216 
217   bool is_control_flow_;
218   bool is_jit_graph_;
219   bool is_dynamic_shape_process_;
220   bool jit_out_has_dict_;
221   bool is_func_grad_;
222 };
223 }  // namespace pynative
224 }  // namespace mindspore
225 
226 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BASE_H_
227