• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/lite_kernel.h"
18 #include <algorithm>
19 #include "src/common/utils.h"
20 #include "src/litert/infer_manager.h"
21 
22 namespace mindspore::kernel {
23 using mindspore::lite::RET_ERROR;
24 using mindspore::lite::RET_OK;
25 
AllocWorkspace()26 void LiteKernel::AllocWorkspace() {
27   workspace_ = malloc(workspace_size());
28   if (workspace_ == nullptr) {
29     MS_LOG(ERROR) << "fail to alloc " << workspace_size() << "in kernel" << name();
30     return;
31   }
32   ws_allocated_ = true;
33 }
34 
FreeWorkspace()35 void LiteKernel::FreeWorkspace() {
36   if (ws_allocated_) {
37     free(workspace_);
38   }
39   workspace_ = nullptr;
40   ws_allocated_ = false;
41 }
42 
InferShape()43 int LiteKernel::InferShape() {
44   return lite::KernelInferShape(in_tensors_, out_tensors_, op_parameter_, ms_context_->allocator);
45 }
46 
PreProcess()47 int LiteKernel::PreProcess() {
48   if (!InferShapeDone()) {
49     auto ret = InferShape();
50     if (ret != 0) {
51       MS_LOG(ERROR) << "InferShape fail!";
52       return ret;
53     }
54     ret = ReSize();
55     if (ret != 0) {
56       MS_LOG(ERROR) << "ReSize fail!ret: " << ret;
57       return ret;
58     }
59   }
60 
61   // check if inputs are valid
62   if (!CheckInputsValid()) {
63     MS_LOG(ERROR) << "The input is not valid.";
64     return RET_ERROR;
65   }
66   // check if parameters are valid
67   if (!CheckParamsValid()) {
68     MS_LOG(ERROR) << "The parameter is not valid.";
69     return RET_ERROR;
70   }
71   for (auto *output : this->out_tensors()) {
72     MS_ASSERT(output != nullptr);
73     if (registry_data_type_ == kNumberTypeFloat16 && output->data_type() == kNumberTypeFloat32) {
74       output->set_data_type(kNumberTypeFloat16);
75     }
76     auto ret = output->MallocData();
77     if (ret != RET_OK) {
78       MS_LOG(ERROR) << "MallocData failed";
79       return ret;
80     }
81     output->ResetRefCount();
82   }
83   return RET_OK;
84 }
85 
UpdateThreadNumProcess(int32_t kernel_type,int64_t per_unit_load_num,int64_t per_unit_store_num,int64_t unit_num)86 int LiteKernel::UpdateThreadNumProcess(int32_t kernel_type, int64_t per_unit_load_num, int64_t per_unit_store_num,
87                                        int64_t unit_num) {
88   thread_num_ =
89     lite::UpdateThreadNum(kernel_type, per_unit_load_num, per_unit_store_num, unit_num, op_parameter_->thread_num_);
90   return lite::RET_OK;
91 }
92 
UpdateThreadNumPass(int32_t kernel_type,int64_t per_unit_load_num,int64_t per_unit_store_num,int64_t unit_num)93 int LiteKernel::UpdateThreadNumPass(int32_t kernel_type, int64_t per_unit_load_num, int64_t per_unit_store_num,
94                                     int64_t unit_num) {
95 #ifdef DYNAMIC_THREAD_DISTRIBUTE
96   if (UpdateThreadNumProcess(kernel_type, per_unit_load_num, per_unit_store_num, unit_num) != lite::RET_OK) {
97     MS_LOG(ERROR) << "update thread num failed";
98     return lite::RET_ERROR;
99   }
100 #else
101   thread_num_ = op_parameter_->thread_num_ > 0 ? op_parameter_->thread_num_ : 1;
102 #endif
103 
104   return lite::RET_OK;
105 }
106 
Execute()107 int LiteKernel::Execute() {
108   auto ret = PreProcess();
109   if (lite::RET_OK != ret) {
110     MS_LOG(ERROR) << "run kernel PreProcess failed, name: " << this->name();
111     return ret;
112   }
113 
114   /* op_parameter_ is null : run in kernel mod */
115   if (op_parameter_ == nullptr || op_parameter_->is_zero_shape_ == false) {
116     ret = Run();
117     if (lite::RET_OK != ret) {
118       MS_LOG(ERROR) << "run kernel failed, name: " << this->name();
119       return ret;
120     }
121   }
122 
123   ret = PostProcess();
124   if (lite::RET_OK != ret) {
125     MS_LOG(ERROR) << "run kernel PostProcess failed, name: " << this->name();
126     return ret;
127   }
128   return lite::RET_OK;
129 }
130 }  // namespace mindspore::kernel
131