• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/forward/do_infer.h"
18 #include "pipeline/pynative/pynative_utils.h"
19 #include "frontend/operator/ops_front_infer_function.h"
20 #include "pybind_api/gil_scoped_long_running.h"
21 #include "include/common/profiler.h"
22 #include "ops/nn_op_name.h"
23 #include "ops/ops_frontend_func_impl.h"
24 
25 namespace mindspore {
26 namespace pynative {
27 namespace {
28 constexpr size_t kCacheThreshold = 10000;
29 
GetInferValueFromAbstract(const AbstractBasePtr & abs)30 ValuePtr GetInferValueFromAbstract(const AbstractBasePtr &abs) {
31   MS_EXCEPTION_IF_NULL(abs);
32   if (abs->isa<abstract::AbstractTensor>()) {
33     return abs->cast<abstract::AbstractTensorPtr>()->BuildValue();
34   } else if (abs->isa<abstract::AbstractSlice>()) {
35     return abs->cast<abstract::AbstractSlicePtr>()->BuildValue();
36   } else if (abs->isa<abstract::AbstractScalar>() || abs->isa<abstract::AbstractType>()) {
37     return abs->BuildValue();
38   } else if (abs->isa<abstract::AbstractTuple>()) {
39     auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
40     const auto &value = tuple_abs->BuildValue();
41     MS_EXCEPTION_IF_NULL(value);
42     if (value->isa<ValueAny>()) {
43       return value;
44     }
45     return tuple_abs->ElementsBuildValue<ValueTuple>();
46   } else if (abs->isa<abstract::AbstractList>()) {
47     auto list_abs = abs->cast<abstract::AbstractListPtr>();
48     const auto &value = list_abs->BuildValue();
49     if (value->isa<ValueAny>()) {
50       return value;
51     }
52     return list_abs->ElementsBuildValue<ValueList>();
53   } else if (abs->isa<abstract::AbstractRowTensor>()) {
54     return abs->cast<abstract::AbstractRowTensorPtr>()->BuildValue();
55   } else if (abs->isa<abstract::AbstractCOOTensor>()) {
56     return abs->cast<abstract::AbstractCOOTensorPtr>()->BuildValue();
57   } else if (abs->isa<abstract::AbstractCSRTensor>()) {
58     return abs->cast<abstract::AbstractCSRTensorPtr>()->BuildValue();
59   } else if (abs->isa<abstract::AbstractMapTensor>()) {
60     return kValueAny;
61   } else {
62     MS_LOG(DEBUG) << "Unsupported abstract type for primitive, the abs is " << abs->ToString();
63     return kValueAny;
64   }
65 }
66 
CallPyInferFunc(const PrimitivePtr & primitive,const FrontendOpRunInfoPtr & op_run_info)67 void CallPyInferFunc(const PrimitivePtr &primitive, const FrontendOpRunInfoPtr &op_run_info) {
68   const AbstractBasePtrList &arg_spec = op_run_info->op_grad_info->input_abs;
69   auto py_infer_args = PreparePyInputs(arg_spec);
70   auto prim_py = dyn_cast<PrimitivePy>(primitive);
71   MS_EXCEPTION_IF_NULL(prim_py);
72   if (primitive->prim_type() == kPrimTypePyCheck) {
73     prim_py->RunCheck(py_infer_args);
74     return;
75   }
76   auto py_infer_result = prim_py->RunInfer(py_infer_args);
77   auto abs = abstract::PyInferRes2Abstract(prim_py, py_infer_result);
78   primitive->EndRecordAddAttr();
79   op_run_info->base_op_run_info.abstract = abs;
80 }
81 
GetPyNativePrimitiveInferImpl(const PrimitivePtr & primitive)82 std::optional<abstract::StandardPrimitiveImplReg> GetPyNativePrimitiveInferImpl(const PrimitivePtr &primitive) {
83   auto iter = abstract::GetFrontendPrimitiveInferMap().find(primitive);
84   if (iter != abstract::GetFrontendPrimitiveInferMap().end()) {
85     return iter->second;
86   }
87 
88   return abstract::GetPrimitiveInferImpl(primitive);
89 }
90 }  // namespace
91 
InferByOpDef(const FrontendOpRunInfoPtr & op_run_info)92 bool InferByOpDef(const FrontendOpRunInfoPtr &op_run_info) {
93   const auto &prim = op_run_info->op_grad_info->op_prim;
94   auto frontend_func_impl = mindspore::ops::GetOpFrontendFuncImplPtr(prim->name());
95   if (frontend_func_impl) {
96     op_run_info->base_op_run_info.abstract =
97       frontend_func_impl->InferAbstract(prim, op_run_info->op_grad_info->input_abs);
98     if (op_run_info->base_op_run_info.abstract != nullptr) {
99       MS_LOG(DEBUG) << "Pynative Infer by InferAbstract, got abstract: "
100                     << op_run_info->base_op_run_info.abstract->ToString();
101       return true;
102     }
103   }
104 
105   auto op_def = mindspore::ops::GetOpDef(prim->name());
106   if (op_def) {
107     (void)op_def->func_impl_.CheckValidation(prim, op_run_info->op_grad_info->input_abs);
108     auto shape = op_def->func_impl_.InferShape(prim, op_run_info->op_grad_info->input_abs);
109     auto type = op_def->func_impl_.InferType(prim, op_run_info->op_grad_info->input_abs);
110     op_run_info->base_op_run_info.abstract = mindspore::abstract::MakeAbstract(shape, type);
111     MS_LOG(DEBUG) << "Pynative Infer by OpDef, got abstract: " << op_run_info->base_op_run_info.abstract->ToString();
112     return true;
113   }
114 
115   return false;
116 }
117 
PynativeInfer(const FrontendOpRunInfoPtr & op_run_info) const118 void InferOperation::PynativeInfer(const FrontendOpRunInfoPtr &op_run_info) const {
119   MS_EXCEPTION_IF_NULL(op_run_info);
120   MS_LOG(DEBUG) << "Op " << op_run_info->base_op_run_info.op_name
121                 << " infer input: " << mindspore::ToString(op_run_info->op_grad_info->input_abs);
122   const auto &prim = op_run_info->op_grad_info->op_prim;
123   MS_EXCEPTION_IF_NULL(prim);
124   op_run_info->base_op_run_info.abstract = nullptr;
125 
126   prim->BeginRecordAddAttr();
127 
128   if (InferByOpDef(op_run_info)) {
129     prim->EndRecordAddAttr();
130     return;
131   }
132 
133   auto eval_impl = GetPyNativePrimitiveInferImpl(prim);
134   if (eval_impl.has_value()) {
135     // the WhileList ops should be constant fold in Pynative mode.
136     if (!eval_impl->IsInWhiteList() && eval_impl->IsImplInferValue()) {
137       auto value = eval_impl->InferValue(prim, op_run_info->op_grad_info->input_abs);
138       if (value != nullptr && !value->isa<ValueAny>()) {
139         op_run_info->base_op_run_info.abstract = value->ToAbstract();
140         prim->EndRecordAddAttr();
141         return;
142       }
143     }
144 
145     op_run_info->base_op_run_info.abstract =
146       eval_impl->InferShapeAndType(nullptr, prim, op_run_info->op_grad_info->input_abs);
147     prim->EndRecordAddAttr();
148     return;
149   }
150 
151   // Only cache the abstract when the primitive should call the python code.
152   if (GetOutputAbstractByCache(op_run_info)) {
153     prim->EndRecordAddAttr();
154     return;
155   }
156 
157   // call python infer
158   py::gil_scoped_acquire acquire;
159   CallPyInferFunc(prim, op_run_info);
160   MS_EXCEPTION_IF_NULL(op_run_info->base_op_run_info.abstract);
161   prim->EndRecordAddAttr();
162 }
163 
DoInfer(const FrontendOpRunInfoPtr & op_run_info)164 void InferOperation::DoInfer(const FrontendOpRunInfoPtr &op_run_info) {
165   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeInfer,
166                                      op_run_info->base_op_run_info.op_name, true);
167   if (op_run_info->is_view_op) {
168     if (op_run_info->requires_grad) {
169       SetInputAbstract(op_run_info);
170     }
171     return;
172   }
173   SetInputAbstract(op_run_info);
174   if (!op_run_info->is_view_op) {
175     InferOutputAbstract(op_run_info);
176   }
177 }
178 
SetInputAbstract(const FrontendOpRunInfoPtr & op_run_info)179 void InferOperation::SetInputAbstract(const FrontendOpRunInfoPtr &op_run_info) {
180   // Get input abstract by input value and set it to `op_run_info`.
181   MS_EXCEPTION_IF_NULL(op_run_info);
182   op_run_info->input_value_id.resize(op_run_info->input_size);
183   op_run_info->op_grad_info->input_abs.resize(op_run_info->input_size);
184   for (size_t i = 0; i < op_run_info->input_size; ++i) {
185     op_run_info->op_grad_info->input_abs[i] =
186       GetInputValueAbs(op_run_info, op_run_info->op_grad_info->input_value[i], i);
187   }
188 }
189 
GetInputValueAbs(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & input_value,size_t input_index)190 AbstractBasePtr InferOperation::GetInputValueAbs(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &input_value,
191                                                  size_t input_index) {
192   // Get tuple or list abs
193   MS_EXCEPTION_IF_NULL(input_value);
194   if (input_value->isa<tensor::BaseTensor>()) {
195     op_run_info->input_value_id[input_index] = PyNativeAlgo::Common::GetIdByValue(input_value);
196   }
197   if (input_value->isa<ValueSequence>()) {
198     const auto &tuple_value = input_value->cast<ValueSequencePtr>();
199     return GetInputTupleValueAbstract(op_run_info, tuple_value, input_index);
200   }
201   // Get non-tuple and non-list abs.
202   const auto &abs = GetAbstractByValue(input_value, input_index, op_run_info->input_value_id[input_index]);
203   MS_LOG(DEBUG) << "Get abstract of input " << input_index << " is " << abs->ToString() << ", id "
204                 << op_run_info->input_value_id[input_index];
205   return abs;
206 }
207 
GetInputTupleValueAbstract(const FrontendOpRunInfoPtr & op_run_info,const ValueSequencePtr & tuple_value,size_t input_index)208 AbstractBasePtr InferOperation::GetInputTupleValueAbstract(const FrontendOpRunInfoPtr &op_run_info,
209                                                            const ValueSequencePtr &tuple_value, size_t input_index) {
210   // Create abstract list for tuple input.
211   MS_EXCEPTION_IF_NULL(tuple_value);
212   size_t tuple_value_size = tuple_value->size();
213   abstract::AbstractBasePtrList abs_list(tuple_value_size);
214   for (size_t i = 0; i < tuple_value_size; ++i) {
215     const auto &item = tuple_value->value()[i];
216     const auto &item_id = (item->isa<tensor::BaseTensor>() ? PyNativeAlgo::Common::GetIdByValue(item) : "");
217     abs_list[i] = GetAbstractByValue(item, input_index, item_id);
218   }
219   // Create output abstract by value type.
220   AbstractBasePtr abs;
221   if (tuple_value->isa<ValueTuple>()) {
222     abs = std::make_shared<abstract::AbstractTuple>(abs_list);
223   } else {
224     abs = std::make_shared<abstract::AbstractList>(abs_list);
225   }
226   MS_LOG(DEBUG) << "Get abstract of tuple input " << input_index << " is " << abs->ToString() << ", id "
227                 << op_run_info->input_value_id[input_index];
228   return abs;
229 }
230 
GetAbstractByValue(const ValuePtr & value,size_t input_index,const std::string & input_id)231 AbstractBasePtr InferOperation::GetAbstractByValue(const ValuePtr &value, size_t input_index,
232                                                    const std::string &input_id) {
233   MS_EXCEPTION_IF_NULL(value);
234   if (value->isa<tensor::BaseTensor>()) {
235     auto cache_abs = GetNodeAbsById(input_id);
236     if (cache_abs != nullptr) {
237       MS_LOG(DEBUG) << "The abstract of input " << input_index << " hits cache.";
238       return cache_abs;
239     }
240   }
241 
242   // Get abstract by input value.
243   const auto &abs = value->ToAbstract();
244   if (value->isa<tensor::BaseTensor>()) {
245     SetNodeAbsById(input_id, PyNativeAlgo::Common::SetAbstractValueToAnyValue(abs));
246   }
247   return abs;
248 }
249 
InferOutputAbstract(const FrontendOpRunInfoPtr & op_run_info)250 void InferOperation::InferOutputAbstract(const FrontendOpRunInfoPtr &op_run_info) {
251   // Step 1 : Infer output abstract.
252   MS_EXCEPTION_IF_NULL(op_run_info);
253   PynativeInfer(op_run_info);
254   MS_EXCEPTION_IF_NULL(op_run_info->base_op_run_info.abstract);
255   MS_LOG(DEBUG) << "Op " << op_run_info->base_op_run_info.op_name
256                 << " infer result: " << op_run_info->base_op_run_info.abstract->ToString();
257   // Step 2: Check whether output shape is dynamic.
258   const auto &shape = op_run_info->base_op_run_info.abstract->BuildShape();
259   MS_EXCEPTION_IF_NULL(shape);
260   op_run_info->base_op_run_info.has_dynamic_output = shape->IsDynamic();
261   MS_LOG(DEBUG) << "Op " << op_run_info->base_op_run_info.op_name << " is dynamic "
262                 << op_run_info->base_op_run_info.has_dynamic_output;
263 
264   // Step 3: Get infer value from output abstract.
265   auto infer_value = GetInferValueFromAbstract(op_run_info->base_op_run_info.abstract);
266   MS_EXCEPTION_IF_NULL(infer_value);
267   if (!infer_value->ContainsValueAny()) {
268     MS_LOG(DEBUG) << "Get output by constant folding, output is " << infer_value->ToString();
269     op_run_info->output_get_by_infer_value = true;
270     op_run_info->should_be_cache = false;
271   } else if (op_run_info->op_grad_info->op_prim->const_prim()) {
272     MS_LOG(DEBUG) << "Get output by const prim.";
273     op_run_info->output_get_by_infer_value = true;
274     op_run_info->should_be_cache = false;
275     infer_value = MakeValue("");
276   } else if (op_run_info->should_be_cache) {
277     // Cache output abstract, the const infer value needs to infer every step.
278     SaveOutputAbstractToCache(op_run_info);
279   }
280   op_run_info->real_out = infer_value;
281 }
282 
GetOutputAbstractByCache(const FrontendOpRunInfoPtr & op_run_info) const283 bool InferOperation::GetOutputAbstractByCache(const FrontendOpRunInfoPtr &op_run_info) const {
284   MS_EXCEPTION_IF_NULL(op_run_info);
285   const auto &prim = op_run_info->op_grad_info->op_prim;
286   MS_EXCEPTION_IF_NULL(prim);
287 
288   AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
289   auto prim_iter = prim_abs_list_.find(key);
290   if (prim_iter != prim_abs_list_.end()) {
291     MS_LOG(DEBUG) << "Output abstract cache matched prim " << prim->name();
292     const auto &input_abs_map = prim_iter->second;
293     auto abs_iter = input_abs_map.find(op_run_info->op_grad_info->input_abs);
294     if (abs_iter != input_abs_map.end()) {
295       MS_EXCEPTION_IF_NULL(abs_iter->second.abs);
296       MS_LOG(DEBUG) << "From output abstract cache get output abs " << abs_iter->second.abs->ToString();
297       op_run_info->base_op_run_info.abstract = abs_iter->second.abs;
298       prim->set_evaluate_added_attrs(abs_iter->second.attrs);
299       op_run_info->should_be_cache = false;
300       return true;
301     }
302   }
303   op_run_info->should_be_cache = true;
304   return false;
305 }
306 
SaveOutputAbstractToCache(const FrontendOpRunInfoPtr & op_run_info)307 void InferOperation::SaveOutputAbstractToCache(const FrontendOpRunInfoPtr &op_run_info) {
308   MS_EXCEPTION_IF_NULL(op_run_info);
309   const auto &prim = op_run_info->op_grad_info->op_prim;
310   MS_EXCEPTION_IF_NULL(prim);
311   AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
312   auto &out = prim_abs_list_[key];
313   out[op_run_info->op_grad_info->input_abs].abs = op_run_info->base_op_run_info.abstract;
314   out[op_run_info->op_grad_info->input_abs].attrs = prim->evaluate_added_attrs();
315 }
316 
SetNodeAbsCacheByValue(const FrontendOpRunInfoPtr & op_run_info)317 void InferOperation::SetNodeAbsCacheByValue(const FrontendOpRunInfoPtr &op_run_info) {
318   SetNodeAbsById(op_run_info->out_value_id,
319                  PyNativeAlgo::Common::SetAbstractValueToAnyValue(op_run_info->base_op_run_info.abstract));
320   // If value is a `value tuple` or `value list`, cache the abstract of each element value.
321   if (op_run_info->real_out->isa<ValueSequence>()) {
322     const auto &seq_value = op_run_info->real_out->cast<ValueSequencePtr>();
323     const auto &seq_abs = op_run_info->base_op_run_info.abstract->cast<abstract::AbstractSequencePtr>();
324     MS_EXCEPTION_IF_NULL(seq_abs);
325 
326     const auto &value_elems = seq_value->value();
327     const auto &abs_elems = seq_abs->elements();
328     size_t num = value_elems.size();
329     if (num != abs_elems.size()) {
330       SaveSpecifiedOutputToCache(op_run_info->base_op_run_info.op_name, value_elems, abs_elems);
331       MS_LOG(DEBUG) << "The size of value " << num << " is not equal to the size of abstract " << abs_elems.size();
332       return;
333     }
334     for (size_t i = 0; i < num; ++i) {
335       SetNodeAbsById(PyNativeAlgo::Common::GetIdByValue(value_elems[i]), abs_elems[i]);
336     }
337   }
338   // If Just call run op and have no cell or function running, node_abs_cache_ will not be clear.
339   // So, set a threshold for clear it.
340   if (node_abs_cache_.size() > kCacheThreshold) {
341     std::unique_lock lock(abs_mutex_);
342     node_abs_cache_.clear();
343   }
344 }
345 
SaveSpecifiedOutputToCache(const std::string & op_name,const ValuePtrList & value_list,const AbstractBasePtrList & abs_list)346 void InferOperation::SaveSpecifiedOutputToCache(const std::string &op_name, const ValuePtrList &value_list,
347                                                 const AbstractBasePtrList &abs_list) {
348   if (value_list.empty() || abs_list.empty()) {
349     return;
350   }
351   // BatchNormal forward only use first output
352   if (op_name == kBatchNormOpName) {
353     SetNodeAbsById(PyNativeAlgo::Common::GetIdByValue(value_list[0]), abs_list[0]);
354   }
355 }
356 
SetNodeAbsCacheById(const std::string & id,const abstract::AbstractBasePtr & abs)357 void InferOperation::SetNodeAbsCacheById(const std::string &id, const abstract::AbstractBasePtr &abs) {
358   SetNodeAbsById(id, PyNativeAlgo::Common::SetAbstractValueToAnyValue(abs));
359 }
360 
UpdateNodeAbsCacheById(const std::string & id,const abstract::AbstractBasePtr & abs)361 void InferOperation::UpdateNodeAbsCacheById(const std::string &id, const abstract::AbstractBasePtr &abs) {
362   std::unique_lock lock(abs_mutex_);
363   (void)node_abs_cache_.erase(id);
364   node_abs_cache_[id] = abs;
365 }
366 
GetNodeAbsById(const std::string & id) const367 AbstractBasePtr InferOperation::GetNodeAbsById(const std::string &id) const {
368   // GetNodeAbsById is used in NewGraph, need to release gil to avoid deadlock.
369   GilReleaseWithCheck release_gil;
370   std::shared_lock lock(abs_mutex_);
371   auto iter = node_abs_cache_.find(id);
372   if (iter == node_abs_cache_.end()) {
373     return nullptr;
374   }
375   return iter->second;
376 }
377 
SetNodeAbsById(const std::string & id,const abstract::AbstractBasePtr & abs)378 void InferOperation::SetNodeAbsById(const std::string &id, const abstract::AbstractBasePtr &abs) {
379   std::unique_lock lock(abs_mutex_);
380   node_abs_cache_[id] = abs;
381 }
382 
CallConstantFolding(const py::args & args) const383 py::object InferOperation::CallConstantFolding(const py::args &args) const {
384   const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
385   PyNativeAlgo::PyParser::SetPrim(op_run_info, args[0]);
386   op_run_info->base_op_run_info.op_name = op_run_info->op_grad_info->op_prim->name();
387   const auto &v = PyNativeAlgo::DataConvert::PyObjToValue(args[1]);
388   (void)op_run_info->op_grad_info->input_abs.emplace_back(v->ToAbstract());
389   PynativeInfer(op_run_info);
390   auto infer_value = GetInferValueFromAbstract(op_run_info->base_op_run_info.abstract);
391   if (infer_value->ContainsValueAny()) {
392     MS_LOG(EXCEPTION) << "Can not get value from abstract";
393   }
394   return PyNativeAlgo::DataConvert::ValueToPyObj(infer_value);
395 }
396 }  // namespace pynative
397 }  // namespace mindspore
398