1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/ops_func_impl/simple_infer.h"
18
19 namespace mindspore::ops {
Instance()20 SimpleInfer &SimpleInfer::Instance() noexcept {
21 static SimpleInfer instance;
22 return instance;
23 }
24
GetFunc(const string & op_name)25 ops::OpFuncImplPtr SimpleInfer::GetFunc(const string &op_name) {
26 auto iter = simple_infer_fun_.find(op_name);
27 if (iter == simple_infer_fun_.end()) {
28 return nullptr;
29 }
30 return iter->second;
31 }
32
Register(const std::string & op_name,ops::OpFuncImplPtr && func)33 void SimpleInfer::Register(const std::string &op_name, ops::OpFuncImplPtr &&func) {
34 MS_LOG(DEBUG) << "Reg simple infer for op " << op_name;
35 auto ret = simple_infer_fun_.try_emplace(op_name, func);
36 if (!ret.second) {
37 MS_LOG(WARNING) << "Duplicate simpler infer for " << op_name;
38 }
39 }
40
DoSimpleInfer(const PrimitivePtr & primitive,const ValueSimpleInfoPtr & value_simple_info,const ops::OpFuncImplPtr & simple_infer_func,const ValuePtrList & input_values)41 void SimpleInfer::DoSimpleInfer(const PrimitivePtr &primitive, const ValueSimpleInfoPtr &value_simple_info,
42 const ops::OpFuncImplPtr &simple_infer_func, const ValuePtrList &input_values) {
43 value_simple_info->shape_vector_ = simple_infer_func->InferShape(primitive, input_values);
44 value_simple_info->dtype_vector_ = simple_infer_func->InferType(primitive, input_values);
45 value_simple_info->size_ = value_simple_info->shape_vector_.size();
46 if (value_simple_info->size_ != value_simple_info->dtype_vector_.size()) {
47 MS_LOG(EXCEPTION) << "Infer shape size " << value_simple_info->size_ << " is not equal to dtype size "
48 << value_simple_info->dtype_vector_.size();
49 }
50 }
51
InferBySimple(const PrimitivePtr & primitive,const ValuePtrList & input_values)52 ValueSimpleInfoPtr InferBySimple(const PrimitivePtr &primitive, const ValuePtrList &input_values) {
53 const auto &simple_infer_func = SimpleInfer::Instance().GetFunc(primitive->name());
54 if (simple_infer_func == nullptr) {
55 return nullptr;
56 }
57 auto value_simple_info = std::make_shared<ValueSimpleInfo>();
58 SimpleInfer::Instance().DoSimpleInfer(primitive, value_simple_info, simple_infer_func, input_values);
59 return value_simple_info;
60 }
61 } // namespace mindspore::ops
62