1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/kernel/cpu/fp32_grad/sigmoid_cross_entropy_with_logits.h"
18 #include "src/litert/kernel_registry.h"
19 #include "include/errorcode.h"
20
21 using mindspore::lite::KernelRegistrar;
22 using mindspore::lite::RET_ERROR;
23 using mindspore::lite::RET_OK;
24 using mindspore::schema::PrimitiveType_SigmoidCrossEntropyWithLogits;
25
26 namespace mindspore::kernel {
ReSize()27 int SigmoidCrossEntropyWithLogitsCPUKernel::ReSize() {
28 CHECK_NULL_RETURN(op_parameter_);
29 CHECK_LESS_RETURN(in_tensors_.size(), 2);
30 CHECK_LESS_RETURN(out_tensors_.size(), 1);
31 CHECK_NULL_RETURN(in_tensors_.at(0));
32 CHECK_NULL_RETURN(in_tensors_.at(1));
33 CHECK_NULL_RETURN(out_tensors_.at(0));
34 return RET_OK;
35 }
36
DoExecute(int task_id)37 int SigmoidCrossEntropyWithLogitsCPUKernel::DoExecute(int task_id) {
38 auto logits = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
39 CHECK_NULL_RETURN(logits);
40 auto labels = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
41 CHECK_NULL_RETURN(labels);
42 auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
43 CHECK_NULL_RETURN(out);
44 const size_t tensor_len = in_tensors_.at(0)->ElementsNum();
45
46 const float zero = 0.0f;
47 const float one = 1.0f;
48 const float two = 2.0f;
49
50 for (uint64_t i = 0; i < tensor_len; ++i) {
51 if (logits[i] >= zero) {
52 out[i] = log1pf(exp(logits[i] - two * logits[i])) - logits[i] * (labels[i] - one);
53 } else {
54 out[i] = log1pf(exp(logits[i])) - logits[i] * labels[i];
55 }
56 }
57
58 return RET_OK;
59 }
60
SigmoidCrossEntropyWithLogitsRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)61 int SigmoidCrossEntropyWithLogitsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
62 CHECK_NULL_RETURN(cdata);
63 auto sig_crs_ent_kernel = reinterpret_cast<SigmoidCrossEntropyWithLogitsCPUKernel *>(cdata);
64 auto error_code = sig_crs_ent_kernel->DoExecute(task_id);
65 if (error_code != RET_OK) {
66 MS_LOG(ERROR) << "SigmoidCrossEntropyWithLogits error task_id[" << task_id << "] error_code[" << error_code << "]";
67 return RET_ERROR;
68 }
69 return RET_OK;
70 }
71
Run()72 int SigmoidCrossEntropyWithLogitsCPUKernel::Run() {
73 int error_code = ParallelLaunch(this->ms_context_, SigmoidCrossEntropyWithLogitsRun, this, 1);
74 if (error_code != RET_OK) {
75 MS_LOG(ERROR) << "SigmoidCrossEntropyWithLogits function error error_code[" << error_code << "]";
76 return RET_ERROR;
77 }
78 return RET_OK;
79 }
80
Prepare()81 int SigmoidCrossEntropyWithLogitsCPUKernel::Prepare() { return RET_OK; }
82
CpuSigmoidCrossEntropyWithLogitsFp32KernelCreator(const std::vector<lite::Tensor * > & inputs,const std::vector<lite::Tensor * > & outputs,OpParameter * opParameter,const lite::InnerContext * ctx,const kernel::KernelKey & desc)83 kernel::LiteKernel *CpuSigmoidCrossEntropyWithLogitsFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
84 const std::vector<lite::Tensor *> &outputs,
85 OpParameter *opParameter,
86 const lite::InnerContext *ctx,
87 const kernel::KernelKey &desc) {
88 MS_ASSERT(opParameter != nullptr);
89 MS_ASSERT(desc.type == schema::PrimitiveType_SigmoidCrossEntropyWithLogits);
90 auto *kernel = new (std::nothrow) SigmoidCrossEntropyWithLogitsCPUKernel(opParameter, inputs, outputs, ctx);
91 if (kernel == nullptr) {
92 MS_LOG(ERROR) << "new SigmoidCrossEntropyWithLogits failed";
93 return nullptr;
94 }
95 return kernel;
96 }
97
98 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SigmoidCrossEntropyWithLogits,
99 CpuSigmoidCrossEntropyWithLogitsFp32KernelCreator)
100 } // namespace mindspore::kernel
101