1 /** 2 * Copyright 2020-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BASE_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BASE_H_ 19 20 #include <utility> 21 #include <vector> 22 #include <string> 23 #include <memory> 24 25 #include "utils/hash_map.h" 26 #include "utils/hash_set.h" 27 #include "ir/anf.h" 28 #include "pybind_api/ir/primitive_py.h" 29 #include "abstract/abstract_value.h" 30 #include "include/common/utils/stub_tensor.h" 31 #include "mindspore/core/utils/simple_info.h" 32 #include "ops/op_def.h" 33 34 namespace mindspore { 35 namespace pynative { 36 namespace py = pybind11; 37 constexpr size_t kDefaultContainerSize = 5000; 38 enum class SensType { kNormal = 0, kTuple = 1, kDict = 2 }; 39 40 struct BaseOpRunInfo { 41 uint64_t py_prim_id_{0}; 42 bool has_dynamic_output = false; 43 bool is_mixed_precision_cast = false; 44 bool use_dynamic_shape_process = false; 45 bool need_earse_cache = false; 46 size_t stream_id{kDefaultStreamIndex}; 47 std::string op_name; 48 std::string next_op_name; 49 std::string device_target = "Unknown"; 50 #if defined(__APPLE__) 51 int next_input_index = 0; 52 #else 53 size_t next_input_index = 0; 54 #endif 55 std::vector<ValuePtr> expanded_input_values; 56 std::vector<InputType> input_types; 57 AbstractBasePtr abstract; 58 std::vector<size_t> output_indexes; 59 std::vector<int64_t> dyn_input_sizes; 60 std::vector<tensor::BaseTensorPtr> output_tensors; 61 }; 62 63 struct AsyncStatus { 64 bool disable_mix_precision{false}; 65 bool is_jit_compiling{false}; 66 size_t custom_bprop_cell_count{0}; 67 }; 68 69 struct OpGradInfo { 70 PrimitivePtr op_prim{nullptr}; 71 abstract::AbstractBasePtrList input_abs{}; 72 abstract::AbstractBasePtr out_abs{nullptr}; 73 std::vector<ValuePtr> input_value{}; 74 ValuePtr out_value{nullptr}; 75 std::vector<InputType> input_value_grad_type{}; 76 size_t output_size; 77 bool is_need_recompute{false}; 78 ValueSimpleInfoPtr output_value_simple_info{nullptr}; 79 }; 80 using OpGradInfoPtr = std::shared_ptr<OpGradInfo>; 81 82 struct GradParam { GradParamGradParam83 GradParam(const OpGradInfoPtr &op_grad_info, bool use_dynamic_shape_process) 84 : op_grad_info(op_grad_info), use_dynamic_shape_process(use_dynamic_shape_process) { 85 input_size = op_grad_info->input_value.size(); 86 } 87 88 OpGradInfoPtr op_grad_info; 89 90 // Dynamic shape or dynamic structure 91 bool use_dynamic_shape_process{false}; 92 93 // For other used 94 bool out_used_in_bporp_graph{false}; 95 bool is_control_flow{false}; 96 bool is_func_grad{false}; 97 size_t input_size{0}; 98 99 // For jit domain 100 bool has_added_v{false}; 101 bool is_jit_graph{false}; 102 bool jit_out_has_dict{false}; 103 bool is_jit_self_dynamic_shape{false}; 104 105 // For KPynativeWithFProp used 106 FuncGraphPtr fg{nullptr}; 107 // grad func graph for jit or fg 108 FuncGraphPtr source_fg{nullptr}; 109 // Op forward output used in bprop graph 110 std::string graph_cache_key; 111 // Used for pyexecute 112 CNodePtr cnode; 113 }; 114 115 using GradParamPtr = std::shared_ptr<GradParam>; 116 117 struct FrontendOpRunInfo { FrontendOpRunInfoFrontendOpRunInfo118 FrontendOpRunInfo() { op_grad_info = std::make_shared<OpGradInfo>(); } 119 OpGradInfoPtr op_grad_info; 120 121 BaseOpRunInfo base_op_run_info; 122 bool run_in_vm = false; 123 bool requires_grad = false; 124 bool output_get_by_infer_value = false; 125 bool should_be_cache = false; 126 bool is_jit_input = false; 127 bool is_view_op = false; 128 int mix_type{0}; 129 size_t input_size = 0; 130 // none_intit_inputs is the inputs those not defined in Primitive's __init__ function 131 size_t none_init_inputs_num = 0; 132 // real_out return to python; out_value in OpGradInfo may be fake value; 133 ValuePtr real_out{nullptr}; 134 std::string op_info; 135 std::string out_value_id; 136 std::string cell_obj_id; 137 // Hold tensorGradType 138 std::vector<std::string> input_value_id{}; 139 stub::StubNodePtr stub_output{nullptr}; 140 std::vector<Signature> signatures{}; 141 std::vector<ops::OP_DTYPE> source_type{}; 142 AsyncStatus async_status; 143 mindspore::HashSet<size_t> input_to_attr{}; 144 }; 145 using FrontendOpRunInfoPtr = std::shared_ptr<FrontendOpRunInfo>; 146 147 struct InputArgsInfo { 148 InputArgsInfo() = default; 149 ~InputArgsInfo() = default; InputArgsInfoInputArgsInfo150 InputArgsInfo(bool is_grad_topest_cell, bool is_high_order_top_cell) 151 : is_grad_topest_cell(is_grad_topest_cell), is_high_order_top_cell(is_high_order_top_cell) {} 152 153 bool is_grad_topest_cell; 154 bool is_high_order_top_cell; 155 156 bool is_need_recompute{false}; 157 bool has_custom_bprop{false}; 158 SensType sens_type{SensType::kNormal}; 159 PrimitivePyPtr custom_bprop_prim{nullptr}; 160 ValuePtr out_value{nullptr}; 161 std::string obj_id; 162 std::string cell_id; 163 std::string already_run_cell_id; 164 std::string input_args_id; 165 size_t input_size = 0; 166 std::vector<std::string> input_arg_id_vec; 167 std::vector<ValuePtr> input_arg_value_vec; 168 // Used for dynamic shape auto detect 169 std::vector<abstract::BaseShapePtr> input_arg_base_shape_vec; 170 171 // Free memory ResetInputArgsInfo172 void Reset() { 173 custom_bprop_prim = nullptr; 174 out_value = nullptr; 175 input_arg_value_vec.clear(); 176 } 177 }; 178 using InputArgsInfoPtr = std::shared_ptr<InputArgsInfo>; 179 180 class FastValue { 181 public: 182 FastValue() = default; 183 ~FastValue() = default; 184 FastValue(const int64_t & v)185 explicit FastValue(const int64_t &v) : int_value_(v), is_int_{true} {} FastValue(std::vector<int64_t> v)186 explicit FastValue(std::vector<int64_t> v) : vec_value_(std::move(v)), is_int_{false} {} 187 is_int()188 bool is_int() const { return is_int_; } int_value()189 int64_t int_value() const { return int_value_; } vec_value()190 const std::vector<int64_t> &vec_value() const { return vec_value_; } 191 192 private: 193 int64_t int_value_{0}; 194 std::vector<int64_t> vec_value_; 195 bool is_int_{false}; 196 }; 197 using FastValuePtr = std::shared_ptr<FastValue>; 198 199 struct SliceOpInfo { 200 SliceOpInfo() = default; 201 ~SliceOpInfo() = default; 202 std::string slice_op_name; 203 std::vector<size_t> data_indexs; 204 std::vector<FastValuePtr> slice_index_inputs; 205 }; 206 using SliceOpInfoPtr = std::shared_ptr<SliceOpInfo>; 207 208 struct GraphCallCondition { GraphCallConditionGraphCallCondition209 GraphCallCondition(bool is_control_flow, bool is_jit_graph, bool is_dynamic_shape_process, bool jit_out_has_dict, 210 bool is_func_grad) 211 : is_control_flow_(is_control_flow), 212 is_jit_graph_(is_jit_graph), 213 is_dynamic_shape_process_(is_dynamic_shape_process), 214 jit_out_has_dict_(jit_out_has_dict), 215 is_func_grad_(is_func_grad) {} 216 217 bool is_control_flow_; 218 bool is_jit_graph_; 219 bool is_dynamic_shape_process_; 220 bool jit_out_has_dict_; 221 bool is_func_grad_; 222 }; 223 } // namespace pynative 224 } // namespace mindspore 225 226 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BASE_H_ 227