1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.h"
18 #include "src/litert/kernel_registry.h"
19 #include "include/errorcode.h"
20 #include "nnacl/fp32_grad/binary_cross_entropy.h"
21
22 using mindspore::lite::KernelRegistrar;
23 using mindspore::lite::RET_ERROR;
24 using mindspore::lite::RET_OK;
25 using mindspore::schema::PrimitiveType_BinaryCrossEntropy;
26
27 namespace mindspore::kernel {
~BinaryCrossEntropyCPUKernel()28 BinaryCrossEntropyCPUKernel::~BinaryCrossEntropyCPUKernel() {
29 if (tmp_loss_ != nullptr) {
30 free(tmp_loss_);
31 tmp_loss_ = nullptr;
32 }
33 }
34
ReSize()35 int BinaryCrossEntropyCPUKernel::ReSize() {
36 CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
37 CHECK_LESS_RETURN(out_tensors_.size(), 1);
38 CHECK_NULL_RETURN(in_tensors_.at(0));
39 CHECK_NULL_RETURN(in_tensors_.at(1));
40 if (in_tensors_.size() == C3NUM) {
41 weight_defined_ = true;
42 CHECK_NULL_RETURN(in_tensors_.at(C2NUM));
43 }
44 CHECK_NULL_RETURN(out_tensors_.at(0));
45 CHECK_NULL_RETURN(op_parameter_);
46
47 auto param_ = reinterpret_cast<BinaryCrossEntropyParameter *>(op_parameter_);
48 CHECK_NULL_RETURN(param_);
49 if (tmp_loss_ != nullptr) {
50 free(tmp_loss_);
51 tmp_loss_ = nullptr;
52 }
53 size_t input_size = in_tensors_.at(0)->ElementsNum();
54 tmp_loss_ = reinterpret_cast<float *>(malloc(input_size * sizeof(float)));
55 if (tmp_loss_ == nullptr) {
56 MS_LOG(ERROR) << "malloc tmp_loss_ for BinaryCrossEntropy op failed";
57 return RET_ERROR;
58 }
59
60 return RET_OK;
61 }
62
DoExecute(int task_id)63 int BinaryCrossEntropyCPUKernel::DoExecute(int task_id) {
64 auto logits = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
65 CHECK_NULL_RETURN(logits);
66 auto labels = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
67 CHECK_NULL_RETURN(labels);
68 auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
69 CHECK_NULL_RETURN(out);
70
71 auto param_ = reinterpret_cast<BinaryCrossEntropyParameter *>(op_parameter_);
72 int reduction = param_->reduction;
73 size_t input_size = in_tensors_.at(0)->ElementsNum();
74 if (weight_defined_) {
75 weight_ = reinterpret_cast<float *>(in_tensors_.at(C2NUM)->MutableData());
76 CHECK_NULL_RETURN(weight_);
77 }
78 BinaryCrossEntropy(input_size, reduction, logits, labels, weight_, out, tmp_loss_, weight_defined_);
79 return RET_OK;
80 }
81
BinaryCrossEntropyRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)82 int BinaryCrossEntropyRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
83 CHECK_NULL_RETURN(cdata);
84 auto bin_crs_ent_kernel = reinterpret_cast<BinaryCrossEntropyCPUKernel *>(cdata);
85 auto error_code = bin_crs_ent_kernel->DoExecute(task_id);
86 if (error_code != RET_OK) {
87 MS_LOG(ERROR) << "BinaryCrossEntropy error task_id[" << task_id << "] error_code[" << error_code << "]";
88 return RET_ERROR;
89 }
90 return RET_OK;
91 }
92
Run()93 int BinaryCrossEntropyCPUKernel::Run() {
94 int error_code = ParallelLaunch(this->ms_context_, BinaryCrossEntropyRun, this, 1);
95 if (error_code != RET_OK) {
96 MS_LOG(ERROR) << "SigmoidCrossEntropyWithLogits function error error_code[" << error_code << "]";
97 return RET_ERROR;
98 }
99 return RET_OK;
100 }
101
Prepare()102 int BinaryCrossEntropyCPUKernel::Prepare() { return ReSize(); }
103
CpuBinaryCrossEntropyFp32KernelCreator(const std::vector<lite::Tensor * > & inputs,const std::vector<lite::Tensor * > & outputs,OpParameter * opParameter,const lite::InnerContext * ctx,const kernel::KernelKey & desc)104 kernel::LiteKernel *CpuBinaryCrossEntropyFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
105 const std::vector<lite::Tensor *> &outputs,
106 OpParameter *opParameter, const lite::InnerContext *ctx,
107 const kernel::KernelKey &desc) {
108 MS_ASSERT(opParameter != nullptr);
109 MS_ASSERT(desc.type == schema::PrimitiveType_BinaryCrossEntropy);
110 auto *kernel = new (std::nothrow) BinaryCrossEntropyCPUKernel(opParameter, inputs, outputs, ctx);
111 if (kernel == nullptr) {
112 MS_LOG(ERROR) << "new SigmoidCrossEntropyWithLogits failed";
113 return nullptr;
114 }
115 return kernel;
116 }
117
118 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BinaryCrossEntropy, CpuBinaryCrossEntropyFp32KernelCreator)
119 } // namespace mindspore::kernel
120