• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/predict_out_type_map.h"
18 #include <string>
19 #include <vector>
20 #include "ops/op_def.h"
21 
22 namespace mindspore {
23 namespace pynative {
24 namespace {
PredictOutTypeByOutputNum(const int64_t & output_num)25 inline TypePtr PredictOutTypeByOutputNum(const int64_t &output_num) {
26   static const std::vector<TypePtr> types({kTuple, kTensorType, kTupleTensor2, kTupleTensor3, kTupleTensor4,
27                                            kTupleTensor5, kTupleTensor6, kTupleTensor7, kTupleTensor8, kTupleTensor9});
28   constexpr int64_t kZero = 0;
29   constexpr int64_t kTen = 10;
30   if (output_num > kZero && output_num < kTen) {
31     return types[static_cast<size_t>(output_num)];
32   }
33   return kTuple;
34 }
35 }  // namespace
36 
PredictOutTypeByOpDef(const ops::OpDefPtr & op_def)37 TypePtr PredictOutTypeByOpDef(const ops::OpDefPtr &op_def) {
38   auto returns_num = op_def->returns_.size();
39   if (returns_num == 1) {
40     if (op_def->returns_[0].arg_dtype_ == ops::OP_DTYPE::DT_TENSOR) {
41       return kTensorType;
42     }
43 
44     if (op_def->returns_[0].arg_dtype_ == ops::OP_DTYPE::DT_LIST_TENSOR ||
45         op_def->returns_[0].arg_dtype_ == ops::OP_DTYPE::DT_TUPLE_TENSOR) {
46       return kTuple;
47     }
48 
49     return kTypeNone;
50   }
51 
52   static const std::vector<TypePtr> kSequenceTypes = {
53     kTuple,  // this is only a placeholder
54     kTuple,  // this is only a placeholder
55     kTupleTensor2, kTupleTensor3, kTupleTensor4, kTupleTensor5,
56     kTupleTensor6, kTupleTensor7, kTupleTensor8, kTupleTensor9,
57   };
58 
59   if (returns_num >= kSequenceTypes.size()) {
60     MS_LOG(EXCEPTION) << "For " << op_def->name_ << ", the number of output must be less than " << kSequenceTypes.size()
61                       << ", but got " << returns_num << ".";
62   }
63 
64   return kSequenceTypes[returns_num];
65 }
66 
PredictOutTypeByName(const std::string & op_name)67 TypePtr PredictOutTypeByName(const std::string &op_name) {
68   static PredictOutTypeMap ops_map{};
69   const auto iter = ops_map.find(op_name);
70   if (iter != ops_map.end()) {
71     return iter->second;
72   }
73   auto op_def = ops::GetOpDef(op_name);
74   if (op_def != nullptr) {
75     auto type = PredictOutTypeByOpDef(op_def);
76     MS_LOG(DEBUG) << "PredictOutTypeByOpDef: " << type->ToString();
77     return ops_map[op_name] = type;
78   }
79 
80   static auto operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
81   if (operator_fns.find(op_name) == operator_fns.end()) {
82     return ops_map[op_name] = kTypeNone;
83   }
84   const auto pre_iter = out_type_prediction.find(op_name);
85   auto type = pre_iter == out_type_prediction.end() ? kTensorType : pre_iter->second;
86   return ops_map[op_name] = type;
87 }
88 
PredictOutType(const FrontendOpRunInfoPtr & op_run_info)89 TypePtr PredictOutType(const FrontendOpRunInfoPtr &op_run_info) {
90   const auto &op_name = op_run_info->base_op_run_info.op_name;
91   auto type = PredictOutTypeByName(op_name);
92   if (type == kTypeAny) {
93     const auto &op_prim = op_run_info->op_grad_info->op_prim;
94     if (const auto &attr = op_prim->GetAttr("output_num"); attr != nullptr) {
95       type = PredictOutTypeByOutputNum(GetValue<int64_t>(attr));
96     }
97   }
98   return type;
99 }
100 }  // namespace pynative
101 }  // namespace mindspore
102