1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/predict_out_type_map.h"
18 #include <string>
19 #include <vector>
20 #include "ops/op_def.h"
21
22 namespace mindspore {
23 namespace pynative {
24 namespace {
PredictOutTypeByOutputNum(const int64_t & output_num)25 inline TypePtr PredictOutTypeByOutputNum(const int64_t &output_num) {
26 static const std::vector<TypePtr> types({kTuple, kTensorType, kTupleTensor2, kTupleTensor3, kTupleTensor4,
27 kTupleTensor5, kTupleTensor6, kTupleTensor7, kTupleTensor8, kTupleTensor9});
28 constexpr int64_t kZero = 0;
29 constexpr int64_t kTen = 10;
30 if (output_num > kZero && output_num < kTen) {
31 return types[static_cast<size_t>(output_num)];
32 }
33 return kTuple;
34 }
35 } // namespace
36
PredictOutTypeByOpDef(const ops::OpDefPtr & op_def)37 TypePtr PredictOutTypeByOpDef(const ops::OpDefPtr &op_def) {
38 auto returns_num = op_def->returns_.size();
39 if (returns_num == 1) {
40 if (op_def->returns_[0].arg_dtype_ == ops::OP_DTYPE::DT_TENSOR) {
41 return kTensorType;
42 }
43
44 if (op_def->returns_[0].arg_dtype_ == ops::OP_DTYPE::DT_LIST_TENSOR ||
45 op_def->returns_[0].arg_dtype_ == ops::OP_DTYPE::DT_TUPLE_TENSOR) {
46 return kTuple;
47 }
48
49 return kTypeNone;
50 }
51
52 static const std::vector<TypePtr> kSequenceTypes = {
53 kTuple, // this is only a placeholder
54 kTuple, // this is only a placeholder
55 kTupleTensor2, kTupleTensor3, kTupleTensor4, kTupleTensor5,
56 kTupleTensor6, kTupleTensor7, kTupleTensor8, kTupleTensor9,
57 };
58
59 if (returns_num >= kSequenceTypes.size()) {
60 MS_LOG(EXCEPTION) << "For " << op_def->name_ << ", the number of output must be less than " << kSequenceTypes.size()
61 << ", but got " << returns_num << ".";
62 }
63
64 return kSequenceTypes[returns_num];
65 }
66
PredictOutTypeByName(const std::string & op_name)67 TypePtr PredictOutTypeByName(const std::string &op_name) {
68 static PredictOutTypeMap ops_map{};
69 const auto iter = ops_map.find(op_name);
70 if (iter != ops_map.end()) {
71 return iter->second;
72 }
73 auto op_def = ops::GetOpDef(op_name);
74 if (op_def != nullptr) {
75 auto type = PredictOutTypeByOpDef(op_def);
76 MS_LOG(DEBUG) << "PredictOutTypeByOpDef: " << type->ToString();
77 return ops_map[op_name] = type;
78 }
79
80 static auto operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
81 if (operator_fns.find(op_name) == operator_fns.end()) {
82 return ops_map[op_name] = kTypeNone;
83 }
84 const auto pre_iter = out_type_prediction.find(op_name);
85 auto type = pre_iter == out_type_prediction.end() ? kTensorType : pre_iter->second;
86 return ops_map[op_name] = type;
87 }
88
PredictOutType(const FrontendOpRunInfoPtr & op_run_info)89 TypePtr PredictOutType(const FrontendOpRunInfoPtr &op_run_info) {
90 const auto &op_name = op_run_info->base_op_run_info.op_name;
91 auto type = PredictOutTypeByName(op_name);
92 if (type == kTypeAny) {
93 const auto &op_prim = op_run_info->op_grad_info->op_prim;
94 if (const auto &attr = op_prim->GetAttr("output_num"); attr != nullptr) {
95 type = PredictOutTypeByOutputNum(GetValue<int64_t>(attr));
96 }
97 }
98 return type;
99 }
100 } // namespace pynative
101 } // namespace mindspore
102