1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h"
18 #include "src/litert/kernel_registry.h"
19 #include "include/errorcode.h"
20 #include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
21
22 using mindspore::lite::KernelRegistrar;
23 using mindspore::lite::RET_ERROR;
24 using mindspore::lite::RET_OK;
25 using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad;
26
27 namespace mindspore::kernel {
ReSize()28 int BinaryCrossEntropyGradCPUKernel::ReSize() {
29 CHECK_LESS_RETURN(in_tensors_.size(), C3NUM);
30 CHECK_LESS_RETURN(out_tensors_.size(), C1NUM);
31 CHECK_NULL_RETURN(in_tensors_.at(C0NUM));
32 CHECK_NULL_RETURN(in_tensors_.at(C1NUM));
33 CHECK_NULL_RETURN(in_tensors_.at(C2NUM));
34 if (in_tensors_.size() == C4NUM) {
35 weight_defined_ = true;
36 CHECK_NULL_RETURN(in_tensors_.at(C3NUM));
37 }
38 CHECK_NULL_RETURN(out_tensors_.at(0));
39 CHECK_NULL_RETURN(op_parameter_);
40 auto param_ = reinterpret_cast<BinaryCrossEntropyGradParameter *>(op_parameter_);
41 CHECK_NULL_RETURN(param_);
42
43 return RET_OK;
44 }
45
DoExecute(int task_id)46 int BinaryCrossEntropyGradCPUKernel::DoExecute(int task_id) {
47 auto input_x = reinterpret_cast<float *>(in_tensors_.at(C0NUM)->MutableData());
48 CHECK_NULL_RETURN(input_x);
49 auto input_y = reinterpret_cast<float *>(in_tensors_.at(C1NUM)->MutableData());
50 CHECK_NULL_RETURN(input_y);
51 auto dloss = reinterpret_cast<float *>(in_tensors_.at(C2NUM)->MutableData());
52 CHECK_NULL_RETURN(dloss);
53 if (weight_defined_) {
54 weight_ = reinterpret_cast<float *>(in_tensors_.at(C3NUM)->MutableData());
55 CHECK_NULL_RETURN(weight_);
56 }
57 auto *out = reinterpret_cast<float *>(out_tensors_.at(C0NUM)->MutableData());
58 CHECK_NULL_RETURN(out);
59
60 auto param_ = reinterpret_cast<BinaryCrossEntropyGradParameter *>(op_parameter_);
61 int reduction = param_->reduction;
62 size_t input_size = in_tensors_.at(0)->ElementsNum();
63 BinaryCrossEntropyGrad(input_size, reduction, input_x, input_y, weight_, dloss, out, weight_defined_);
64 return RET_OK;
65 }
66
BinaryCrossEntropyGradRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)67 int BinaryCrossEntropyGradRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
68 CHECK_NULL_RETURN(cdata);
69 auto bin_crs_ent_kernel = reinterpret_cast<BinaryCrossEntropyGradCPUKernel *>(cdata);
70 auto error_code = bin_crs_ent_kernel->DoExecute(task_id);
71 if (error_code != RET_OK) {
72 MS_LOG(ERROR) << "BinaryCrossEntropyGrad error task_id[" << task_id << "] error_code[" << error_code << "]";
73 return RET_ERROR;
74 }
75 return RET_OK;
76 }
77
Run()78 int BinaryCrossEntropyGradCPUKernel::Run() {
79 int error_code = ParallelLaunch(this->ms_context_, BinaryCrossEntropyGradRun, this, 1);
80 if (error_code != RET_OK) {
81 MS_LOG(ERROR) << "BinaryCrossEntropyGrad function error error_code[" << error_code << "]";
82 return RET_ERROR;
83 }
84 return RET_OK;
85 }
86
Prepare()87 int BinaryCrossEntropyGradCPUKernel::Prepare() { return ReSize(); }
88
CpuBinaryCrossEntropyGradFp32KernelCreator(const std::vector<lite::Tensor * > & inputs,const std::vector<lite::Tensor * > & outputs,OpParameter * opParameter,const lite::InnerContext * ctx,const kernel::KernelKey & desc)89 kernel::LiteKernel *CpuBinaryCrossEntropyGradFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
90 const std::vector<lite::Tensor *> &outputs,
91 OpParameter *opParameter, const lite::InnerContext *ctx,
92 const kernel::KernelKey &desc) {
93 MS_ASSERT(opParameter != nullptr);
94 MS_ASSERT(desc.type == schema::PrimitiveType_BinaryCrossEntropyGrad);
95 auto *kernel = new (std::nothrow) BinaryCrossEntropyGradCPUKernel(opParameter, inputs, outputs, ctx);
96 if (kernel == nullptr) {
97 MS_LOG(ERROR) << "new BinaryCrossEntropyGrad failed";
98 return nullptr;
99 }
100 return kernel;
101 }
102
103 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BinaryCrossEntropyGrad, CpuBinaryCrossEntropyGradFp32KernelCreator)
104 } // namespace mindspore::kernel
105