1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/api/kernel.h"
17 #include "include/errorcode.h"
18 #include "src/registry/kernel_interface_registry.h"
19 #include "src/common/log_adapter.h"
20
21 namespace mindspore::kernel {
Initialize()22 void Kernel::Initialize() {
23 if (primitive_ == nullptr) {
24 return;
25 }
26 type_ = primitive_->value_type();
27 if (type_ == schema::PrimitiveType_Custom) {
28 auto param = primitive_->value_as_Custom();
29 if (param != nullptr && param->type() != nullptr) {
30 SetAttr("type", param->type()->str());
31 }
32 }
33 }
34
InferShape()35 int Kernel::InferShape() {
36 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
37 std::shared_ptr<KernelInterface> kernel_interface = nullptr;
38 if (type() == schema::PrimitiveType_Custom) {
39 kernel_interface = registry::KernelInterfaceRegistry::Instance()->GetKernelInterface("", nullptr, this);
40 } else {
41 auto device_list = const_cast<mindspore::Context *>(context_)->MutableDeviceInfo();
42 for (auto &device : device_list) {
43 MS_CHECK_TRUE_RET(device != nullptr, lite::RET_NULL_PTR);
44 kernel_interface =
45 registry::KernelInterfaceRegistry::Instance()->GetKernelInterface(device->GetProvider(), nullptr, this);
46 if (kernel_interface != nullptr) {
47 break;
48 }
49 }
50 }
51
52 if (kernel_interface == nullptr) {
53 MS_LOG(ERROR) << "op_type: " << schema::EnumNamePrimitiveType(type_) << " can not find infer interface.";
54 return lite::RET_NOT_SUPPORT;
55 }
56 auto ret = kernel_interface->Infer(&inputs_, &outputs_, static_cast<const schema::Primitive *>(primitive_), this);
57 if (ret == kLiteInferInvalid) {
58 for (auto output : outputs_) {
59 output.SetShape({-1});
60 }
61 return lite::RET_INFER_INVALID;
62 }
63 if (ret != kSuccess) {
64 MS_LOG(ERROR) << "op_type: " << schema::EnumNamePrimitiveType(type_) << " infer fail!ret: " << ret;
65 return lite::RET_ERROR;
66 }
67 return lite::RET_OK;
68 #endif
69 return lite::RET_NOT_SUPPORT;
70 }
71 } // namespace mindspore::kernel
72