• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_UTILS_H_
18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_UTILS_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include <utility>
24 #include "pipeline/pynative/base.h"
25 #include "pipeline/pynative/pynative_execute.h"
26 #include "kernel/pyboost/op_runner.h"
27 #include "kernel/pyboost/op_register.h"
28 #include "pipeline/pynative/forward/forward_task.h"
29 #include "pipeline/pynative/grad/function/func_builder.h"
30 #include "pipeline/jit/ps/parse/data_converter.h"
31 
32 namespace mindspore {
33 namespace pynative {
34 class PyNativeExecutor;
35 using CallBackFn = std::function<VectorRef(const VectorRef &arg_list)>;
36 enum class SpecialType { kZerosLikeType = 0, kOnesLikeType = 1 };
37 
38 namespace PyNativeAlgo {
39 // Common function
40 struct Common {
41   static AbstractBasePtr SetAbstractValueToAnyValue(const AbstractBasePtr &abs);
42   static AnfNodePtr ConvertValueSequenceToMakeTuple(const ValueNodePtr &node, const FuncGraphPtr &func_graph);
43   static std::string GetIdByValue(const ValuePtr &v);
44   static std::string GetCellId(const std::string &obj_id, const std::vector<std::string> &input_arg_id_vec,
45                                const std::vector<ValuePtr> &input_arg_value_vec);
46   static void SplitString(const std::string &str, std::vector<std::string> *id_vec);
47   static bool ValueHasDynamicShape(const ValuePtr &value);
48   static bool IsTensor(const ValuePtr &v, bool include_sequence = false);
49   static bool IsControlFlowGraph(const FuncGraphPtr &func_graph);
50   static ValuePtr FilterSensValues(const ValuePtr &value, bool dict_convert_to_tuple);
51   static tensor::BaseTensorPtr GetTensorFromParam(const AnfNodePtr &param_node);
52   static void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
53   static TypeId GetTypeFromAbstract(const abstract::AbstractBasePtr &abs);
54   static ShapeVector GetShapeFromAbstract(const abstract::AbstractBasePtr &abs);
55   static std::pair<TypePtr, TypeId> GetTypeFromValue(const ValuePtr &v);
56   static ShapeVector GetShapeFromValue(const ValuePtr &v);
57   static ValuePtr CreatOutputTensorValueByAbstract(const abstract::AbstractBasePtr &abs);
58   static void ReplaceCNodeWithValueNode(const FuncGraphPtr &bprop_graph);
59   static const std::shared_ptr<PyNativeExecutor> &GetPyNativeExecutor();
60   static void StubNodeToValue(const FrontendOpRunInfoPtr &op_run_info);
61   static tensor::BaseTensorPtr StubNodeToTensor(const ValuePtr &value);
62   static tensor::BaseTensorPtr ConvertStubNodeToTensor(const ValuePtr &v, bool need_contiguous, bool requires_grad);
63   static std::optional<tensor::BaseTensorPtr> ConvertStubNodeToTensor(const std::optional<ValuePtr> &v,
64                                                                       bool need_contiguous, bool requires_grad);
65   static ValueTuplePtr ConvertStubNodeToValueTuple(const ValueListPtr &v, bool need_contiguous, bool requires_grad);
66   static ValueTuplePtr ConvertStubNodeToValueTuple(const ValueTuplePtr &v, bool need_contiguous, bool requires_grad);
67   static std::optional<ValueTuplePtr> ConvertStubNodeToValueTuple(const std::optional<ValueTuplePtr> &v,
68                                                                   bool need_contiguous, bool requires_grad);
69   static void GetConstInputToAttr(const PrimitivePtr &op_prim, const std::string &op_name,
70                                   const std::string &device_target, bool is_dynamic_shape,
71                                   mindspore::HashSet<size_t> *input_to_attr_index);
72   static ValueNodePtr CreateValueNodeByValue(const ValuePtr &v, const abstract::AbstractBasePtr &abs = nullptr);
73   static ValuePtr CreateFakeValueWithoutDeviceAddress(const ValuePtr &value);
74   static tensor::TensorPtr CreateFakeTensorWithoutDeviceAddress(const tensor::TensorPtr &tensor);
IsParamCommon75   static inline bool IsParam(InputType grad_type) {
76     return grad_type == InputType::kParameter || grad_type == InputType::kInput;
77   }
IsParamRequiresGradCommon78   static inline bool IsParamRequiresGrad(const tensor::BaseTensorPtr &tensor) {
79     return tensor->param_info() != nullptr && tensor->param_info()->requires_grad();
80   }
81   static void ClearDeviceAddress(const ValuePtr &value);
IsConstantCommon82   static inline bool IsConstant(InputType grad_type) { return grad_type == InputType::kConstant; }
83   static InputType SetValueGradInfo(const ValuePtr &value, const TopCellInfoPtr &top_cell, InputType grad_type);
84   static InputType SetTensorGradInfo(const tensor::BaseTensorPtr &tensor, const TopCellInfoPtr &top_cell);
85   static void SetGraphInputAndWeightsInfo(const FrontendOpRunInfoPtr &op_run_info, const FuncGraphPtr &func_graph,
86                                           const TopCellInfoPtr &top_cell);
87   static void ProcessTupleParam(const FuncGraphPtr &bprop_graph, size_t position);
88   static void ProcessDictParam(const FuncGraphPtr &bprop_graph, size_t position);
89   static void FreeFuncGraphForwardNodes(const FuncGraphPtr &func_graph);
90   static tensor::BaseTensorPtr ConvertToContiguousTensor(const tensor::BaseTensorPtr &tensor, bool requires_grad);
91   static ValuePtr ConvertToContiguousValue(const ValuePtr &v, bool requires_grad);
92   static size_t GetValueSize(const ValuePtr &v);
93   static ValuePtr CreateTensorByConstantValue(const ValuePtr &value);
94 
95   template <typename T>
96   static std::string PrintDebugInfo(std::vector<T> items, const std::string &info_header = "",
97                                     bool is_print_tensor_data = false) {
98     static constexpr size_t end_char_size = 2;
99     std::ostringstream buf;
100     buf << info_header;
101     for (size_t i = 0; i < items.size(); ++i) {
102       if (items[i] == nullptr) {
103         MS_LOG(DEBUG) << "The " << i << "'th item is nullptr!";
104         continue;
105       }
106       if (items[i]->template isa<tensor::BaseTensor>() && is_print_tensor_data) {
107         auto tensor = items[i]->template cast<tensor::BaseTensorPtr>();
108         auto grad = std::make_shared<tensor::Tensor>(*tensor);
109         grad->data_sync();
110         buf << i << "th: "
111             << "ptr " << items[i].get() << ", " << grad->ToStringRepr() << ", ";
112       } else {
113         buf << i << "th: "
114             << "ptr " << items[i].get() << ", " << items[i]->ToString() << ", ";
115       }
116     }
117     return buf.str().erase(buf.str().size() - end_char_size);
118   }
119 };
120 
121 // Parser python
122 struct PyParser {
123   static std::string GetIdByPyObj(const py::object &obj);
124   static std::pair<std::vector<std::string>, std::vector<ValuePtr>> GetArgsIdAndValue(const py::args &args);
125   static void SetPrim(const FrontendOpRunInfoPtr &op_run_info, const py::object &prim_arg);
126   static void ParseOpInputByPythonObj(const FrontendOpRunInfoPtr &op_run_info, const py::list &op_inputs,
127                                       bool stub = false);
128   static std::string BuilidPyInputTypeString(const py::object &obj);
129 
IsSupportTensorCastPyParser130   static inline bool IsSupportTensorCast(const std::vector<ops::OP_DTYPE> &cast_types) {
131     for (const auto &type : cast_types) {
132       if (type == ops::DT_TENSOR) {
133         return true;
134       }
135     }
136     return false;
137   }
138   static void PrintTypeCastError(const ops::OpDefPtr &op_def, const py::list &op_inputs, size_t idx);
139 };
140 
141 // Data convert
142 struct DataConvert {
143   static py::object ValueToPyObj(const ValuePtr &v);
144   static ValuePtr PyObjToValue(const py::object &obj, bool stub = false);
145   static ValuePtr BaseRefToValue(const BaseRef &value, bool requires_grad, bool is_out_sequence);
146   static ValuePtr VectorRefToValue(const VectorRef &vec_ref, bool requires_grad, bool is_out_sequence);
147   static void FlattenValueSeqArg(const ValuePtr &v, bool is_only_flatten_tensor_seq, bool is_filter_tensor,
148                                  std::vector<ValuePtr> *flatten_v);
149   static void FlattenArgs(const std::vector<ValuePtr> &v_vec, std::vector<ValuePtr> *flatten_v, bool has_sens);
150   static ValuePtrList FlattenTensorSeqInValue(const ValuePtr &v);
151   static ValuePtrList FlattenTensorSeqInValueSeq(const ValuePtrList &v, bool only_flatten_tensor = true);
152   static void GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const TopCellInfoPtr &top_cell);
153   static void ConvertCSRTensorToTensorList(const FrontendOpRunInfoPtr &op_run_info,
154                                            const tensor::CSRTensorPtr &csr_tensor, const TopCellInfoPtr &top_cell,
155                                            size_t index);
156   static void ConvertMapTensor(const FrontendOpRunInfoPtr &op_run_info, const tensor::MapTensorPtr &map_tensor,
157                                const TopCellInfoPtr &top_cell, size_t index);
158   static ValuePtr ConvertValueDictToValueTuple(const ValuePtr &v);
159   static void PlantTensorTupleToVector(const FrontendOpRunInfoPtr &op_run_info, const ValueSequencePtr &value_seq,
160                                        size_t index, const TopCellInfoPtr &top_cell);
161   static void ConvertValueTensorId(const ValuePtr &value, std::vector<std::string> *converted_tensor_id);
162   static void ConvertTupleValueToTensor(const FrontendOpRunInfoPtr &op_run_info, const ValueSequencePtr &value_seq,
163                                         size_t index, const TopCellInfoPtr &top_cell);
164   static void MarkInputs(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v, size_t index,
165                          const TopCellInfoPtr &top_cell);
166   static bool RunOpConvertConstInputToAttr(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v,
167                                            size_t input_index);
168 };
169 
170 struct PyBoost {
171   static FrontendOpRunInfoPtr Init(const PrimitivePtr &prim, const py::list &args);
172   static void MakeOutputValue(const FrontendOpRunInfoPtr &op_run_info, const kernel::pyboost::OpPtr &op);
173   static void DoGrad(const kernel::pyboost::OpPtr &op, const FrontendOpRunInfoPtr &op_run_info,
174                      ValuePtrList &&op_inputs);
175   static void SetAnyValueForAbstract(const kernel::pyboost::OpPtr &op);
176   static void UpdateStubOutput(const FrontendOpRunInfoPtr &op_run_info, const AbstractBasePtr &abstract,
177                                const kernel::pyboost::OpPtr &op);
178   static void UpdateOpRunInfo(const kernel::pyboost::OpPtr &op, const FrontendOpRunInfoPtr &op_run_info);
179   static PrimitivePtr ConvertPrimitive(const py::object &obj);
180   static py::object RunPyFunction(const PrimitivePtr &prim, const py::list &args);
181   template <typename T>
OptionalToValuePyBoost182   static ValuePtr OptionalToValue(const std::optional<T> &val) {
183     if (!val.has_value()) {
184       return kNone;
185     }
186     return val.value();
187   }
188 
189   template <typename Tuple, size_t... N>
TupleToVectorPyBoost190   static std::vector<ValuePtr> TupleToVector(const Tuple &tuple, std::index_sequence<N...>) {
191     std::vector<ValuePtr> inputs;
192     ((void)inputs.emplace_back(OptionalToValue(std::get<N>(tuple))), ...);
193     return inputs;
194   }
195 
196   template <typename T>
OptionalToValuePyBoost197   static T OptionalToValue(const T &val) {
198     return val;
199   }
200 
201   template <size_t N, typename... T>
SetPyBoostCastForInputsPyBoost202   static auto SetPyBoostCastForInputs(const FrontendOpRunInfoPtr &op_run_info,
203                                       const std::vector<std::vector<size_t>> &same_type_table, T... t) {
204     MS_EXCEPTION_IF_NULL(op_run_info);
205     op_run_info->input_size = sizeof...(t);
206     if (op_run_info->op_grad_info->op_prim->name() == kCast) {
207       return std::make_tuple(t...);
208     }
209     const auto &pyboost_cast_operation = Common::GetPyNativeExecutor()->forward_executor()->pyboost_cast_operation();
210     const auto &ret = pyboost_cast_operation->DoMixPrecisionCast(op_run_info, t...);
211     if constexpr (N != 0) {
212       return pyboost_cast_operation->DoImplicitCast<N>(op_run_info, same_type_table, ret);
213     }
214     return ret;
215   }
216   static void DataSyncForGraph(const kernel::pyboost::OpPtr &op, ValuePtrList &&op_inputs);
217 };
218 
219 // Used for auto grad, like func_grad and ir grad
220 struct AutoGrad {
221   static bool IsPrimNeedGrad(const PrimitivePtr &prim);
222   static bool NeedGrad(const std::vector<ValuePtr> &input_values);
223   static bool IsZerosLikeNode(const AnfNodePtr &node);
224   static ValuePtr GetFakeZeroTensor();
225   static ValuePtr BuildSpecialValueGrad(const ValuePtr &value, const tensor::BaseTensorPtr &grad,
226                                         autograd::FuncBuilder *func_builder, const SpecialType &type);
227   static AnfNodePtr BuildSpecialNode(const KernelGraphPtr &tape, const ValuePtr &value,
228                                      const abstract::AbstractBasePtr &abs, const SpecialType &type);
229   static AnfNodePtr BuildSparseTensorNode(const KernelGraphPtr &tape, const ValuePtr &sparse_value,
230                                           const AnfNodePtr &dout_value_node);
231   static void SetGradMetaData(const ValuePtr &value, const VariablePtr &variable, const ParameterPtr &param = nullptr);
232   static void SetGradInfoForInputs(const ValuePtr &value, const VariablePtr &variable,
233                                    const ParameterPtr &param = nullptr);
234 
235   // Create fake bprop
236   static void BuildFakeBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
237   static CallBackFn CreateGraphCallBack(const FuncGraphPtr &call_graph, const std::string &cache_key,
238                                         const GraphCallCondition &graph_call_condition);
239   static PrimitivePyPtr BuildBpropCutPrim(const PrimitivePtr &prim, bool is_need_recompute = false);
240   static void CheckRecomputeInputs(const GradParamPtr &grad_param);
241   static void ClearAutoGradStaticCache();
242   static void CheckAndSetAbstract(const OpGradInfoPtr &op_grad_info);
243   static void CacheOutputAbstract(const ValuePtr &v, const abstract::AbstractBasePtr &abs);
244 };
245 
246 // Some common functions used in both jit and PackFunc grad
247 struct GradCommon {
248   static bool IsRealOp(const AnfNodePtr &cnode);
249   static void GetUsedCNodeInBpropGraph(const CNodePtr &cnode, const mindspore::HashSet<size_t> &unused_inputs,
250                                        AnfNodePtrList *node_list);
251   static void SetForward(const AnfNodePtrList &node_list);
252 };
253 };  // namespace PyNativeAlgo
254 
255 void DispatchOp(const std::shared_ptr<runtime::AsyncTask> &task);
256 }  // namespace pynative
257 }  // namespace mindspore
258 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_UTILS_H_
259