• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/control_flow/kernel/identity_kernel.h"
18 #include "src/tensor.h"
19 #include "src/litert/lite_kernel.h"
20 #include "src/common/tensor_util.h"
21 #include "src/common/prim_inner.h"
22 
23 namespace mindspore::kernel {
Run()24 int IdentityKernel::Run() {
25   auto ret = lite::RET_OK;
26   for (size_t i = 0; i < in_tensors().size(); ++i) {
27     auto src_tensor = in_tensors()[i];
28     auto dst_tensor = out_tensors()[i];
29     if (NeedCastData(dst_tensor, src_tensor)) {
30       ret = CastTensorData(dst_tensor, src_tensor, support_fp16_);
31       MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "identity cast failed.");
32       continue;
33     }
34     if (src_tensor->allocator() == nullptr || src_tensor->IsGraphInput()) {
35       ret = SetTensorData(dst_tensor, src_tensor);
36       MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "identity set tensor data failed.");
37     } else {
38       ret = MoveTensorData(dst_tensor, src_tensor);
39       MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "identity move tensor data failed.");
40     }
41   }
42   return ret;
43 }
44 
PreProcess()45 int IdentityKernel::PreProcess() {
46   auto ret = InferShape();
47   if (ret != RET_OK) {
48     MS_LOG(ERROR) << "infer shape failed.";
49     return ret;
50   }
51   ret = ReSize();
52   if (ret != RET_OK) {
53     MS_LOG(ERROR) << "resize failed.";
54     return ret;
55   }
56   return RET_OK;
57 }
58 
InferShape()59 int IdentityKernel::InferShape() {
60   if (in_tensors().size() != out_tensors().size()) {
61     MS_LOG(ERROR) << "output kernel in_tensors size is not same as out_tensors size.";
62     return lite::RET_ERROR;
63   }
64   need_resize_.resize(in_tensors().size());
65   for (size_t i = 0; i < in_tensors().size(); ++i) {
66     auto src_tensor = in_tensors()[i];
67     auto dst_tensor = out_tensors()[i];
68     need_resize_[i] = !IsSameShape(src_tensor, dst_tensor);
69     auto ret = SetTensorShape(dst_tensor, src_tensor);
70     if (ret != RET_OK) {
71       MS_LOG(ERROR) << "set output shape failed.";
72       return ret;
73     }
74   }
75   return RET_OK;
76 }
77 
PostProcess()78 int IdentityKernel::PostProcess() { return lite::RET_OK; }
79 
ReSize()80 int IdentityKernel::ReSize() {
81   for (size_t i = 0; i < in_tensors().size(); ++i) {
82     if (need_resize_[i]) {
83       auto ret = lite::MallocTensorData(out_tensors_[i]);
84       MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "malloc dst tensor data failed.");
85     }
86   }
87   return RET_OK;
88 }
89 
Create(std::vector<lite::Tensor * > in_tensors,std::vector<lite::Tensor * > out_tensors,const lite::InnerContext * ctx)90 KernelExec *IdentityKernel::Create(std::vector<lite::Tensor *> in_tensors, std::vector<lite::Tensor *> out_tensors,
91                                    const lite::InnerContext *ctx) {
92   auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
93   if (param == nullptr) {
94     MS_LOG(ERROR) << "malloc OpParameter failed.";
95     return nullptr;
96   }
97   (void)memset(param, 0, sizeof(OpParameter));
98   param->type_ = PrimType::PrimType_Inner_Identity;
99   auto lite_kernel = new IdentityKernel(param, in_tensors, out_tensors, ctx);
100   MS_CHECK_TRUE_MSG(lite_kernel != nullptr, nullptr, "new inner kernel failed.");
101   std::shared_ptr<kernel::Kernel> shared_kernel(lite_kernel);
102   if(shared_kernel != nullptr){
103     auto *kernel_exec = new KernelExec(shared_kernel);
104     kernel_exec->set_context(ctx);
105     return kernel_exec;
106   } else {
107     MS_LOG(ERROR) << "malloc shared_kernel failed.";
108     return nullptr;
109   }
110 }
111 }  // namespace mindspore::kernel
112