• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_SUB_GRAPH_KERNEL_H_
18 #define MINDSPORE_LITE_SRC_SUB_GRAPH_KERNEL_H_
19 
20 #include <atomic>
21 #include <utility>
22 #include <string>
23 #include <vector>
24 #include <map>
25 #include <memory>
26 #include "src/lite_kernel.h"
27 #include "src/executor.h"
28 #include "src/common/log_adapter.h"
29 #include "src/common/version_manager.h"
30 #include "src/cpu_info.h"
31 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
32 #include "nnacl/constant_of_shape_parameter.h"
33 #endif
34 
35 namespace mindspore::kernel {
36 // store origin data and allocator of input tensor of subgraph for PreProcess and PostProcess
37 struct DataStore {
38   void *data_ = nullptr;
39   Allocator *allocator_ = nullptr;
40   bool own_data_ = true;
41   static DataStore *CreateDataStore(void *data = nullptr, bool own_data = true, Allocator *data_allocator = nullptr,
42                                     Allocator *allocator = nullptr) {
43     DataStore *data_store = nullptr;
44     if (allocator == nullptr) {
45       data_store = static_cast<DataStore *>(malloc(sizeof(DataStore)));
46     } else {
47       data_store = static_cast<DataStore *>(allocator->Malloc(sizeof(DataStore)));
48     }
49     if (data_store == nullptr) {
50       MS_LOG(ERROR) << "Malloc data_store failed";
51       return nullptr;
52     }
53     data_store->data_ = data;
54     data_store->own_data_ = own_data;
55     data_store->allocator_ = data_allocator;
56     return data_store;
57   }
58 };
59 
60 class SubGraphKernel : public LiteKernel {
61  public:
SubGraphKernel(std::vector<LiteKernel * > in_kernels,std::vector<LiteKernel * > out_kernels,std::vector<LiteKernel * > nodes,Kernel * kernel)62   SubGraphKernel(std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels,
63                  std::vector<LiteKernel *> nodes, Kernel *kernel)
64       : LiteKernel(std::shared_ptr<Kernel>(kernel)),
65         nodes_(std::move(nodes)),
66         in_nodes_(std::move(in_kernels)),
67         out_nodes_(std::move(out_kernels)) {
68     subgraph_type_ = kCpuFP32SubGraph;
69     desc_.data_type = kNumberTypeFloat32;
70   }
71 
~SubGraphKernel()72   ~SubGraphKernel() override {
73     for (auto *node : nodes_) {
74       delete node;
75     }
76     nodes_.clear();
77   }
78 
IsReady(const std::vector<lite::Tensor * > & scope_tensors)79   bool IsReady(const std::vector<lite::Tensor *> &scope_tensors) override {
80     return std::all_of(this->in_nodes_.begin(), this->in_nodes_.end(),
81                        [&](LiteKernel *kernel) { return kernel->IsReady(scope_tensors); });
82   }
83 
84   // called while compiling graph. Call node->Prepare() by default.
85   int Prepare() override;
86   // called before Run
Execute()87   int Execute() override { return Execute(nullptr, nullptr); }
88 
89   int Execute(const KernelCallBack &before, const KernelCallBack &after) override;
90 
91   // called after Run
92   int ReSize() override;
93 
94   void InitOutTensorInitRefCount(const std::vector<LiteKernel *> *mask_kernels) override;
95 
96   void InitInputTensorInitRefCount();
97 
Init()98   int Init() override { return mindspore::lite::RET_OK; }
99 
100   std::string ToString() const override;
101 
nodes()102   std::vector<LiteKernel *> &nodes() { return this->nodes_; }
103 
104   void DropNode(LiteKernel *node);
105 
in_nodes()106   std::vector<LiteKernel *> in_nodes() { return this->in_nodes_; }
107 
out_nodes()108   std::vector<LiteKernel *> out_nodes() { return this->out_nodes_; }
109 
SetSchemaVersion(int schema_version)110   void SetSchemaVersion(int schema_version) { schema_version_ = schema_version; }
111 
112  protected:
113   std::vector<LiteKernel *> nodes_{};
114   // entry nodes in nodes
115   std::vector<LiteKernel *> in_nodes_{};
116   // exit nodes in nodes
117   std::vector<LiteKernel *> out_nodes_{};
118   mindspore::lite::Executor *executor_ = nullptr;
119   int schema_version_ = lite::SCHEMA_VERSION::SCHEMA_CUR;
120 };
121 
122 class CpuSubGraph : public SubGraphKernel {
123  public:
CpuSubGraph(std::vector<LiteKernel * > in_kernels,std::vector<LiteKernel * > out_kernels,std::vector<LiteKernel * > nodes,Kernel * kernel)124   CpuSubGraph(std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels,
125               std::vector<LiteKernel *> nodes, Kernel *kernel)
126       : SubGraphKernel(std::move(in_kernels), std::move(out_kernels), std::move(nodes), kernel) {
127     subgraph_type_ = kCpuFP32SubGraph;
128     desc_.arch = kernel::KERNEL_ARCH::kCPU;
129   }
130 
~CpuSubGraph()131   ~CpuSubGraph() override { delete this->executor_; }
132   int Prepare() override;
Init()133   int Init() override { return SubGraphKernel::Init(); }
Execute()134   int Execute() override { return Execute(nullptr, nullptr); }
135   int Execute(const KernelCallBack &before, const KernelCallBack &after) override;
136 };
137 
138 class CpuFp32SubGraph : public CpuSubGraph {
139  public:
CpuFp32SubGraph(std::vector<LiteKernel * > in_kernels,std::vector<LiteKernel * > out_kernels,std::vector<LiteKernel * > nodes,Kernel * kernel)140   CpuFp32SubGraph(std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels,
141                   std::vector<LiteKernel *> nodes, Kernel *kernel)
142       : CpuSubGraph(std::move(in_kernels), std::move(out_kernels), std::move(nodes), kernel) {
143     subgraph_type_ = kCpuFP32SubGraph;
144     static std::atomic_int index = {0};
145     this->set_name("CpuFP32SubGraph" + std::to_string(index++));
146     desc_.data_type = kNumberTypeFloat32;
147   }
148   ~CpuFp32SubGraph() override = default;
149 };
150 
151 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
152 class CpuFp16SubGraph : public CpuSubGraph {
153  public:
CpuFp16SubGraph(std::vector<LiteKernel * > in_kernels,std::vector<LiteKernel * > out_kernels,std::vector<LiteKernel * > nodes,Kernel * kernel)154   CpuFp16SubGraph(std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels,
155                   std::vector<LiteKernel *> nodes, Kernel *kernel)
156       : CpuSubGraph(std::move(in_kernels), std::move(out_kernels), std::move(nodes), kernel) {
157     subgraph_type_ = kCpuFP16SubGraph;
158     static std::atomic_int index = 0;
159     this->set_name("CpuFP16SubGraph" + std::to_string(index++));
160     desc_.data_type = kNumberTypeFloat16;
161   }
162 
163   ~CpuFp16SubGraph() override = default;
Init()164   int Init() override {
165     const auto *context = this->Context();
166     MS_ASSERT(context != nullptr);
167     support_fp16_ = context->device_and_pkg_support_fp16();
168     return CpuSubGraph::Init();
169   }
170 
Prepare()171   int Prepare() override {
172     auto ret = CpuSubGraph::Prepare();
173     if (ret != RET_OK) {
174       return ret;
175     }
176     for (auto &node : this->nodes_) {
177       if (node->type() == schema::PrimitiveType_Cast) {
178         auto inputs = node->in_tensors();
179         MS_ASSERT(inputs.size() >= 2);
180         auto dst_tensor = inputs[1];
181         MS_ASSERT(dst_tensor != nullptr);
182         MS_ASSERT(dst_tensor->data_type() == kNumberTypeInt32);
183         MS_ASSERT(dst_tensor->data() != nullptr);
184         MS_ASSERT(dst_tensor->ElementsNum() == 1);
185         auto *dst_data = reinterpret_cast<int32_t *>(dst_tensor->data());
186         if (dst_data[0] == kNumberTypeFloat32) {
187           dst_data[0] = kNumberTypeFloat16;
188         }
189         auto outputs = node->out_tensors();
190         MS_ASSERT(outputs.size() == 1);
191         auto output = outputs.front();
192         MS_ASSERT(output != nullptr);
193         if (output->data_type() == kNumberTypeFloat32) {
194           output->set_data_type(kNumberTypeFloat16);
195         }
196       } else if (node->type() == schema::PrimitiveType_ConstantOfShape) {
197         auto param = node->op_parameter();
198         MS_ASSERT(param != nullptr);
199         if (static_cast<TypeId>(reinterpret_cast<ConstantOfShapeParameter *>(param)->data_type_ ==
200                                 kNumberTypeFloat32)) {
201           reinterpret_cast<ConstantOfShapeParameter *>(param)->data_type_ = kNumberTypeFloat16;
202         }
203         auto outputs = node->out_tensors();
204         MS_ASSERT(outputs.size() == 1);
205         auto output = outputs.front();
206         MS_ASSERT(output != nullptr);
207         if (output->data_type() == kNumberTypeFloat32) {
208           output->set_data_type(kNumberTypeFloat16);
209         }
210       }
211     }
212     return RET_OK;
213   }
214 
215  private:
216   bool support_fp16_ = false;
217 };
218 #endif
219 
220 class CustomSubGraph : public SubGraphKernel {
221  public:
CustomSubGraph(std::vector<LiteKernel * > in_kernels,std::vector<LiteKernel * > out_kernels,std::vector<LiteKernel * > nodes,Kernel * kernel)222   CustomSubGraph(std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels,
223                  std::vector<LiteKernel *> nodes, Kernel *kernel)
224       : SubGraphKernel(std::move(in_kernels), std::move(out_kernels), std::move(nodes), kernel) {
225     subgraph_type_ = kCustomSubGraph;
226     desc_.arch = kernel::KERNEL_ARCH::kCustom;
227   }
228 
~CustomSubGraph()229   ~CustomSubGraph() override { delete this->executor_; }
230   int Prepare() override;
Init()231   int Init() override { return SubGraphKernel::Init(); }
Execute()232   int Execute() override { return Execute(nullptr, nullptr); }
233   int Execute(const KernelCallBack &before, const KernelCallBack &after) override;
234 };
235 }  // namespace mindspore::kernel
236 #endif  // MINDSPORE_LITE_SRC_SUB_GRAPH_KERNEL_H_
237