1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/forward/do_infer.h"
18 #include "pipeline/pynative/pynative_utils.h"
19 #include "frontend/operator/ops_front_infer_function.h"
20 #include "pybind_api/gil_scoped_long_running.h"
21 #include "include/common/profiler.h"
22 #include "ops/nn_op_name.h"
23 #include "ops/ops_frontend_func_impl.h"
24
25 namespace mindspore {
26 namespace pynative {
27 namespace {
28 constexpr size_t kCacheThreshold = 10000;
29
GetInferValueFromAbstract(const AbstractBasePtr & abs)30 ValuePtr GetInferValueFromAbstract(const AbstractBasePtr &abs) {
31 MS_EXCEPTION_IF_NULL(abs);
32 if (abs->isa<abstract::AbstractTensor>()) {
33 return abs->cast<abstract::AbstractTensorPtr>()->BuildValue();
34 } else if (abs->isa<abstract::AbstractSlice>()) {
35 return abs->cast<abstract::AbstractSlicePtr>()->BuildValue();
36 } else if (abs->isa<abstract::AbstractScalar>() || abs->isa<abstract::AbstractType>()) {
37 return abs->BuildValue();
38 } else if (abs->isa<abstract::AbstractTuple>()) {
39 auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
40 const auto &value = tuple_abs->BuildValue();
41 MS_EXCEPTION_IF_NULL(value);
42 if (value->isa<ValueAny>()) {
43 return value;
44 }
45 return tuple_abs->ElementsBuildValue<ValueTuple>();
46 } else if (abs->isa<abstract::AbstractList>()) {
47 auto list_abs = abs->cast<abstract::AbstractListPtr>();
48 const auto &value = list_abs->BuildValue();
49 if (value->isa<ValueAny>()) {
50 return value;
51 }
52 return list_abs->ElementsBuildValue<ValueList>();
53 } else if (abs->isa<abstract::AbstractRowTensor>()) {
54 return abs->cast<abstract::AbstractRowTensorPtr>()->BuildValue();
55 } else if (abs->isa<abstract::AbstractCOOTensor>()) {
56 return abs->cast<abstract::AbstractCOOTensorPtr>()->BuildValue();
57 } else if (abs->isa<abstract::AbstractCSRTensor>()) {
58 return abs->cast<abstract::AbstractCSRTensorPtr>()->BuildValue();
59 } else if (abs->isa<abstract::AbstractMapTensor>()) {
60 return kValueAny;
61 } else {
62 MS_LOG(DEBUG) << "Unsupported abstract type for primitive, the abs is " << abs->ToString();
63 return kValueAny;
64 }
65 }
66
CallPyInferFunc(const PrimitivePtr & primitive,const FrontendOpRunInfoPtr & op_run_info)67 void CallPyInferFunc(const PrimitivePtr &primitive, const FrontendOpRunInfoPtr &op_run_info) {
68 const AbstractBasePtrList &arg_spec = op_run_info->op_grad_info->input_abs;
69 auto py_infer_args = PreparePyInputs(arg_spec);
70 auto prim_py = dyn_cast<PrimitivePy>(primitive);
71 MS_EXCEPTION_IF_NULL(prim_py);
72 if (primitive->prim_type() == kPrimTypePyCheck) {
73 prim_py->RunCheck(py_infer_args);
74 return;
75 }
76 auto py_infer_result = prim_py->RunInfer(py_infer_args);
77 auto abs = abstract::PyInferRes2Abstract(prim_py, py_infer_result);
78 primitive->EndRecordAddAttr();
79 op_run_info->base_op_run_info.abstract = abs;
80 }
81
GetPyNativePrimitiveInferImpl(const PrimitivePtr & primitive)82 std::optional<abstract::StandardPrimitiveImplReg> GetPyNativePrimitiveInferImpl(const PrimitivePtr &primitive) {
83 auto iter = abstract::GetFrontendPrimitiveInferMap().find(primitive);
84 if (iter != abstract::GetFrontendPrimitiveInferMap().end()) {
85 return iter->second;
86 }
87
88 return abstract::GetPrimitiveInferImpl(primitive);
89 }
90 } // namespace
91
InferByOpDef(const FrontendOpRunInfoPtr & op_run_info)92 bool InferByOpDef(const FrontendOpRunInfoPtr &op_run_info) {
93 const auto &prim = op_run_info->op_grad_info->op_prim;
94 auto frontend_func_impl = mindspore::ops::GetOpFrontendFuncImplPtr(prim->name());
95 if (frontend_func_impl) {
96 op_run_info->base_op_run_info.abstract =
97 frontend_func_impl->InferAbstract(prim, op_run_info->op_grad_info->input_abs);
98 if (op_run_info->base_op_run_info.abstract != nullptr) {
99 MS_LOG(DEBUG) << "Pynative Infer by InferAbstract, got abstract: "
100 << op_run_info->base_op_run_info.abstract->ToString();
101 return true;
102 }
103 }
104
105 auto op_def = mindspore::ops::GetOpDef(prim->name());
106 if (op_def) {
107 (void)op_def->func_impl_.CheckValidation(prim, op_run_info->op_grad_info->input_abs);
108 auto shape = op_def->func_impl_.InferShape(prim, op_run_info->op_grad_info->input_abs);
109 auto type = op_def->func_impl_.InferType(prim, op_run_info->op_grad_info->input_abs);
110 op_run_info->base_op_run_info.abstract = mindspore::abstract::MakeAbstract(shape, type);
111 MS_LOG(DEBUG) << "Pynative Infer by OpDef, got abstract: " << op_run_info->base_op_run_info.abstract->ToString();
112 return true;
113 }
114
115 return false;
116 }
117
PynativeInfer(const FrontendOpRunInfoPtr & op_run_info) const118 void InferOperation::PynativeInfer(const FrontendOpRunInfoPtr &op_run_info) const {
119 MS_EXCEPTION_IF_NULL(op_run_info);
120 MS_LOG(DEBUG) << "Op " << op_run_info->base_op_run_info.op_name
121 << " infer input: " << mindspore::ToString(op_run_info->op_grad_info->input_abs);
122 const auto &prim = op_run_info->op_grad_info->op_prim;
123 MS_EXCEPTION_IF_NULL(prim);
124 op_run_info->base_op_run_info.abstract = nullptr;
125
126 prim->BeginRecordAddAttr();
127
128 if (InferByOpDef(op_run_info)) {
129 prim->EndRecordAddAttr();
130 return;
131 }
132
133 auto eval_impl = GetPyNativePrimitiveInferImpl(prim);
134 if (eval_impl.has_value()) {
135 // the WhileList ops should be constant fold in Pynative mode.
136 if (!eval_impl->IsInWhiteList() && eval_impl->IsImplInferValue()) {
137 auto value = eval_impl->InferValue(prim, op_run_info->op_grad_info->input_abs);
138 if (value != nullptr && !value->isa<ValueAny>()) {
139 op_run_info->base_op_run_info.abstract = value->ToAbstract();
140 prim->EndRecordAddAttr();
141 return;
142 }
143 }
144
145 op_run_info->base_op_run_info.abstract =
146 eval_impl->InferShapeAndType(nullptr, prim, op_run_info->op_grad_info->input_abs);
147 prim->EndRecordAddAttr();
148 return;
149 }
150
151 // Only cache the abstract when the primitive should call the python code.
152 if (GetOutputAbstractByCache(op_run_info)) {
153 prim->EndRecordAddAttr();
154 return;
155 }
156
157 // call python infer
158 py::gil_scoped_acquire acquire;
159 CallPyInferFunc(prim, op_run_info);
160 MS_EXCEPTION_IF_NULL(op_run_info->base_op_run_info.abstract);
161 prim->EndRecordAddAttr();
162 }
163
DoInfer(const FrontendOpRunInfoPtr & op_run_info)164 void InferOperation::DoInfer(const FrontendOpRunInfoPtr &op_run_info) {
165 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeInfer,
166 op_run_info->base_op_run_info.op_name, true);
167 if (op_run_info->is_view_op) {
168 if (op_run_info->requires_grad) {
169 SetInputAbstract(op_run_info);
170 }
171 return;
172 }
173 SetInputAbstract(op_run_info);
174 if (!op_run_info->is_view_op) {
175 InferOutputAbstract(op_run_info);
176 }
177 }
178
SetInputAbstract(const FrontendOpRunInfoPtr & op_run_info)179 void InferOperation::SetInputAbstract(const FrontendOpRunInfoPtr &op_run_info) {
180 // Get input abstract by input value and set it to `op_run_info`.
181 MS_EXCEPTION_IF_NULL(op_run_info);
182 op_run_info->input_value_id.resize(op_run_info->input_size);
183 op_run_info->op_grad_info->input_abs.resize(op_run_info->input_size);
184 for (size_t i = 0; i < op_run_info->input_size; ++i) {
185 op_run_info->op_grad_info->input_abs[i] =
186 GetInputValueAbs(op_run_info, op_run_info->op_grad_info->input_value[i], i);
187 }
188 }
189
GetInputValueAbs(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & input_value,size_t input_index)190 AbstractBasePtr InferOperation::GetInputValueAbs(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &input_value,
191 size_t input_index) {
192 // Get tuple or list abs
193 MS_EXCEPTION_IF_NULL(input_value);
194 if (input_value->isa<tensor::BaseTensor>()) {
195 op_run_info->input_value_id[input_index] = PyNativeAlgo::Common::GetIdByValue(input_value);
196 }
197 if (input_value->isa<ValueSequence>()) {
198 const auto &tuple_value = input_value->cast<ValueSequencePtr>();
199 return GetInputTupleValueAbstract(op_run_info, tuple_value, input_index);
200 }
201 // Get non-tuple and non-list abs.
202 const auto &abs = GetAbstractByValue(input_value, input_index, op_run_info->input_value_id[input_index]);
203 MS_LOG(DEBUG) << "Get abstract of input " << input_index << " is " << abs->ToString() << ", id "
204 << op_run_info->input_value_id[input_index];
205 return abs;
206 }
207
GetInputTupleValueAbstract(const FrontendOpRunInfoPtr & op_run_info,const ValueSequencePtr & tuple_value,size_t input_index)208 AbstractBasePtr InferOperation::GetInputTupleValueAbstract(const FrontendOpRunInfoPtr &op_run_info,
209 const ValueSequencePtr &tuple_value, size_t input_index) {
210 // Create abstract list for tuple input.
211 MS_EXCEPTION_IF_NULL(tuple_value);
212 size_t tuple_value_size = tuple_value->size();
213 abstract::AbstractBasePtrList abs_list(tuple_value_size);
214 for (size_t i = 0; i < tuple_value_size; ++i) {
215 const auto &item = tuple_value->value()[i];
216 const auto &item_id = (item->isa<tensor::BaseTensor>() ? PyNativeAlgo::Common::GetIdByValue(item) : "");
217 abs_list[i] = GetAbstractByValue(item, input_index, item_id);
218 }
219 // Create output abstract by value type.
220 AbstractBasePtr abs;
221 if (tuple_value->isa<ValueTuple>()) {
222 abs = std::make_shared<abstract::AbstractTuple>(abs_list);
223 } else {
224 abs = std::make_shared<abstract::AbstractList>(abs_list);
225 }
226 MS_LOG(DEBUG) << "Get abstract of tuple input " << input_index << " is " << abs->ToString() << ", id "
227 << op_run_info->input_value_id[input_index];
228 return abs;
229 }
230
GetAbstractByValue(const ValuePtr & value,size_t input_index,const std::string & input_id)231 AbstractBasePtr InferOperation::GetAbstractByValue(const ValuePtr &value, size_t input_index,
232 const std::string &input_id) {
233 MS_EXCEPTION_IF_NULL(value);
234 if (value->isa<tensor::BaseTensor>()) {
235 auto cache_abs = GetNodeAbsById(input_id);
236 if (cache_abs != nullptr) {
237 MS_LOG(DEBUG) << "The abstract of input " << input_index << " hits cache.";
238 return cache_abs;
239 }
240 }
241
242 // Get abstract by input value.
243 const auto &abs = value->ToAbstract();
244 if (value->isa<tensor::BaseTensor>()) {
245 SetNodeAbsById(input_id, PyNativeAlgo::Common::SetAbstractValueToAnyValue(abs));
246 }
247 return abs;
248 }
249
InferOutputAbstract(const FrontendOpRunInfoPtr & op_run_info)250 void InferOperation::InferOutputAbstract(const FrontendOpRunInfoPtr &op_run_info) {
251 // Step 1 : Infer output abstract.
252 MS_EXCEPTION_IF_NULL(op_run_info);
253 PynativeInfer(op_run_info);
254 MS_EXCEPTION_IF_NULL(op_run_info->base_op_run_info.abstract);
255 MS_LOG(DEBUG) << "Op " << op_run_info->base_op_run_info.op_name
256 << " infer result: " << op_run_info->base_op_run_info.abstract->ToString();
257 // Step 2: Check whether output shape is dynamic.
258 const auto &shape = op_run_info->base_op_run_info.abstract->BuildShape();
259 MS_EXCEPTION_IF_NULL(shape);
260 op_run_info->base_op_run_info.has_dynamic_output = shape->IsDynamic();
261 MS_LOG(DEBUG) << "Op " << op_run_info->base_op_run_info.op_name << " is dynamic "
262 << op_run_info->base_op_run_info.has_dynamic_output;
263
264 // Step 3: Get infer value from output abstract.
265 auto infer_value = GetInferValueFromAbstract(op_run_info->base_op_run_info.abstract);
266 MS_EXCEPTION_IF_NULL(infer_value);
267 if (!infer_value->ContainsValueAny()) {
268 MS_LOG(DEBUG) << "Get output by constant folding, output is " << infer_value->ToString();
269 op_run_info->output_get_by_infer_value = true;
270 op_run_info->should_be_cache = false;
271 } else if (op_run_info->op_grad_info->op_prim->const_prim()) {
272 MS_LOG(DEBUG) << "Get output by const prim.";
273 op_run_info->output_get_by_infer_value = true;
274 op_run_info->should_be_cache = false;
275 infer_value = MakeValue("");
276 } else if (op_run_info->should_be_cache) {
277 // Cache output abstract, the const infer value needs to infer every step.
278 SaveOutputAbstractToCache(op_run_info);
279 }
280 op_run_info->real_out = infer_value;
281 }
282
GetOutputAbstractByCache(const FrontendOpRunInfoPtr & op_run_info) const283 bool InferOperation::GetOutputAbstractByCache(const FrontendOpRunInfoPtr &op_run_info) const {
284 MS_EXCEPTION_IF_NULL(op_run_info);
285 const auto &prim = op_run_info->op_grad_info->op_prim;
286 MS_EXCEPTION_IF_NULL(prim);
287
288 AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
289 auto prim_iter = prim_abs_list_.find(key);
290 if (prim_iter != prim_abs_list_.end()) {
291 MS_LOG(DEBUG) << "Output abstract cache matched prim " << prim->name();
292 const auto &input_abs_map = prim_iter->second;
293 auto abs_iter = input_abs_map.find(op_run_info->op_grad_info->input_abs);
294 if (abs_iter != input_abs_map.end()) {
295 MS_EXCEPTION_IF_NULL(abs_iter->second.abs);
296 MS_LOG(DEBUG) << "From output abstract cache get output abs " << abs_iter->second.abs->ToString();
297 op_run_info->base_op_run_info.abstract = abs_iter->second.abs;
298 prim->set_evaluate_added_attrs(abs_iter->second.attrs);
299 op_run_info->should_be_cache = false;
300 return true;
301 }
302 }
303 op_run_info->should_be_cache = true;
304 return false;
305 }
306
SaveOutputAbstractToCache(const FrontendOpRunInfoPtr & op_run_info)307 void InferOperation::SaveOutputAbstractToCache(const FrontendOpRunInfoPtr &op_run_info) {
308 MS_EXCEPTION_IF_NULL(op_run_info);
309 const auto &prim = op_run_info->op_grad_info->op_prim;
310 MS_EXCEPTION_IF_NULL(prim);
311 AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
312 auto &out = prim_abs_list_[key];
313 out[op_run_info->op_grad_info->input_abs].abs = op_run_info->base_op_run_info.abstract;
314 out[op_run_info->op_grad_info->input_abs].attrs = prim->evaluate_added_attrs();
315 }
316
SetNodeAbsCacheByValue(const FrontendOpRunInfoPtr & op_run_info)317 void InferOperation::SetNodeAbsCacheByValue(const FrontendOpRunInfoPtr &op_run_info) {
318 SetNodeAbsById(op_run_info->out_value_id,
319 PyNativeAlgo::Common::SetAbstractValueToAnyValue(op_run_info->base_op_run_info.abstract));
320 // If value is a `value tuple` or `value list`, cache the abstract of each element value.
321 if (op_run_info->real_out->isa<ValueSequence>()) {
322 const auto &seq_value = op_run_info->real_out->cast<ValueSequencePtr>();
323 const auto &seq_abs = op_run_info->base_op_run_info.abstract->cast<abstract::AbstractSequencePtr>();
324 MS_EXCEPTION_IF_NULL(seq_abs);
325
326 const auto &value_elems = seq_value->value();
327 const auto &abs_elems = seq_abs->elements();
328 size_t num = value_elems.size();
329 if (num != abs_elems.size()) {
330 SaveSpecifiedOutputToCache(op_run_info->base_op_run_info.op_name, value_elems, abs_elems);
331 MS_LOG(DEBUG) << "The size of value " << num << " is not equal to the size of abstract " << abs_elems.size();
332 return;
333 }
334 for (size_t i = 0; i < num; ++i) {
335 SetNodeAbsById(PyNativeAlgo::Common::GetIdByValue(value_elems[i]), abs_elems[i]);
336 }
337 }
338 // If Just call run op and have no cell or function running, node_abs_cache_ will not be clear.
339 // So, set a threshold for clear it.
340 if (node_abs_cache_.size() > kCacheThreshold) {
341 std::unique_lock lock(abs_mutex_);
342 node_abs_cache_.clear();
343 }
344 }
345
SaveSpecifiedOutputToCache(const std::string & op_name,const ValuePtrList & value_list,const AbstractBasePtrList & abs_list)346 void InferOperation::SaveSpecifiedOutputToCache(const std::string &op_name, const ValuePtrList &value_list,
347 const AbstractBasePtrList &abs_list) {
348 if (value_list.empty() || abs_list.empty()) {
349 return;
350 }
351 // BatchNormal forward only use first output
352 if (op_name == kBatchNormOpName) {
353 SetNodeAbsById(PyNativeAlgo::Common::GetIdByValue(value_list[0]), abs_list[0]);
354 }
355 }
356
SetNodeAbsCacheById(const std::string & id,const abstract::AbstractBasePtr & abs)357 void InferOperation::SetNodeAbsCacheById(const std::string &id, const abstract::AbstractBasePtr &abs) {
358 SetNodeAbsById(id, PyNativeAlgo::Common::SetAbstractValueToAnyValue(abs));
359 }
360
UpdateNodeAbsCacheById(const std::string & id,const abstract::AbstractBasePtr & abs)361 void InferOperation::UpdateNodeAbsCacheById(const std::string &id, const abstract::AbstractBasePtr &abs) {
362 std::unique_lock lock(abs_mutex_);
363 (void)node_abs_cache_.erase(id);
364 node_abs_cache_[id] = abs;
365 }
366
GetNodeAbsById(const std::string & id) const367 AbstractBasePtr InferOperation::GetNodeAbsById(const std::string &id) const {
368 // GetNodeAbsById is used in NewGraph, need to release gil to avoid deadlock.
369 GilReleaseWithCheck release_gil;
370 std::shared_lock lock(abs_mutex_);
371 auto iter = node_abs_cache_.find(id);
372 if (iter == node_abs_cache_.end()) {
373 return nullptr;
374 }
375 return iter->second;
376 }
377
SetNodeAbsById(const std::string & id,const abstract::AbstractBasePtr & abs)378 void InferOperation::SetNodeAbsById(const std::string &id, const abstract::AbstractBasePtr &abs) {
379 std::unique_lock lock(abs_mutex_);
380 node_abs_cache_[id] = abs;
381 }
382
CallConstantFolding(const py::args & args) const383 py::object InferOperation::CallConstantFolding(const py::args &args) const {
384 const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
385 PyNativeAlgo::PyParser::SetPrim(op_run_info, args[0]);
386 op_run_info->base_op_run_info.op_name = op_run_info->op_grad_info->op_prim->name();
387 const auto &v = PyNativeAlgo::DataConvert::PyObjToValue(args[1]);
388 (void)op_run_info->op_grad_info->input_abs.emplace_back(v->ToAbstract());
389 PynativeInfer(op_run_info);
390 auto infer_value = GetInferValueFromAbstract(op_run_info->base_op_run_info.abstract);
391 if (infer_value->ContainsValueAny()) {
392 MS_LOG(EXCEPTION) << "Can not get value from abstract";
393 }
394 return PyNativeAlgo::DataConvert::ValueToPyObj(infer_value);
395 }
396 } // namespace pynative
397 } // namespace mindspore
398