• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.h"
18 #include "src/litert/kernel_registry.h"
19 #include "include/errorcode.h"
20 #include "nnacl/fp32_grad/binary_cross_entropy.h"
21 
22 using mindspore::lite::KernelRegistrar;
23 using mindspore::lite::RET_ERROR;
24 using mindspore::lite::RET_OK;
25 using mindspore::schema::PrimitiveType_BinaryCrossEntropy;
26 
27 namespace mindspore::kernel {
~BinaryCrossEntropyCPUKernel()28 BinaryCrossEntropyCPUKernel::~BinaryCrossEntropyCPUKernel() {
29   if (tmp_loss_ != nullptr) {
30     free(tmp_loss_);
31     tmp_loss_ = nullptr;
32   }
33 }
34 
ReSize()35 int BinaryCrossEntropyCPUKernel::ReSize() {
36   CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
37   CHECK_LESS_RETURN(out_tensors_.size(), 1);
38   CHECK_NULL_RETURN(in_tensors_.at(0));
39   CHECK_NULL_RETURN(in_tensors_.at(1));
40   if (in_tensors_.size() == C3NUM) {
41     weight_defined_ = true;
42     CHECK_NULL_RETURN(in_tensors_.at(C2NUM));
43   }
44   CHECK_NULL_RETURN(out_tensors_.at(0));
45   CHECK_NULL_RETURN(op_parameter_);
46 
47   auto param_ = reinterpret_cast<BinaryCrossEntropyParameter *>(op_parameter_);
48   CHECK_NULL_RETURN(param_);
49   if (tmp_loss_ != nullptr) {
50     free(tmp_loss_);
51     tmp_loss_ = nullptr;
52   }
53   size_t input_size = in_tensors_.at(0)->ElementsNum();
54   tmp_loss_ = reinterpret_cast<float *>(malloc(input_size * sizeof(float)));
55   if (tmp_loss_ == nullptr) {
56     MS_LOG(ERROR) << "malloc tmp_loss_ for BinaryCrossEntropy op failed";
57     return RET_ERROR;
58   }
59 
60   return RET_OK;
61 }
62 
DoExecute(int task_id)63 int BinaryCrossEntropyCPUKernel::DoExecute(int task_id) {
64   auto logits = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
65   CHECK_NULL_RETURN(logits);
66   auto labels = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
67   CHECK_NULL_RETURN(labels);
68   auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
69   CHECK_NULL_RETURN(out);
70 
71   auto param_ = reinterpret_cast<BinaryCrossEntropyParameter *>(op_parameter_);
72   int reduction = param_->reduction;
73   size_t input_size = in_tensors_.at(0)->ElementsNum();
74   if (weight_defined_) {
75     weight_ = reinterpret_cast<float *>(in_tensors_.at(C2NUM)->MutableData());
76     CHECK_NULL_RETURN(weight_);
77   }
78   BinaryCrossEntropy(input_size, reduction, logits, labels, weight_, out, tmp_loss_, weight_defined_);
79   return RET_OK;
80 }
81 
BinaryCrossEntropyRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)82 int BinaryCrossEntropyRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
83   CHECK_NULL_RETURN(cdata);
84   auto bin_crs_ent_kernel = reinterpret_cast<BinaryCrossEntropyCPUKernel *>(cdata);
85   auto error_code = bin_crs_ent_kernel->DoExecute(task_id);
86   if (error_code != RET_OK) {
87     MS_LOG(ERROR) << "BinaryCrossEntropy error task_id[" << task_id << "] error_code[" << error_code << "]";
88     return RET_ERROR;
89   }
90   return RET_OK;
91 }
92 
Run()93 int BinaryCrossEntropyCPUKernel::Run() {
94   int error_code = ParallelLaunch(this->ms_context_, BinaryCrossEntropyRun, this, 1);
95   if (error_code != RET_OK) {
96     MS_LOG(ERROR) << "SigmoidCrossEntropyWithLogits function error error_code[" << error_code << "]";
97     return RET_ERROR;
98   }
99   return RET_OK;
100 }
101 
Prepare()102 int BinaryCrossEntropyCPUKernel::Prepare() { return ReSize(); }
103 
CpuBinaryCrossEntropyFp32KernelCreator(const std::vector<lite::Tensor * > & inputs,const std::vector<lite::Tensor * > & outputs,OpParameter * opParameter,const lite::InnerContext * ctx,const kernel::KernelKey & desc)104 kernel::LiteKernel *CpuBinaryCrossEntropyFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
105                                                            const std::vector<lite::Tensor *> &outputs,
106                                                            OpParameter *opParameter, const lite::InnerContext *ctx,
107                                                            const kernel::KernelKey &desc) {
108   MS_ASSERT(opParameter != nullptr);
109   MS_ASSERT(desc.type == schema::PrimitiveType_BinaryCrossEntropy);
110   auto *kernel = new (std::nothrow) BinaryCrossEntropyCPUKernel(opParameter, inputs, outputs, ctx);
111   if (kernel == nullptr) {
112     MS_LOG(ERROR) << "new SigmoidCrossEntropyWithLogits failed";
113     return nullptr;
114   }
115   return kernel;
116 }
117 
118 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BinaryCrossEntropy, CpuBinaryCrossEntropyFp32KernelCreator)
119 }  // namespace mindspore::kernel
120