• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_INFER_MANAGER_H_
18 #define MINDSPORE_LITE_SRC_RUNTIME_INFER_MANAGER_H_
19 
20 #include <map>
21 #include <vector>
22 #include <set>
23 #include <string>
24 #include "src/common/prim_util.h"
25 #include "src/common/common.h"
26 #include "src/tensor.h"
27 #include "nnacl/tensor_c.h"
28 #include "nnacl/infer/infer.h"
29 #include "include/api/kernel.h"
30 
31 namespace mindspore::lite {
32 int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs,
33                      OpParameter *parameter);
34 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
35 int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
36                      const void *primitive, std::set<std::string> &&providers, int schema_version,
37                      const kernel::Kernel *kernel = nullptr);
38 #endif
39 class InferManager {
40  public:
GetInstance()41   static InferManager *GetInstance() {
42     static InferManager instance;
43     return &instance;
44   }
45   virtual ~InferManager() = default;
46 
InsertInferShapeFunc(int prim_type,InferShape func)47   void InsertInferShapeFunc(int prim_type, InferShape func) { infer_shape_funcs_[prim_type] = func; }
48 
GetInferShapeFunc(int prim_type)49   InferShape GetInferShapeFunc(int prim_type) {
50     auto iter = infer_shape_funcs_.find(prim_type);
51     if (iter == infer_shape_funcs_.end()) {
52       return nullptr;
53     }
54     return iter->second;
55   }
56 
57  private:
58   InferManager() = default;
59 
60   std::map<int, InferShape> infer_shape_funcs_;
61 };
62 
63 class RegistryInferShape {
64  public:
RegistryInferShape(int prim_type,InferShape func)65   RegistryInferShape(int prim_type, InferShape func) {
66     InferManager::GetInstance()->InsertInferShapeFunc(prim_type, func);
67   }
68   ~RegistryInferShape() = default;
69 };
70 
71 #define REG_INFER_SHAPE(op, prim_type, func) static RegistryInferShape g_##op##InferShape(prim_type, func);
72 }  // namespace mindspore::lite
73 
74 #endif  // MINDSPORE_LITE_SRC_RUNTIME_INFER_MANAGER_H_
75