• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/ops_func_impl/simple_infer.h"
18 
19 namespace mindspore::ops {
Instance()20 SimpleInfer &SimpleInfer::Instance() noexcept {
21   static SimpleInfer instance;
22   return instance;
23 }
24 
GetFunc(const string & op_name)25 ops::OpFuncImplPtr SimpleInfer::GetFunc(const string &op_name) {
26   auto iter = simple_infer_fun_.find(op_name);
27   if (iter == simple_infer_fun_.end()) {
28     return nullptr;
29   }
30   return iter->second;
31 }
32 
Register(const std::string & op_name,ops::OpFuncImplPtr && func)33 void SimpleInfer::Register(const std::string &op_name, ops::OpFuncImplPtr &&func) {
34   MS_LOG(DEBUG) << "Reg simple infer for op " << op_name;
35   auto ret = simple_infer_fun_.try_emplace(op_name, func);
36   if (!ret.second) {
37     MS_LOG(WARNING) << "Duplicate simpler infer for " << op_name;
38   }
39 }
40 
DoSimpleInfer(const PrimitivePtr & primitive,const ValueSimpleInfoPtr & value_simple_info,const ops::OpFuncImplPtr & simple_infer_func,const ValuePtrList & input_values)41 void SimpleInfer::DoSimpleInfer(const PrimitivePtr &primitive, const ValueSimpleInfoPtr &value_simple_info,
42                                 const ops::OpFuncImplPtr &simple_infer_func, const ValuePtrList &input_values) {
43   value_simple_info->shape_vector_ = simple_infer_func->InferShape(primitive, input_values);
44   value_simple_info->dtype_vector_ = simple_infer_func->InferType(primitive, input_values);
45   value_simple_info->size_ = value_simple_info->shape_vector_.size();
46   if (value_simple_info->size_ != value_simple_info->dtype_vector_.size()) {
47     MS_LOG(EXCEPTION) << "Infer shape size " << value_simple_info->size_ << " is not equal to dtype size "
48                       << value_simple_info->dtype_vector_.size();
49   }
50 }
51 
InferBySimple(const PrimitivePtr & primitive,const ValuePtrList & input_values)52 ValueSimpleInfoPtr InferBySimple(const PrimitivePtr &primitive, const ValuePtrList &input_values) {
53   const auto &simple_infer_func = SimpleInfer::Instance().GetFunc(primitive->name());
54   if (simple_infer_func == nullptr) {
55     return nullptr;
56   }
57   auto value_simple_info = std::make_shared<ValueSimpleInfo>();
58   SimpleInfer::Instance().DoSimpleInfer(primitive, value_simple_info, simple_infer_func, input_values);
59   return value_simple_info;
60 }
61 }  // namespace mindspore::ops
62