1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/control_flow/kernel/identity_kernel.h"
18 #include "src/tensor.h"
19 #include "src/litert/lite_kernel.h"
20 #include "src/common/tensor_util.h"
21 #include "src/common/prim_inner.h"
22
23 namespace mindspore::kernel {
Run()24 int IdentityKernel::Run() {
25 auto ret = lite::RET_OK;
26 for (size_t i = 0; i < in_tensors().size(); ++i) {
27 auto src_tensor = in_tensors()[i];
28 auto dst_tensor = out_tensors()[i];
29 if (NeedCastData(dst_tensor, src_tensor)) {
30 ret = CastTensorData(dst_tensor, src_tensor, support_fp16_);
31 MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "identity cast failed.");
32 continue;
33 }
34 if (src_tensor->allocator() == nullptr || src_tensor->IsGraphInput()) {
35 ret = SetTensorData(dst_tensor, src_tensor);
36 MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "identity set tensor data failed.");
37 } else {
38 ret = MoveTensorData(dst_tensor, src_tensor);
39 MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "identity move tensor data failed.");
40 }
41 }
42 return ret;
43 }
44
PreProcess()45 int IdentityKernel::PreProcess() {
46 auto ret = InferShape();
47 if (ret != RET_OK) {
48 MS_LOG(ERROR) << "infer shape failed.";
49 return ret;
50 }
51 ret = ReSize();
52 if (ret != RET_OK) {
53 MS_LOG(ERROR) << "resize failed.";
54 return ret;
55 }
56 return RET_OK;
57 }
58
InferShape()59 int IdentityKernel::InferShape() {
60 if (in_tensors().size() != out_tensors().size()) {
61 MS_LOG(ERROR) << "output kernel in_tensors size is not same as out_tensors size.";
62 return lite::RET_ERROR;
63 }
64 need_resize_.resize(in_tensors().size());
65 for (size_t i = 0; i < in_tensors().size(); ++i) {
66 auto src_tensor = in_tensors()[i];
67 auto dst_tensor = out_tensors()[i];
68 need_resize_[i] = !IsSameShape(src_tensor, dst_tensor);
69 auto ret = SetTensorShape(dst_tensor, src_tensor);
70 if (ret != RET_OK) {
71 MS_LOG(ERROR) << "set output shape failed.";
72 return ret;
73 }
74 }
75 return RET_OK;
76 }
77
PostProcess()78 int IdentityKernel::PostProcess() { return lite::RET_OK; }
79
ReSize()80 int IdentityKernel::ReSize() {
81 for (size_t i = 0; i < in_tensors().size(); ++i) {
82 if (need_resize_[i]) {
83 auto ret = lite::MallocTensorData(out_tensors_[i]);
84 MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "malloc dst tensor data failed.");
85 }
86 }
87 return RET_OK;
88 }
89
Create(std::vector<lite::Tensor * > in_tensors,std::vector<lite::Tensor * > out_tensors,const lite::InnerContext * ctx)90 KernelExec *IdentityKernel::Create(std::vector<lite::Tensor *> in_tensors, std::vector<lite::Tensor *> out_tensors,
91 const lite::InnerContext *ctx) {
92 auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
93 if (param == nullptr) {
94 MS_LOG(ERROR) << "malloc OpParameter failed.";
95 return nullptr;
96 }
97 (void)memset(param, 0, sizeof(OpParameter));
98 param->type_ = PrimType::PrimType_Inner_Identity;
99 auto lite_kernel = new IdentityKernel(param, in_tensors, out_tensors, ctx);
100 MS_CHECK_TRUE_MSG(lite_kernel != nullptr, nullptr, "new inner kernel failed.");
101 std::shared_ptr<kernel::Kernel> shared_kernel(lite_kernel);
102 if(shared_kernel != nullptr){
103 auto *kernel_exec = new KernelExec(shared_kernel);
104 kernel_exec->set_context(ctx);
105 return kernel_exec;
106 } else {
107 MS_LOG(ERROR) << "malloc shared_kernel failed.";
108 return nullptr;
109 }
110 }
111 } // namespace mindspore::kernel
112