1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/nllloss.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/fp32/nllloss_fp32.h"
20 #include "nnacl/nllloss_parameter.h"
21
NlllossCompute(KernelBase * self)22 int NlllossCompute(KernelBase *self) {
23 NLLLossStruct *nllloss = (NLLLossStruct *)self;
24 NNACL_CHECK_NULL_RETURN_ERR(nllloss);
25 float *logits = self->in_[Index0]->data_;
26 NNACL_CHECK_NULL_RETURN_ERR(logits);
27 int *labels = self->in_[Index1]->data_;
28 NNACL_CHECK_NULL_RETURN_ERR(labels);
29 float *weight = self->in_[Index2]->data_;
30 NNACL_CHECK_NULL_RETURN_ERR(weight);
31
32 float *loss = self->out_[Index0]->data_;
33 NNACL_CHECK_NULL_RETURN_ERR(loss);
34 float *total_weight = self->out_[Index1]->data_;
35 NNACL_CHECK_NULL_RETURN_ERR(total_weight);
36
37 ReductionType reduction_type = ((NLLLossParameter *)self->param_)->reduction_type_;
38 return NLLLoss(logits, labels, weight, loss, total_weight, nllloss, reduction_type);
39 }
40
NlllossPrepare(KernelBase * self)41 int NlllossPrepare(KernelBase *self) {
42 NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR);
43 NNACL_CHECK_FALSE(self->out_size_ < TWO_TENSOR, NNACL_ERR);
44 NLLLossStruct *nllloss = (NLLLossStruct *)self;
45 NNACL_CHECK_NULL_RETURN_ERR(nllloss);
46 TensorC *logits_tensor = self->in_[FIRST_INPUT];
47 NNACL_CHECK_NULL_RETURN_ERR(logits_tensor);
48 nllloss->batch_ = logits_tensor->shape_[Index0];
49 nllloss->class_num_ = logits_tensor->shape_[Index1];
50 return NNACL_OK;
51 }
52
CreateNLLLoss(OpParameter * param,int data_type)53 KernelBase *CreateNLLLoss(OpParameter *param, int data_type) {
54 NLLLossStruct *nllloss = (NLLLossStruct *)malloc(sizeof(NLLLossStruct));
55 NNACL_CHECK_NULL_RETURN_NULL(nllloss);
56 nllloss->base_.Release = DefaultRelease;
57 nllloss->base_.Prepare = NlllossPrepare;
58 nllloss->base_.Resize = DefaultResize;
59 nllloss->base_.Compute = NlllossCompute;
60 return (KernelBase *)nllloss;
61 }
62
63 REG_KERNEL_CREATOR(PrimType_NLLLoss, kNumberTypeFloat32, CreateNLLLoss)
64