• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/api/kernel.h"
17 #include "include/errorcode.h"
18 #include "src/registry/kernel_interface_registry.h"
19 #include "src/common/log_adapter.h"
20 
21 namespace mindspore::kernel {
Initialize()22 void Kernel::Initialize() {
23   if (primitive_ == nullptr) {
24     return;
25   }
26   type_ = primitive_->value_type();
27   if (type_ == schema::PrimitiveType_Custom) {
28     auto param = primitive_->value_as_Custom();
29     if (param != nullptr && param->type() != nullptr) {
30       SetAttr("type", param->type()->str());
31     }
32   }
33 }
34 
InferShape()35 int Kernel::InferShape() {
36 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
37   std::shared_ptr<KernelInterface> kernel_interface = nullptr;
38   if (type() == schema::PrimitiveType_Custom) {
39     kernel_interface = registry::KernelInterfaceRegistry::Instance()->GetKernelInterface("", nullptr, this);
40   } else {
41     auto device_list = const_cast<mindspore::Context *>(context_)->MutableDeviceInfo();
42     for (auto &device : device_list) {
43       MS_CHECK_TRUE_RET(device != nullptr, lite::RET_NULL_PTR);
44       kernel_interface =
45         registry::KernelInterfaceRegistry::Instance()->GetKernelInterface(device->GetProvider(), nullptr, this);
46       if (kernel_interface != nullptr) {
47         break;
48       }
49     }
50   }
51 
52   if (kernel_interface == nullptr) {
53     MS_LOG(ERROR) << "op_type: " << schema::EnumNamePrimitiveType(type_) << " can not find infer interface.";
54     return lite::RET_NOT_SUPPORT;
55   }
56   auto ret = kernel_interface->Infer(&inputs_, &outputs_, static_cast<const schema::Primitive *>(primitive_), this);
57   if (ret == kLiteInferInvalid) {
58     for (auto output : outputs_) {
59       output.SetShape({-1});
60     }
61     return lite::RET_INFER_INVALID;
62   }
63   if (ret != kSuccess) {
64     MS_LOG(ERROR) << "op_type: " << schema::EnumNamePrimitiveType(type_) << " infer fail!ret: " << ret;
65     return lite::RET_ERROR;
66   }
67   return lite::RET_OK;
68 #endif
69   return lite::RET_NOT_SUPPORT;
70 }
71 }  // namespace mindspore::kernel
72